mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
184 Commits
opt-attn-m
...
0.6.0
Author | SHA1 | Date | |
---|---|---|---|
a3dd87f15e | |||
242e006bbb | |||
6baa1d486b | |||
36cf54525d | |||
2b10aaa05d | |||
9f804af29d | |||
54ff971e35 | |||
b9fac7ec00 | |||
f65e90e7ef | |||
d39462856b | |||
cb180eb23a | |||
9182c828e6 | |||
3f13ad3d79 | |||
cd4d941ed1 | |||
03344d3c19 | |||
1ec3b2cc18 | |||
f7773d498a | |||
7abc3b8cd7 | |||
46012ed31f | |||
f3fade3b03 | |||
ea260aeffd | |||
0814dfd148 | |||
3ceca9901a | |||
1df2bddccf | |||
6f0b807ffd | |||
d54e02d73d | |||
45e235a747 | |||
31cf64147b | |||
77ea479a18 | |||
72e7ca529a | |||
7ff921c538 | |||
9b8537a62f | |||
7ebc3548e1 | |||
eefc1c77ef | |||
01545f7303 | |||
349c3e806a | |||
bdaa34216a | |||
cc80e065e5 | |||
13c64f6828 | |||
21f82a5155 | |||
9cff7bc3f4 | |||
d9bc5ec151 | |||
84328e2b60 | |||
82b641fd27 | |||
01794dc16e | |||
a75cd8164f | |||
b13a82a438 | |||
59b18d974e | |||
89f53b9d7b | |||
a09d451d11 | |||
fa06f5f5f9 | |||
09d4845aa8 | |||
a0d03aded1 | |||
3bbb88fcb4 | |||
ed7b99f525 | |||
287013ef28 | |||
eb26e2467e | |||
c68ed8963f | |||
e5c8b88f90 | |||
805f3be8e1 | |||
3b429f3023 | |||
96a48e5cc4 | |||
6cf82fd7a3 | |||
cfab6e7616 | |||
11d4a3c588 | |||
9d3f1c8af5 | |||
7211009179 | |||
6fadaf2eff | |||
8a05743a21 | |||
b2e816752b | |||
618ecf5e23 | |||
267601eec1 | |||
08a15cb79e | |||
c388be93e7 | |||
d22f1d4f4e | |||
0067fe00a8 | |||
587ee3bb6f | |||
dd78422701 | |||
9215e9ce8c | |||
52ae332910 | |||
8b390ddd29 | |||
c97d639fa0 | |||
b45c710dbf | |||
9c532aef47 | |||
f7a6468238 | |||
2b93dffe64 | |||
e6ee7ba4d4 | |||
1690ab45d2 | |||
8de0ce6cba | |||
ce6d08df94 | |||
2817643db9 | |||
4d14777673 | |||
f135b7963d | |||
af955f260c | |||
8ad822a983 | |||
e198bb0816 | |||
f7d5bf5b97 | |||
c119600d6e | |||
c449f65b12 | |||
db7dbf3071 | |||
4ecedb1598 | |||
53e5380bf6 | |||
50e49ecc5f | |||
4c88c3ce06 | |||
8b8fb630df | |||
fb805b8ca2 | |||
79e3bec789 | |||
e6d412b156 | |||
26cbbf8d84 | |||
2bf413caa3 | |||
3ad4770eb6 | |||
a0460cd2b1 | |||
b81ecf712d | |||
a4d5a414e3 | |||
798e0335cd | |||
718671a0d5 | |||
c5fe4a7f89 | |||
7f354473cf | |||
33c9b66554 | |||
9fd52b3b71 | |||
e662431acf | |||
ab892274d1 | |||
b869a659ec | |||
88f7793598 | |||
2ac302a5d1 | |||
ace282e5c2 | |||
c87381fc96 | |||
c5626b8271 | |||
e6a5b82ba6 | |||
5aebe53dd2 | |||
f76bb7794a | |||
30b145150f | |||
f48c07e242 | |||
8967c46563 | |||
1e46cf8b19 | |||
bd8db2a771 | |||
318d143224 | |||
2be1a35710 | |||
26226068a4 | |||
cd6b9e317c | |||
08c049def3 | |||
d17b2cdad9 | |||
fb918a23c8 | |||
b23436bf90 | |||
be9c200cbb | |||
ea0d8d3753 | |||
308ea070ed | |||
b20acd622c | |||
5522bbc57c | |||
888c09a3db | |||
318cb82f16 | |||
c7557b65dc | |||
cd29c7ccd4 | |||
f9954b73ba | |||
eead1dcead | |||
92f81d2fcb | |||
3144150b8d | |||
b190fd8592 | |||
efe4a0c84b | |||
665da30487 | |||
356a170ae9 | |||
7ecbc6d50b | |||
8ad12a0e81 | |||
eb1b27abcd | |||
708e422456 | |||
c5092f2c29 | |||
cdc8b57b5c | |||
b0340d72ec | |||
b3484e7a5e | |||
ada5d7c096 | |||
13ae5a34c7 | |||
ab86cd37c8 | |||
a9abde5f93 | |||
75b6d4b0da | |||
66f0a4eeea | |||
4523ecfb2a | |||
f5dfe883d7 | |||
196765e995 | |||
60676780a9 | |||
d3a8d291d5 | |||
cd254074f3 | |||
e7f8e72588 | |||
1b98f84a2b | |||
cf7d7fcf2f |
15
.github/workflows/trufflehog.yml
vendored
Normal file
15
.github/workflows/trufflehog.yml
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
on:
|
||||||
|
push:
|
||||||
|
|
||||||
|
name: Secret Leaks
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
trufflehog:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: Secret Scanning
|
||||||
|
uses: trufflesecurity/trufflehog@main
|
29
Cargo.toml
29
Cargo.toml
@ -9,6 +9,7 @@ members = [
|
|||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
"candle-wasm-examples/*",
|
"candle-wasm-examples/*",
|
||||||
"candle-wasm-tests",
|
"candle-wasm-tests",
|
||||||
|
"tensor-tools",
|
||||||
]
|
]
|
||||||
exclude = [
|
exclude = [
|
||||||
"candle-flash-attn",
|
"candle-flash-attn",
|
||||||
@ -19,7 +20,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.4.2"
|
version = "0.6.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -32,21 +33,22 @@ ab_glyph = "0.2.23"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.2" }
|
candle = { path = "./candle-core", package = "candle-core", version = "0.6.0" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.4.2" }
|
candle-datasets = { path = "./candle-datasets", version = "0.6.0" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.2" }
|
candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.0" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.4.2" }
|
candle-kernels = { path = "./candle-kernels", version = "0.6.0" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.2" }
|
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.0" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.4.2" }
|
candle-nn = { path = "./candle-nn", version = "0.6.0" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.4.2" }
|
candle-onnx = { path = "./candle-onnx", version = "0.6.0" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.4.2" }
|
candle-transformers = { path = "./candle-transformers", version = "0.6.0" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
cudarc = { version = "0.11.4", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.3.0"
|
hf-hub = "0.3.0"
|
||||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
|
hound = "3.5.1"
|
||||||
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
|
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
|
||||||
imageproc = { version = "0.24.0", default-features = false }
|
imageproc = { version = "0.24.0", default-features = false }
|
||||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||||
@ -55,7 +57,7 @@ log = "0.4"
|
|||||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||||
num_cpus = "1.15.0"
|
num_cpus = "1.15.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
parquet = { version = "50.0.0" }
|
parquet = { version = "51.0.0" }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
rand_distr = "0.4.3"
|
rand_distr = "0.4.3"
|
||||||
rayon = "1.7.0"
|
rayon = "1.7.0"
|
||||||
@ -64,13 +66,12 @@ serde = { version = "1.0.171", features = ["derive"] }
|
|||||||
serde_plain = "1.0.2"
|
serde_plain = "1.0.2"
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tokenizers = { version = "0.15.0", default-features = false }
|
tokenizers = { version = "0.19.1", default-features = false }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
wav = "1.0.0"
|
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "1.1.1", default-features = false }
|
||||||
metal = { version = "0.27.0", features = ["mps"]}
|
metal = { version = "0.27.0", features = ["mps"]}
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
|
34
README.md
34
README.md
@ -60,12 +60,14 @@ These online demos run entirely in your browser:
|
|||||||
|
|
||||||
We also provide a some command line based examples using state of the art models:
|
We also provide a some command line based examples using state of the art models:
|
||||||
|
|
||||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
|
||||||
the SOLAR-10.7B variant.
|
the SOLAR-10.7B variant.
|
||||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google
|
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
|
||||||
Deepmind.
|
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
||||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
Griffin based models from Google that mix attention with a RNN like state.
|
||||||
|
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
|
||||||
|
2.7b, and 3.8b general LLMs with performance on par with 7b models.
|
||||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||||
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
||||||
@ -110,7 +112,7 @@ We also provide a some command line based examples using state of the art models
|
|||||||
|
|
||||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||||
|
|
||||||
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
|
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmentation model.
|
||||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||||
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
|
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
|
||||||
model using residual vector quantization.
|
model using residual vector quantization.
|
||||||
@ -125,10 +127,14 @@ We also provide a some command line based examples using state of the art models
|
|||||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
generate captions for an image.
|
generate captions for an image.
|
||||||
|
- [CLIP](./candle-examples/examples/clip/): multi-model vision and language
|
||||||
|
model.
|
||||||
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
||||||
dedicated submodels for hand-writing and printed recognition.
|
dedicated submodels for hand-writing and printed recognition.
|
||||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||||
model, generates the translated text from the input text.
|
model, generates the translated text from the input text.
|
||||||
|
- [Moondream](./candle-examples/examples/moondream/): tiny computer-vision model
|
||||||
|
that can answer real-world questions about images.
|
||||||
|
|
||||||
Run them using commands like:
|
Run them using commands like:
|
||||||
```
|
```
|
||||||
@ -172,6 +178,7 @@ And then head over to
|
|||||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||||
serving local LLMs including an OpenAI compatible API server.
|
serving local LLMs including an OpenAI compatible API server.
|
||||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||||
|
- [`candle-coursera-ml`](https://github.com/vishpat/candle-coursera-ml): Implementation of ML algorithms from Coursera's [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) course.
|
||||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||||
@ -194,10 +201,10 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- WASM support, run your models in a browser.
|
- WASM support, run your models in a browser.
|
||||||
- Included models.
|
- Included models.
|
||||||
- Language Models.
|
- Language Models.
|
||||||
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
- LLaMA v1, v2, and v3 with variants such as SOLAR-10.7B.
|
||||||
- Falcon.
|
- Falcon.
|
||||||
- StarCoder, StarCoder2.
|
- StarCoder, StarCoder2.
|
||||||
- Phi 1, 1.5, and 2.
|
- Phi 1, 1.5, 2, and 3.
|
||||||
- Mamba, Minimal Mamba
|
- Mamba, Minimal Mamba
|
||||||
- Gemma 2b and 7b.
|
- Gemma 2b and 7b.
|
||||||
- Mistral 7b v0.1.
|
- Mistral 7b v0.1.
|
||||||
@ -206,7 +213,7 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Replit-code-v1.5-3B.
|
- Replit-code-v1.5-3B.
|
||||||
- Bert.
|
- Bert.
|
||||||
- Yi-6B and Yi-34B.
|
- Yi-6B and Yi-34B.
|
||||||
- Qwen1.5.
|
- Qwen1.5, Qwen1.5 MoE.
|
||||||
- RWKV v5 and v6.
|
- RWKV v5 and v6.
|
||||||
- Quantized LLMs.
|
- Quantized LLMs.
|
||||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||||
@ -369,9 +376,9 @@ git submodule update --init
|
|||||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||||
```
|
```
|
||||||
|
|
||||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
|
||||||
```
|
```
|
||||||
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Linking error on windows when running rustdoc or mdbook tests
|
#### Linking error on windows when running rustdoc or mdbook tests
|
||||||
@ -401,3 +408,10 @@ This may be caused by the models being loaded from `/mnt/c`, more details on
|
|||||||
|
|
||||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||||
error is generated.
|
error is generated.
|
||||||
|
|
||||||
|
#### CudaRC error
|
||||||
|
|
||||||
|
If you encounter an error like this one `called `Result::unwrap()` on an `Err` value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }` on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.
|
||||||
|
`c:\Windows\System32\nvcuda.dll` -> `cuda.dll`
|
||||||
|
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll` -> `cublas.dll`
|
||||||
|
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll` -> `curand.dll`
|
||||||
|
@ -37,7 +37,6 @@ tokenizers = { workspace = true, features = ["onig"] }
|
|||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tracing-chrome = { workspace = true }
|
tracing-chrome = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
wav = { workspace = true }
|
|
||||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
parquet = { workspace = true }
|
parquet = { workspace = true }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
|
@ -81,7 +81,7 @@ let mut tp_shape = view.shape().to_vec();
|
|||||||
let size = tp_shape[0];
|
let size = tp_shape[0];
|
||||||
|
|
||||||
if size % world_size != 0 {
|
if size % world_size != 0 {
|
||||||
panic!("The dimension is not divisble by `world_size`");
|
panic!("The dimension is not divisible by `world_size`");
|
||||||
}
|
}
|
||||||
let block_size = size / world_size;
|
let block_size = size / world_size;
|
||||||
let start = rank * block_size;
|
let start = rank * block_size;
|
||||||
@ -106,8 +106,8 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[test]
|
|
||||||
fn book_training_1() -> Result<()>{
|
fn book_training_1() -> Result<()>{
|
||||||
// ANCHOR: book_training_1
|
// ANCHOR: book_training_1
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
@ -7,4 +7,6 @@ criterion_main!(
|
|||||||
benchmarks::random::benches,
|
benchmarks::random::benches,
|
||||||
benchmarks::where_cond::benches,
|
benchmarks::where_cond::benches,
|
||||||
benchmarks::conv_transpose2d::benches,
|
benchmarks::conv_transpose2d::benches,
|
||||||
|
benchmarks::qmatmul::benches,
|
||||||
|
benchmarks::unary::benches
|
||||||
);
|
);
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
pub(crate) mod affine;
|
pub(crate) mod affine;
|
||||||
pub(crate) mod conv_transpose2d;
|
pub(crate) mod conv_transpose2d;
|
||||||
pub(crate) mod matmul;
|
pub(crate) mod matmul;
|
||||||
|
pub(crate) mod qmatmul;
|
||||||
pub(crate) mod random;
|
pub(crate) mod random;
|
||||||
|
pub(crate) mod unary;
|
||||||
pub(crate) mod where_cond;
|
pub(crate) mod where_cond;
|
||||||
|
|
||||||
use candle_core::{Device, Result};
|
use candle_core::{Device, Result};
|
||||||
|
72
candle-core/benches/benchmarks/qmatmul.rs
Normal file
72
candle-core/benches/benchmarks/qmatmul.rs
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{
|
||||||
|
quantized::{self, GgmlDType, QMatMul},
|
||||||
|
Device, Module, Tensor,
|
||||||
|
};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run(matmul: &QMatMul, x: &Tensor) {
|
||||||
|
matmul.forward(&x).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
||||||
|
let b = 1;
|
||||||
|
let m = 1;
|
||||||
|
let n = 1024;
|
||||||
|
let k = 1024;
|
||||||
|
|
||||||
|
let lhs = (0..(m * k))
|
||||||
|
.map(|v| v as f32 / (m * k) as f32)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let rhs = (0..(k * n))
|
||||||
|
.map(|v| v as f32 / (n * k) as f32)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
|
||||||
|
let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
|
||||||
|
|
||||||
|
let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
|
||||||
|
let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
|
||||||
|
|
||||||
|
let flops = b * m * n * k;
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
|
||||||
|
group.sample_size(200);
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(black_box(&matmul), black_box(&lhs));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
|
for device in handler.devices {
|
||||||
|
for dtype in vec![
|
||||||
|
GgmlDType::F32,
|
||||||
|
GgmlDType::F16,
|
||||||
|
GgmlDType::Q4_0,
|
||||||
|
GgmlDType::Q4_1,
|
||||||
|
GgmlDType::Q5_0,
|
||||||
|
GgmlDType::Q5_1,
|
||||||
|
GgmlDType::Q8_0,
|
||||||
|
GgmlDType::Q2K,
|
||||||
|
GgmlDType::Q3K,
|
||||||
|
GgmlDType::Q4K,
|
||||||
|
GgmlDType::Q5K,
|
||||||
|
GgmlDType::Q6K,
|
||||||
|
] {
|
||||||
|
run_bench(c, &device, dtype);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
49
candle-core/benches/benchmarks/unary.rs
Normal file
49
candle-core/benches/benchmarks/unary.rs
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run(a: &Tensor) {
|
||||||
|
a.sqrt().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||||
|
let b = 1;
|
||||||
|
let m = 1024;
|
||||||
|
let k = 1024;
|
||||||
|
|
||||||
|
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, &device)
|
||||||
|
.unwrap()
|
||||||
|
.to_dtype(dtype)
|
||||||
|
.unwrap()
|
||||||
|
.reshape((b, m, k))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let flops = b * m * k * dtype.size_in_bytes();
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(black_box(&tensor));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
|
for device in handler.devices {
|
||||||
|
for dtype in [DType::F32, DType::BF16, DType::F16] {
|
||||||
|
let name = format!("sqrt_{:?}", dtype);
|
||||||
|
run_unary_benchmark(c, &device, dtype, &name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
@ -5,32 +5,29 @@ extern crate accelerate_src;
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle_core::{Device, Module, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
use candle_core::quantized::{QMatMul, QTensor};
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
let device = Device::new_cuda(0)?;
|
||||||
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
|
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||||
let q_cpu = q.to_device(&Device::Cpu)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
|
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||||
let q = QMatMul::from_qtensor(q)?;
|
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
||||||
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
|
let _x1 = x.matmul(&x)?;
|
||||||
let res_q_cuda = q.forward(&x)?;
|
drop(_x1);
|
||||||
println!("{res_q_cuda}");
|
let start_time = std::time::Instant::now();
|
||||||
|
let _x1 = x.matmul(&x)?;
|
||||||
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
|
device.synchronize()?;
|
||||||
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
|
println!("fp32: {:?}", start_time.elapsed());
|
||||||
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
|
drop(_x1);
|
||||||
let x_cpu = x.to_device(&Device::Cpu)?;
|
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||||
let res_q_cpu = q_cpu.forward(&x_cpu)?;
|
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
||||||
println!("{res_q_cpu}");
|
let _x1 = x.matmul(&x)?;
|
||||||
|
drop(_x1);
|
||||||
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
|
let start_time = std::time::Instant::now();
|
||||||
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
|
let _x1 = x.matmul(&x)?;
|
||||||
.abs()?
|
device.synchronize()?;
|
||||||
.flatten_all()?
|
println!("tf32: {:?}", start_time.elapsed());
|
||||||
.max(0)?;
|
drop(_x1);
|
||||||
println!("{diff}");
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -133,6 +133,8 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
|||||||
/// after this call.
|
/// after this call.
|
||||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
|
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
|
||||||
@ -142,4 +144,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
|||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn set_seed(&self, _: u64) -> Result<()>;
|
fn set_seed(&self, _: u64) -> Result<()>;
|
||||||
|
|
||||||
|
/// Synchronize should block until all the operations on the device are completed.
|
||||||
|
fn synchronize(&self) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
@ -112,7 +112,8 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Op::Unary(_node, UnaryOp::Ceil)
|
Op::Unary(_node, UnaryOp::Ceil)
|
||||||
| Op::Unary(_node, UnaryOp::Floor)
|
| Op::Unary(_node, UnaryOp::Floor)
|
||||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
| Op::Unary(_node, UnaryOp::Round)
|
||||||
|
| Op::Unary(_node, UnaryOp::Sign) => nodes,
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
| Op::UpsampleNearest1D { arg: node, .. }
|
| Op::UpsampleNearest1D { arg: node, .. }
|
||||||
| Op::UpsampleNearest2D { arg: node, .. }
|
| Op::UpsampleNearest2D { arg: node, .. }
|
||||||
@ -488,7 +489,6 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad)?;
|
*sum_grad = sum_grad.add(&grad)?;
|
||||||
}
|
}
|
||||||
Op::Cmp(_args, _) => {}
|
|
||||||
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
||||||
let node = broadcast_back(arg, node, reduced_dims)?;
|
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||||
@ -578,20 +578,18 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
|
Op::Unary(_, UnaryOp::Floor)
|
||||||
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
|
| Op::Unary(_, UnaryOp::Round)
|
||||||
|
| Op::Reduce(_, ReduceOp::ArgMin, _)
|
||||||
|
| Op::Reduce(_, ReduceOp::ArgMax, _)
|
||||||
|
| Op::Unary(_, UnaryOp::Sign)
|
||||||
|
| Op::Cmp(_, _) => {}
|
||||||
Op::Reshape(arg) => {
|
Op::Reshape(arg) => {
|
||||||
let arg_grad = grad.reshape(arg.dims())?;
|
let arg_grad = grad.reshape(arg.dims())?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||||
Op::Unary(_, UnaryOp::Floor) => {
|
|
||||||
Err(Error::BackwardNotSupported { op: "floor" })?
|
|
||||||
}
|
|
||||||
Op::Unary(_, UnaryOp::Round) => {
|
|
||||||
Err(Error::BackwardNotSupported { op: "round" })?
|
|
||||||
}
|
|
||||||
Op::Unary(arg, UnaryOp::Gelu) => {
|
Op::Unary(arg, UnaryOp::Gelu) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
let cube = arg.powf(3.)?;
|
let cube = arg.powf(3.)?;
|
||||||
@ -626,7 +624,7 @@ impl Tensor {
|
|||||||
Op::Unary(arg, UnaryOp::Silu) => {
|
Op::Unary(arg, UnaryOp::Silu) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||||
let sigmoid_arg = (*node / arg)?;
|
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
|
||||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
pub mod erf;
|
pub mod erf;
|
||||||
pub mod kernels;
|
pub mod kernels;
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
trait Cpu<const ARR: usize> {
|
trait Cpu<const ARR: usize> {
|
||||||
type Unit;
|
type Unit;
|
||||||
type Array;
|
type Array;
|
||||||
@ -18,6 +19,7 @@ trait Cpu<const ARR: usize> {
|
|||||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
trait CpuF16<const ARR: usize> {
|
trait CpuF16<const ARR: usize> {
|
||||||
type Unit;
|
type Unit;
|
||||||
type Array;
|
type Array;
|
||||||
|
@ -4,8 +4,13 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
|||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
mod utils;
|
||||||
|
pub use utils::{
|
||||||
|
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
||||||
|
};
|
||||||
|
|
||||||
const USE_IM2COL_CONV1D: bool = true;
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
const USE_IM2COL_CONV1D_TR: bool = true;
|
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||||
const USE_IM2COL_CONV2D: bool = true;
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||||
@ -21,105 +26,20 @@ pub enum CpuStorage {
|
|||||||
F64(Vec<f64>),
|
F64(Vec<f64>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum CpuStorageRef<'a> {
|
||||||
|
U8(&'a [u8]),
|
||||||
|
U32(&'a [u32]),
|
||||||
|
I64(&'a [i64]),
|
||||||
|
BF16(&'a [bf16]),
|
||||||
|
F16(&'a [f16]),
|
||||||
|
F32(&'a [f32]),
|
||||||
|
F64(&'a [f64]),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CpuDevice;
|
pub struct CpuDevice;
|
||||||
|
|
||||||
pub trait Map1 {
|
|
||||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
|
||||||
|
|
||||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
|
||||||
match vs {
|
|
||||||
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map1Any {
|
|
||||||
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
|
|
||||||
&self,
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
wrap: W,
|
|
||||||
) -> Result<CpuStorage>;
|
|
||||||
|
|
||||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
|
||||||
match vs {
|
|
||||||
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
|
|
||||||
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
|
|
||||||
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
|
|
||||||
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
|
|
||||||
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
|
|
||||||
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
|
|
||||||
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type C = CpuStorage;
|
|
||||||
pub trait Map2 {
|
|
||||||
const OP: &'static str;
|
|
||||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
|
||||||
|
|
||||||
fn map(
|
|
||||||
&self,
|
|
||||||
v1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
v2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
) -> Result<CpuStorage> {
|
|
||||||
match (v1, v2) {
|
|
||||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: v1.dtype(),
|
|
||||||
rhs: v2.dtype(),
|
|
||||||
op: Self::OP,
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2U8 {
|
|
||||||
const OP: &'static str;
|
|
||||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
|
||||||
|
|
||||||
fn map(
|
|
||||||
&self,
|
|
||||||
v1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
v2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
) -> Result<CpuStorage> {
|
|
||||||
match (v1, v2) {
|
|
||||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: v1.dtype(),
|
|
||||||
rhs: v2.dtype(),
|
|
||||||
op: Self::OP,
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Cmp(CmpOp);
|
struct Cmp(CmpOp);
|
||||||
impl Map2U8 for Cmp {
|
impl Map2U8 for Cmp {
|
||||||
const OP: &'static str = "cmp";
|
const OP: &'static str = "cmp";
|
||||||
@ -201,7 +121,8 @@ impl ReduceIndex {
|
|||||||
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
|
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
|
||||||
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
|
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
|
||||||
let dst_to_set = dst.spare_capacity_mut();
|
let dst_to_set = dst.spare_capacity_mut();
|
||||||
let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
|
let dst_to_set =
|
||||||
|
unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
|
||||||
match src_l.contiguous_offsets() {
|
match src_l.contiguous_offsets() {
|
||||||
Some((o1, o2)) => {
|
Some((o1, o2)) => {
|
||||||
let src = &src[o1..o2];
|
let src = &src[o1..o2];
|
||||||
@ -366,275 +287,6 @@ impl<'a> Map1 for ReduceSum<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
mut f: F,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match layout.strided_blocks() {
|
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
|
||||||
[start_offset..start_offset + len]
|
|
||||||
.iter()
|
|
||||||
.map(|&v| f(v))
|
|
||||||
.collect(),
|
|
||||||
crate::StridedBlocks::MultipleBlocks {
|
|
||||||
block_start_index,
|
|
||||||
block_len,
|
|
||||||
} => {
|
|
||||||
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
|
||||||
// Specialize the case where block_len is one to avoid the second loop.
|
|
||||||
if block_len == 1 {
|
|
||||||
for index in block_start_index {
|
|
||||||
let v = unsafe { vs.get_unchecked(index) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for index in block_start_index {
|
|
||||||
for offset in 0..block_len {
|
|
||||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
mut f: F,
|
|
||||||
mut f_vec: FV,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match layout.strided_blocks() {
|
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
|
||||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
|
||||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(len) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
crate::StridedBlocks::MultipleBlocks {
|
|
||||||
block_start_index,
|
|
||||||
block_len,
|
|
||||||
} => {
|
|
||||||
let el_count = layout.shape().elem_count();
|
|
||||||
// Specialize the case where block_len is one to avoid the second loop.
|
|
||||||
if block_len == 1 {
|
|
||||||
let mut result = Vec::with_capacity(el_count);
|
|
||||||
for index in block_start_index {
|
|
||||||
let v = unsafe { vs.get_unchecked(index) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
result
|
|
||||||
} else {
|
|
||||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
|
||||||
let mut dst_index = 0;
|
|
||||||
for src_index in block_start_index {
|
|
||||||
let vs = &vs[src_index..src_index + block_len];
|
|
||||||
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
|
||||||
f_vec(vs, ys);
|
|
||||||
dst_index += block_len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This function maps over two strided index sequences.
|
|
||||||
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
|
||||||
lhs_l: &Layout,
|
|
||||||
rhs_l: &Layout,
|
|
||||||
lhs: &[T],
|
|
||||||
rhs: &[T],
|
|
||||||
mut f: F,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
|
||||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
|
||||||
.iter()
|
|
||||||
.zip(rhs[o_r1..o_r2].iter())
|
|
||||||
.map(|(&l, &r)| f(l, r))
|
|
||||||
.collect(),
|
|
||||||
(Some((o_l1, o_l2)), None) => {
|
|
||||||
// TODO: Maybe we want to avoid going through the layout twice.
|
|
||||||
match rhs_l.offsets_b() {
|
|
||||||
Some(ob) => {
|
|
||||||
let mut i_in_block = 0;
|
|
||||||
let mut i_right_broadcast = 0;
|
|
||||||
lhs[o_l1..o_l2]
|
|
||||||
.iter()
|
|
||||||
.map(|&l| {
|
|
||||||
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
|
||||||
i_right_broadcast += 1;
|
|
||||||
if i_right_broadcast >= ob.right_broadcast {
|
|
||||||
i_in_block += 1;
|
|
||||||
i_right_broadcast = 0;
|
|
||||||
}
|
|
||||||
if i_in_block >= ob.len {
|
|
||||||
i_in_block = 0
|
|
||||||
}
|
|
||||||
f(l, *r)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(None, Some((o_r1, o_r2))) => {
|
|
||||||
// TODO: Maybe we want to avoid going through the layout twice.
|
|
||||||
match lhs_l.offsets_b() {
|
|
||||||
Some(ob) => {
|
|
||||||
let mut i_in_block = 0;
|
|
||||||
let mut i_right_broadcast = 0;
|
|
||||||
rhs[o_r1..o_r2]
|
|
||||||
.iter()
|
|
||||||
.map(|&r| {
|
|
||||||
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
|
||||||
i_right_broadcast += 1;
|
|
||||||
if i_right_broadcast >= ob.right_broadcast {
|
|
||||||
i_in_block += 1;
|
|
||||||
i_right_broadcast = 0;
|
|
||||||
}
|
|
||||||
if i_in_block >= ob.len {
|
|
||||||
i_in_block = 0
|
|
||||||
}
|
|
||||||
f(*l, r)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Similar to binary_map but with vectorized variants.
|
|
||||||
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
|
||||||
lhs_l: &Layout,
|
|
||||||
rhs_l: &Layout,
|
|
||||||
lhs: &[T],
|
|
||||||
rhs: &[T],
|
|
||||||
mut f: F,
|
|
||||||
mut f_vec: FV,
|
|
||||||
) -> Vec<T> {
|
|
||||||
let el_count = lhs_l.shape().elem_count();
|
|
||||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
|
||||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
|
||||||
Some(ob) if ob.right_broadcast == 1 => {
|
|
||||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
let mut dst_i = 0;
|
|
||||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
|
||||||
f_vec(
|
|
||||||
&lhs[src_i..src_i + ob.len],
|
|
||||||
rhs,
|
|
||||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
|
||||||
);
|
|
||||||
dst_i += ob.len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
Some(ob) => {
|
|
||||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys = lhs[o_l1..o_l2].to_vec();
|
|
||||||
for idx_l in 0..ob.left_broadcast {
|
|
||||||
let start = idx_l * ob.len * ob.right_broadcast;
|
|
||||||
for (i, &r) in rhs.iter().enumerate() {
|
|
||||||
let start = start + i * ob.right_broadcast;
|
|
||||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
|
||||||
*v = f(*v, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
|
||||||
Some(ob) if ob.right_broadcast == 1 => {
|
|
||||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
let mut dst_i = 0;
|
|
||||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
|
||||||
f_vec(
|
|
||||||
lhs,
|
|
||||||
&rhs[src_i..src_i + ob.len],
|
|
||||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
|
||||||
);
|
|
||||||
dst_i += ob.len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
Some(ob) => {
|
|
||||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys = rhs[o_r1..o_r2].to_vec();
|
|
||||||
for idx_l in 0..ob.left_broadcast {
|
|
||||||
let start = idx_l * ob.len * ob.right_broadcast;
|
|
||||||
for (i, &l) in lhs.iter().enumerate() {
|
|
||||||
let start = start + i * ob.right_broadcast;
|
|
||||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
|
||||||
*v = f(l, *v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
_ => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Affine(f64, f64);
|
struct Affine(f64, f64);
|
||||||
|
|
||||||
impl Map1 for Affine {
|
impl Map1 for Affine {
|
||||||
@ -1564,6 +1216,30 @@ impl MatMul {
|
|||||||
}))
|
}))
|
||||||
.bt()
|
.bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
|
||||||
|
let lhs_stride = lhs_l.stride();
|
||||||
|
let rhs_stride = rhs_l.stride();
|
||||||
|
let rank = lhs_stride.len();
|
||||||
|
let (_b, m, n, k) = self.0;
|
||||||
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
|
[_, stride] if lhs_l.dims()[0] == 1 => stride,
|
||||||
|
[stride, _] if lhs_l.dims()[1] == 1 => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => m * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||||
|
};
|
||||||
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
|
[_, stride] if rhs_l.dims()[0] == 1 => stride,
|
||||||
|
[stride, _] if rhs_l.dims()[1] == 1 => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => n * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||||
|
};
|
||||||
|
Ok((a_skip, b_skip))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Map2 for MatMul {
|
impl Map2 for MatMul {
|
||||||
@ -1597,18 +1273,7 @@ impl Map2 for MatMul {
|
|||||||
let rhs_cs = rhs_stride[rank - 1];
|
let rhs_cs = rhs_stride[rank - 1];
|
||||||
let rhs_rs = rhs_stride[rank - 2];
|
let rhs_rs = rhs_stride[rank - 2];
|
||||||
|
|
||||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
||||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => m * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
|
||||||
};
|
|
||||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => n * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
|
||||||
};
|
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let dst_shape: Shape = (m, n).into();
|
let dst_shape: Shape = (m, n).into();
|
||||||
@ -1668,20 +1333,8 @@ impl Map2 for MatMul {
|
|||||||
|
|
||||||
let lhs_stride = lhs_l.stride();
|
let lhs_stride = lhs_l.stride();
|
||||||
let rhs_stride = rhs_l.stride();
|
let rhs_stride = rhs_l.stride();
|
||||||
let rank = lhs_stride.len();
|
|
||||||
|
|
||||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
||||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => m * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
|
||||||
};
|
|
||||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => n * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
|
||||||
};
|
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
@ -1689,7 +1342,7 @@ impl Map2 for MatMul {
|
|||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
|
||||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||||
(n as i32, b'N')
|
(n as i32, b'N')
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(k as i32, b'T')
|
(k as i32, b'T')
|
||||||
@ -1697,7 +1350,7 @@ impl Map2 for MatMul {
|
|||||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||||
(k as i32, b'N')
|
(k as i32, b'N')
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(m as i32, b'T')
|
(m as i32, b'T')
|
||||||
@ -1771,20 +1424,8 @@ impl Map2 for MatMul {
|
|||||||
|
|
||||||
let lhs_stride = lhs_l.stride();
|
let lhs_stride = lhs_l.stride();
|
||||||
let rhs_stride = rhs_l.stride();
|
let rhs_stride = rhs_l.stride();
|
||||||
let rank = lhs_stride.len();
|
|
||||||
|
|
||||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
||||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => m * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
|
||||||
};
|
|
||||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => n * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
|
||||||
};
|
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
@ -1792,7 +1433,7 @@ impl Map2 for MatMul {
|
|||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
|
||||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||||
(n as i32, b'N')
|
(n as i32, b'N')
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(k as i32, b'T')
|
(k as i32, b'T')
|
||||||
@ -1800,7 +1441,7 @@ impl Map2 for MatMul {
|
|||||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||||
(k as i32, b'N')
|
(k as i32, b'N')
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(m as i32, b'T')
|
(m as i32, b'T')
|
||||||
@ -2609,7 +2250,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
&& params.dilation == 1
|
&& params.dilation == 1
|
||||||
&& params.padding == 0
|
&& params.padding == 0
|
||||||
&& params.output_padding == 0;
|
&& params.output_padding == 0;
|
||||||
if USE_IM2COL_CONV1D_TR && can_use_col2im {
|
if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||||
if !kernel_l.is_contiguous() {
|
if !kernel_l.is_contiguous() {
|
||||||
@ -2816,6 +2457,10 @@ impl BackendDevice for CpuDevice {
|
|||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||||
|
Ok(T::to_cpu_storage(s))
|
||||||
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Ok(s.clone())
|
Ok(s.clone())
|
||||||
}
|
}
|
||||||
@ -2999,6 +2644,10 @@ impl BackendDevice for CpuDevice {
|
|||||||
};
|
};
|
||||||
Ok(storage)
|
Ok(storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn synchronize(&self) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
360
candle-core/src/cpu_backend/utils.rs
Normal file
360
candle-core/src/cpu_backend/utils.rs
Normal file
@ -0,0 +1,360 @@
|
|||||||
|
/// Helper functions to write CPU kernels.
|
||||||
|
use crate::backend::BackendStorage;
|
||||||
|
use crate::{Error, Layout, Result, WithDType};
|
||||||
|
|
||||||
|
type C = super::CpuStorage;
|
||||||
|
pub trait Map1 {
|
||||||
|
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||||
|
|
||||||
|
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
||||||
|
match vs {
|
||||||
|
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
|
||||||
|
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
|
||||||
|
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
|
||||||
|
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
|
||||||
|
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
|
||||||
|
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
|
||||||
|
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map1Any {
|
||||||
|
fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
|
||||||
|
|
||||||
|
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
||||||
|
match vs {
|
||||||
|
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
|
||||||
|
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
|
||||||
|
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
|
||||||
|
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
|
||||||
|
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
|
||||||
|
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
|
||||||
|
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2 {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
||||||
|
|
||||||
|
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2U8 {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||||
|
|
||||||
|
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
lhs: &[T],
|
||||||
|
rhs: &[T],
|
||||||
|
mut f: F,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||||
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||||
|
.iter()
|
||||||
|
.zip(rhs[o_r1..o_r2].iter())
|
||||||
|
.map(|(&l, &r)| f(l, r))
|
||||||
|
.collect(),
|
||||||
|
(Some((o_l1, o_l2)), None) => {
|
||||||
|
// TODO: Maybe we want to avoid going through the layout twice.
|
||||||
|
match rhs_l.offsets_b() {
|
||||||
|
Some(ob) => {
|
||||||
|
let mut i_in_block = 0;
|
||||||
|
let mut i_right_broadcast = 0;
|
||||||
|
lhs[o_l1..o_l2]
|
||||||
|
.iter()
|
||||||
|
.map(|&l| {
|
||||||
|
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
||||||
|
i_right_broadcast += 1;
|
||||||
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
|
i_in_block += 1;
|
||||||
|
i_right_broadcast = 0;
|
||||||
|
}
|
||||||
|
if i_in_block >= ob.len {
|
||||||
|
i_in_block = 0
|
||||||
|
}
|
||||||
|
f(l, *r)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(None, Some((o_r1, o_r2))) => {
|
||||||
|
// TODO: Maybe we want to avoid going through the layout twice.
|
||||||
|
match lhs_l.offsets_b() {
|
||||||
|
Some(ob) => {
|
||||||
|
let mut i_in_block = 0;
|
||||||
|
let mut i_right_broadcast = 0;
|
||||||
|
rhs[o_r1..o_r2]
|
||||||
|
.iter()
|
||||||
|
.map(|&r| {
|
||||||
|
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
||||||
|
i_right_broadcast += 1;
|
||||||
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
|
i_in_block += 1;
|
||||||
|
i_right_broadcast = 0;
|
||||||
|
}
|
||||||
|
if i_in_block >= ob.len {
|
||||||
|
i_in_block = 0
|
||||||
|
}
|
||||||
|
f(*l, r)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similar to binary_map but with vectorized variants.
|
||||||
|
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
lhs: &[T],
|
||||||
|
rhs: &[T],
|
||||||
|
mut f: F,
|
||||||
|
mut f_vec: FV,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let el_count = lhs_l.shape().elem_count();
|
||||||
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||||
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||||
|
};
|
||||||
|
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||||
|
Some(ob) if ob.right_broadcast == 1 => {
|
||||||
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||||
|
};
|
||||||
|
let mut dst_i = 0;
|
||||||
|
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||||
|
f_vec(
|
||||||
|
&lhs[src_i..src_i + ob.len],
|
||||||
|
rhs,
|
||||||
|
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||||
|
);
|
||||||
|
dst_i += ob.len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
Some(ob) => {
|
||||||
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys = lhs[o_l1..o_l2].to_vec();
|
||||||
|
for idx_l in 0..ob.left_broadcast {
|
||||||
|
let start = idx_l * ob.len * ob.right_broadcast;
|
||||||
|
for (i, &r) in rhs.iter().enumerate() {
|
||||||
|
let start = start + i * ob.right_broadcast;
|
||||||
|
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||||
|
*v = f(*v, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||||
|
Some(ob) if ob.right_broadcast == 1 => {
|
||||||
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||||
|
};
|
||||||
|
let mut dst_i = 0;
|
||||||
|
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||||
|
f_vec(
|
||||||
|
lhs,
|
||||||
|
&rhs[src_i..src_i + ob.len],
|
||||||
|
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||||
|
);
|
||||||
|
dst_i += ob.len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
Some(ob) => {
|
||||||
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys = rhs[o_r1..o_r2].to_vec();
|
||||||
|
for idx_l in 0..ob.left_broadcast {
|
||||||
|
let start = idx_l * ob.len * ob.right_broadcast;
|
||||||
|
for (i, &l) in lhs.iter().enumerate() {
|
||||||
|
let start = start + i * ob.right_broadcast;
|
||||||
|
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||||
|
*v = f(l, *v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
_ => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
mut f: F,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match layout.strided_blocks() {
|
||||||
|
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||||
|
[start_offset..start_offset + len]
|
||||||
|
.iter()
|
||||||
|
.map(|&v| f(v))
|
||||||
|
.collect(),
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
||||||
|
// Specialize the case where block_len is one to avoid the second loop.
|
||||||
|
if block_len == 1 {
|
||||||
|
for index in block_start_index {
|
||||||
|
let v = unsafe { vs.get_unchecked(index) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for index in block_start_index {
|
||||||
|
for offset in 0..block_len {
|
||||||
|
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
mut f: F,
|
||||||
|
mut f_vec: FV,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match layout.strided_blocks() {
|
||||||
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
|
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
|
||||||
|
};
|
||||||
|
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(len) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
let el_count = layout.shape().elem_count();
|
||||||
|
// Specialize the case where block_len is one to avoid the second loop.
|
||||||
|
if block_len == 1 {
|
||||||
|
let mut result = Vec::with_capacity(el_count);
|
||||||
|
for index in block_start_index {
|
||||||
|
let v = unsafe { vs.get_unchecked(index) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
|
||||||
|
};
|
||||||
|
let mut dst_index = 0;
|
||||||
|
for src_index in block_start_index {
|
||||||
|
let vs = &vs[src_index..src_index + block_len];
|
||||||
|
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
||||||
|
f_vec(vs, ys);
|
||||||
|
dst_index += block_len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
452
candle-core/src/cuda_backend/device.rs
Normal file
452
candle-core/src/cuda_backend/device.rs
Normal file
@ -0,0 +1,452 @@
|
|||||||
|
use crate::backend::BackendDevice;
|
||||||
|
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||||
|
pub use candle_kernels as kernels;
|
||||||
|
pub use cudarc;
|
||||||
|
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
||||||
|
|
||||||
|
/// Unique identifier for cuda devices.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||||
|
pub struct DeviceId(usize);
|
||||||
|
|
||||||
|
impl DeviceId {
|
||||||
|
fn new() -> Self {
|
||||||
|
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||||
|
use std::sync::atomic;
|
||||||
|
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||||
|
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CudaRng(cudarc::curand::CudaRng);
|
||||||
|
unsafe impl Send for CudaRng {}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct CudaDevice {
|
||||||
|
id: DeviceId,
|
||||||
|
device: Arc<cudarc::driver::CudaDevice>,
|
||||||
|
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
||||||
|
curand: Arc<Mutex<CudaRng>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for CudaDevice {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "CudaDevice({:?})", self.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for CudaDevice {
|
||||||
|
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CudaDevice {
|
||||||
|
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||||
|
self.device.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> DeviceId {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
||||||
|
let params = (&data, v as u8, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||||
|
let params = (&data, v as u32, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||||
|
let params = (&data, v as i64, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||||
|
let params = (&data, bf16::from_f64(v), elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||||
|
let params = (&data, f16::from_f64(v), elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||||
|
let params = (&data, v as f32, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||||
|
let params = (&data, v, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||||
|
if !self.has_func(module_name, module_name) {
|
||||||
|
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||||
|
// done once per kernel name.
|
||||||
|
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||||
|
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||||
|
.map_err(|cuda| CudaError::Load {
|
||||||
|
cuda,
|
||||||
|
module_name: module_name.to_string(),
|
||||||
|
})
|
||||||
|
.w()?;
|
||||||
|
}
|
||||||
|
self.get_func(module_name, module_name)
|
||||||
|
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||||
|
// able to only build the error value if needed.
|
||||||
|
.ok_or(CudaError::MissingKernel {
|
||||||
|
module_name: module_name.to_string(),
|
||||||
|
})
|
||||||
|
.w()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendDevice for CudaDevice {
|
||||||
|
type Storage = CudaStorage;
|
||||||
|
|
||||||
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
|
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
||||||
|
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||||
|
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||||
|
Ok(Self {
|
||||||
|
id: DeviceId::new(),
|
||||||
|
device,
|
||||||
|
blas: Arc::new(blas),
|
||||||
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
|
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||||
|
// state will be identical and the same random numbers will be generated.
|
||||||
|
let mut curand = self.curand.lock().unwrap();
|
||||||
|
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
crate::DeviceLocation::Cuda {
|
||||||
|
gpu_id: self.device.ordinal(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn same_device(&self, rhs: &Self) -> bool {
|
||||||
|
self.id == rhs.id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let curand = self.curand.lock().unwrap();
|
||||||
|
let slice = match dtype {
|
||||||
|
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||||
|
// cudarc changes.
|
||||||
|
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||||
|
Err(CudaError::UnsupportedDtype {
|
||||||
|
dtype,
|
||||||
|
op: "rand_uniform",
|
||||||
|
})
|
||||||
|
.w()?
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||||
|
curand.0.fill_with_uniform(&mut data).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||||
|
curand.0.fill_with_uniform(&mut data).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let slice = if lo == 0. && up == 1.0 {
|
||||||
|
slice
|
||||||
|
} else {
|
||||||
|
use super::utils::Map1;
|
||||||
|
let layout = Layout::contiguous(shape);
|
||||||
|
super::Affine(up - lo, lo).map(&slice, self, &layout)?
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
||||||
|
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||||
|
// cudarc changes.
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let curand = self.curand.lock().unwrap();
|
||||||
|
// curand can only generate an odd number of values.
|
||||||
|
// https://github.com/huggingface/candle/issues/734
|
||||||
|
let elem_count_round = if elem_count % 2 == 1 {
|
||||||
|
elem_count + 1
|
||||||
|
} else {
|
||||||
|
elem_count
|
||||||
|
};
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||||
|
Err(CudaError::UnsupportedDtype {
|
||||||
|
dtype,
|
||||||
|
op: "rand_normal",
|
||||||
|
})
|
||||||
|
.w()?
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||||
|
curand
|
||||||
|
.0
|
||||||
|
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||||
|
.w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||||
|
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
self.const_impl(1., shape, dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
let data = self.alloc::<u8>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let data = self.alloc::<u32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let data = self.alloc::<i64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
let data = self.alloc::<bf16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let data = self.alloc::<f16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let data = self.alloc::<f32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let data = self.alloc::<f64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||||
|
let slice = match T::cpu_storage_ref(s) {
|
||||||
|
CpuStorageRef::U8(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::U32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::I64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::BF16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::F16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::F32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::F64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
|
let slice = match storage {
|
||||||
|
CpuStorage::U8(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
CpuStorage::U32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::I64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
CpuStorage::BF16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||||
|
let slice = match storage {
|
||||||
|
CpuStorage::U8(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
CpuStorage::U32(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::I64(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
CpuStorage::BF16(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F16(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F32(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F64(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn synchronize(&self) -> Result<()> {
|
||||||
|
self.device.synchronize().map_err(crate::Error::wrap)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
62
candle-core/src/cuda_backend/error.rs
Normal file
62
candle-core/src/cuda_backend/error.rs
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
use crate::{DType, Layout};
|
||||||
|
|
||||||
|
/// cudarc related errors
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum CudaError {
|
||||||
|
#[error(transparent)]
|
||||||
|
Cuda(#[from] cudarc::driver::DriverError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Compiler(#[from] cudarc::nvrtc::CompileError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Curand(#[from] cudarc::curand::result::CurandError),
|
||||||
|
|
||||||
|
#[error("missing kernel '{module_name}'")]
|
||||||
|
MissingKernel { module_name: String },
|
||||||
|
|
||||||
|
#[error("unsupported dtype {dtype:?} for {op}")]
|
||||||
|
UnsupportedDtype { dtype: DType, op: &'static str },
|
||||||
|
|
||||||
|
#[error("internal error '{0}'")]
|
||||||
|
InternalError(&'static str),
|
||||||
|
|
||||||
|
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||||
|
MatMulNonContiguous {
|
||||||
|
lhs_stride: Layout,
|
||||||
|
rhs_stride: Layout,
|
||||||
|
mnk: (usize, usize, usize),
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||||
|
UnexpectedDType {
|
||||||
|
msg: &'static str,
|
||||||
|
expected: DType,
|
||||||
|
got: DType,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("{cuda} when loading {module_name}")]
|
||||||
|
Load {
|
||||||
|
cuda: cudarc::driver::DriverError,
|
||||||
|
module_name: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<CudaError> for crate::Error {
|
||||||
|
fn from(val: CudaError) -> Self {
|
||||||
|
crate::Error::Cuda(Box::new(val)).bt()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait WrapErr<O> {
|
||||||
|
fn w(self) -> std::result::Result<O, crate::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
||||||
|
fn w(self) -> std::result::Result<O, crate::Error> {
|
||||||
|
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
172
candle-core/src/cuda_backend/utils.rs
Normal file
172
candle-core/src/cuda_backend/utils.rs
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
/// Helper functions to plug cuda kernels in candle.
|
||||||
|
use crate::{Layout, Result, Shape, WithDType};
|
||||||
|
pub use cudarc;
|
||||||
|
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||||
|
|
||||||
|
use super::{CudaDevice, CudaError, WrapErr};
|
||||||
|
|
||||||
|
pub type S = super::CudaStorageSlice;
|
||||||
|
|
||||||
|
pub trait Map1 {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||||
|
let out = match s {
|
||||||
|
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||||
|
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||||
|
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||||
|
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||||
|
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||||
|
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||||
|
S::F64(s) => S::F64(self.f(s, d, l)?),
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2 {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||||
|
let out = match (s1, s2) {
|
||||||
|
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map3 {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
src3: &CudaSlice<T>,
|
||||||
|
layout3: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
s1: &S,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &S,
|
||||||
|
l2: &Layout,
|
||||||
|
s3: &S,
|
||||||
|
l3: &Layout,
|
||||||
|
d: &CudaDevice,
|
||||||
|
) -> Result<S> {
|
||||||
|
let out = match (s1, s2, s3) {
|
||||||
|
(S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2InPlace {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
dst: &mut CudaSlice<T>,
|
||||||
|
dst_shape: &Shape,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
src_l: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<()>;
|
||||||
|
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
dst: &mut S,
|
||||||
|
dst_s: &Shape,
|
||||||
|
src: &S,
|
||||||
|
src_l: &Layout,
|
||||||
|
d: &CudaDevice,
|
||||||
|
) -> Result<()> {
|
||||||
|
match (dst, src) {
|
||||||
|
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map1Any {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
wrap: W,
|
||||||
|
) -> Result<S>;
|
||||||
|
|
||||||
|
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||||
|
let out = match s {
|
||||||
|
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||||
|
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||||
|
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||||
|
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||||
|
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||||
|
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||||
|
S::F64(s) => self.f(s, d, l, S::F64)?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2Any {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<S>;
|
||||||
|
|
||||||
|
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||||
|
let out = match (s1, s2) {
|
||||||
|
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
@ -306,6 +306,20 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
|
||||||
|
match self {
|
||||||
|
Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
|
||||||
|
Device::Cuda(device) => {
|
||||||
|
let storage = device.storage_from_slice(data)?;
|
||||||
|
Ok(Storage::Cuda(storage))
|
||||||
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.storage_from_slice(data)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||||
@ -337,4 +351,12 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn synchronize(&self) -> Result<()> {
|
||||||
|
match self {
|
||||||
|
Self::Cpu => Ok(()),
|
||||||
|
Self::Cuda(d) => d.synchronize(),
|
||||||
|
Self::Metal(d) => d.synchronize(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
//! Types for elements that can be stored and manipulated using tensors.
|
//! Types for elements that can be stored and manipulated using tensors.
|
||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::{CpuStorage, Error, Result};
|
use crate::{CpuStorage, CpuStorageRef, Error, Result};
|
||||||
|
|
||||||
/// The different types of elements allowed in tensors.
|
/// The different types of elements allowed in tensors.
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
@ -100,12 +100,14 @@ pub trait WithDType:
|
|||||||
+ 'static
|
+ 'static
|
||||||
+ Send
|
+ Send
|
||||||
+ Sync
|
+ Sync
|
||||||
|
+ std::any::Any
|
||||||
+ crate::cpu::kernels::VecOps
|
+ crate::cpu::kernels::VecOps
|
||||||
{
|
{
|
||||||
const DTYPE: DType;
|
const DTYPE: DType;
|
||||||
|
|
||||||
fn from_f64(v: f64) -> Self;
|
fn from_f64(v: f64) -> Self;
|
||||||
fn to_f64(self) -> f64;
|
fn to_f64(self) -> f64;
|
||||||
|
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||||
|
|
||||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||||
@ -129,6 +131,10 @@ macro_rules! with_dtype {
|
|||||||
$to_f64(self)
|
$to_f64(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
||||||
|
CpuStorageRef::$dtype(data)
|
||||||
|
}
|
||||||
|
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||||
CpuStorage::$dtype(data)
|
CpuStorage::$dtype(data)
|
||||||
}
|
}
|
||||||
|
@ -214,6 +214,10 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
@ -229,4 +233,38 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn synchronize(&self) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||||
|
/// allowed with f16 GEMMs.
|
||||||
|
pub fn gemm_reduced_precision_f16() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||||
|
/// allowed with f16 GEMMs.
|
||||||
|
pub fn set_gemm_reduced_precision_f16(_: bool) {}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||||
|
/// allowed with bf16 GEMMs.
|
||||||
|
pub fn gemm_reduced_precision_bf16() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||||
|
/// allowed with bf16 GEMMs.
|
||||||
|
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||||
|
/// allowed with f32 GEMMs.
|
||||||
|
pub fn gemm_reduced_precision_f32() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||||
|
/// allowed with f32 GEMMs.
|
||||||
|
pub fn set_gemm_reduced_precision_f32(_b: bool) {}
|
||||||
|
@ -226,6 +226,10 @@ impl crate::backend::BackendDevice for MetalDevice {
|
|||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
@ -241,4 +245,8 @@ impl crate::backend::BackendDevice for MetalDevice {
|
|||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn synchronize(&self) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -219,10 +219,14 @@ impl Error {
|
|||||||
Self::Wrapped(Box::new(err)).bt()
|
Self::Wrapped(Box::new(err)).bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
pub fn msg(err: impl std::error::Error) -> Self {
|
||||||
Self::Msg(err.to_string()).bt()
|
Self::Msg(err.to_string()).bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn debug(err: impl std::fmt::Debug) -> Self {
|
||||||
|
Self::Msg(format!("{err:?}")).bt()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn bt(self) -> Self {
|
pub fn bt(self) -> Self {
|
||||||
let backtrace = std::backtrace::Backtrace::capture();
|
let backtrace = std::backtrace::Backtrace::capture();
|
||||||
match backtrace.status() {
|
match backtrace.status() {
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
//!
|
//!
|
||||||
//! ## Features
|
//! ## Features
|
||||||
//!
|
//!
|
||||||
//! - Simple syntax (looks and like PyTorch)
|
//! - Simple syntax (looks and feels like PyTorch)
|
||||||
//! - CPU and Cuda backends (and M1 support)
|
//! - CPU and Cuda backends (and M1 support)
|
||||||
//! - Enable serverless (CPU) small and fast deployments
|
//! - Enable serverless (CPU) small and fast deployments
|
||||||
//! - Model training
|
//! - Model training
|
||||||
@ -37,19 +37,17 @@
|
|||||||
mod accelerate;
|
mod accelerate;
|
||||||
pub mod backend;
|
pub mod backend;
|
||||||
pub mod backprop;
|
pub mod backprop;
|
||||||
mod conv;
|
pub mod conv;
|
||||||
mod convert;
|
mod convert;
|
||||||
pub mod cpu;
|
pub mod cpu;
|
||||||
pub mod cpu_backend;
|
pub mod cpu_backend;
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub mod cuda_backend;
|
pub mod cuda_backend;
|
||||||
#[cfg(feature = "cudnn")]
|
|
||||||
pub mod cudnn;
|
|
||||||
mod custom_op;
|
mod custom_op;
|
||||||
mod device;
|
mod device;
|
||||||
pub mod display;
|
pub mod display;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
mod dummy_cuda_backend;
|
pub mod dummy_cuda_backend;
|
||||||
mod dummy_metal_backend;
|
mod dummy_metal_backend;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
mod indexer;
|
mod indexer;
|
||||||
@ -59,12 +57,13 @@ pub mod metal_backend;
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
mod mkl;
|
mod mkl;
|
||||||
pub mod npy;
|
pub mod npy;
|
||||||
mod op;
|
pub mod op;
|
||||||
pub mod pickle;
|
pub mod pickle;
|
||||||
pub mod quantized;
|
pub mod quantized;
|
||||||
pub mod safetensors;
|
pub mod safetensors;
|
||||||
pub mod scalar;
|
pub mod scalar;
|
||||||
pub mod shape;
|
pub mod shape;
|
||||||
|
mod sort;
|
||||||
mod storage;
|
mod storage;
|
||||||
mod strided_index;
|
mod strided_index;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
@ -73,10 +72,13 @@ pub mod test_utils;
|
|||||||
pub mod utils;
|
pub mod utils;
|
||||||
mod variable;
|
mod variable;
|
||||||
|
|
||||||
pub use cpu_backend::CpuStorage;
|
#[cfg(feature = "cudnn")]
|
||||||
|
pub use cuda_backend::cudnn;
|
||||||
|
|
||||||
|
pub use cpu_backend::{CpuStorage, CpuStorageRef};
|
||||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||||
pub use device::{Device, DeviceLocation, NdArray};
|
pub use device::{Device, DeviceLocation, NdArray};
|
||||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use indexer::IndexOp;
|
pub use indexer::IndexOp;
|
||||||
pub use layout::Layout;
|
pub use layout::Layout;
|
||||||
@ -87,10 +89,12 @@ pub use tensor::{Tensor, TensorId};
|
|||||||
pub use variable::Var;
|
pub use variable::Var;
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub use cuda_backend::{CudaDevice, CudaStorage};
|
pub use cuda_backend as cuda;
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
pub use dummy_cuda_backend as cuda;
|
||||||
|
|
||||||
|
pub use cuda::{CudaDevice, CudaStorage};
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||||
|
287
candle-core/src/metal_backend/device.rs
Normal file
287
candle-core/src/metal_backend/device.rs
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
use crate::{DType, Result};
|
||||||
|
use candle_metal_kernels::Kernels;
|
||||||
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::ffi::c_void;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
|
||||||
|
|
||||||
|
use super::MetalError;
|
||||||
|
|
||||||
|
/// Unique identifier for cuda devices.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||||
|
pub struct DeviceId(usize);
|
||||||
|
|
||||||
|
impl DeviceId {
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||||
|
use std::sync::atomic;
|
||||||
|
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||||
|
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||||
|
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MetalDevice {
|
||||||
|
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
|
||||||
|
/// the device itself.
|
||||||
|
pub(crate) id: DeviceId,
|
||||||
|
|
||||||
|
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||||
|
pub(crate) device: metal::Device,
|
||||||
|
|
||||||
|
/// Single command queue for the entire device.
|
||||||
|
pub(crate) command_queue: CommandQueue,
|
||||||
|
/// One command buffer at a time.
|
||||||
|
/// The scheduler works by allowing multiple
|
||||||
|
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||||
|
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||||
|
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||||
|
/// to start to work).
|
||||||
|
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||||
|
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||||
|
/// command buffer2 starts (or there are metal bugs there)
|
||||||
|
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
|
||||||
|
/// Keeps track of the current amount of compute command encoders on the current
|
||||||
|
/// command buffer
|
||||||
|
/// Arc, RwLock because of the interior mutability.
|
||||||
|
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
|
||||||
|
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||||
|
pub(crate) compute_per_buffer: usize,
|
||||||
|
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||||
|
/// Heavily used by [`candle_metal_kernels`]
|
||||||
|
pub(crate) kernels: Arc<Kernels>,
|
||||||
|
/// Simple allocator struct.
|
||||||
|
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||||
|
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||||
|
/// (could be linked to FFI communication overhead).
|
||||||
|
///
|
||||||
|
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
|
||||||
|
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
|
||||||
|
/// to be reused. However, in order for this to work, we need to guarantee the order of
|
||||||
|
/// operation, so that this buffer is not being used by another kernel at the same time.
|
||||||
|
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
||||||
|
///
|
||||||
|
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||||
|
/// (strong_count = 1).
|
||||||
|
pub(crate) buffers: AllocatedBuffers,
|
||||||
|
/// Seed for random number generation.
|
||||||
|
pub(crate) seed: Arc<Mutex<Buffer>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for MetalDevice {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "MetalDevice({:?})", self.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for MetalDevice {
|
||||||
|
type Target = metal::DeviceRef;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MetalDevice {
|
||||||
|
pub fn id(&self) -> DeviceId {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn metal_device(&self) -> &metal::Device {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn command_queue(&self) -> &CommandQueue {
|
||||||
|
&self.command_queue
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||||
|
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
|
||||||
|
let mut command_buffer = command_buffer_lock.to_owned();
|
||||||
|
let mut index = self
|
||||||
|
.command_buffer_index
|
||||||
|
.write()
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
if *index > self.compute_per_buffer {
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
|
*command_buffer_lock = command_buffer.clone();
|
||||||
|
*index = 0;
|
||||||
|
|
||||||
|
self.drop_unused_buffers()?;
|
||||||
|
}
|
||||||
|
*index += 1;
|
||||||
|
Ok(command_buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wait_until_completed(&self) -> Result<()> {
|
||||||
|
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
|
||||||
|
match command_buffer.status() {
|
||||||
|
metal::MTLCommandBufferStatus::Committed
|
||||||
|
| metal::MTLCommandBufferStatus::Scheduled
|
||||||
|
| metal::MTLCommandBufferStatus::Completed => {
|
||||||
|
panic!("Already committed");
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn kernels(&self) -> &Kernels {
|
||||||
|
&self.kernels
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &metal::Device {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new buffer (not necessarily zeroed).
|
||||||
|
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||||
|
/// This means the buffer data cannot be read on the CPU directly.
|
||||||
|
///
|
||||||
|
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
||||||
|
pub fn new_buffer(
|
||||||
|
&self,
|
||||||
|
element_count: usize,
|
||||||
|
dtype: DType,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Arc<Buffer>> {
|
||||||
|
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||||
|
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new buffer (not necessarily zeroed).
|
||||||
|
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||||
|
/// This means the buffer can be read on the CPU but will require manual
|
||||||
|
/// synchronization when the CPU memory is modified
|
||||||
|
/// Used as a bridge to gather data back from the GPU
|
||||||
|
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
||||||
|
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new buffer from data.
|
||||||
|
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||||
|
///
|
||||||
|
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
||||||
|
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
||||||
|
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||||
|
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||||
|
let new_buffer = self.device.new_buffer_with_data(
|
||||||
|
data.as_ptr() as *const c_void,
|
||||||
|
size,
|
||||||
|
MTLResourceOptions::StorageModeManaged,
|
||||||
|
);
|
||||||
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
|
let subbuffers = buffers
|
||||||
|
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||||
|
.or_insert(vec![]);
|
||||||
|
|
||||||
|
let new_buffer = Arc::new(new_buffer);
|
||||||
|
subbuffers.push(new_buffer.clone());
|
||||||
|
Ok(new_buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
||||||
|
let buffer = self.allocate_buffer(
|
||||||
|
size_in_bytes as NSUInteger,
|
||||||
|
MTLResourceOptions::StorageModePrivate,
|
||||||
|
"allocate_zeros",
|
||||||
|
)?;
|
||||||
|
let command_buffer = self.command_buffer()?;
|
||||||
|
command_buffer.set_label("zeros");
|
||||||
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
|
blit.fill_buffer(
|
||||||
|
&buffer,
|
||||||
|
metal::NSRange {
|
||||||
|
location: 0,
|
||||||
|
length: buffer.length(),
|
||||||
|
},
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
blit.end_encoding();
|
||||||
|
Ok(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_available_buffer(
|
||||||
|
&self,
|
||||||
|
size: NSUInteger,
|
||||||
|
option: MTLResourceOptions,
|
||||||
|
buffers: &RwLockWriteGuard<BufferMap>,
|
||||||
|
) -> Option<Arc<Buffer>> {
|
||||||
|
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||||
|
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||||
|
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||||
|
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||||
|
for sub in subbuffers {
|
||||||
|
if Arc::strong_count(sub) == 1 {
|
||||||
|
best_buffer = Some(sub);
|
||||||
|
best_buffer_size = *buffer_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
best_buffer.cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn drop_unused_buffers(&self) -> Result<()> {
|
||||||
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
|
for subbuffers in buffers.values_mut() {
|
||||||
|
let newbuffers = subbuffers
|
||||||
|
.iter()
|
||||||
|
.filter(|s| Arc::strong_count(*s) > 1)
|
||||||
|
.map(Arc::clone)
|
||||||
|
.collect();
|
||||||
|
*subbuffers = newbuffers;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The critical allocator algorithm
|
||||||
|
fn allocate_buffer(
|
||||||
|
&self,
|
||||||
|
size: NSUInteger,
|
||||||
|
option: MTLResourceOptions,
|
||||||
|
_name: &str,
|
||||||
|
) -> Result<Arc<Buffer>> {
|
||||||
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
|
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||||
|
// Cloning also ensures we increment the strong count
|
||||||
|
return Ok(b.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let size = buf_size(size);
|
||||||
|
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||||
|
|
||||||
|
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||||
|
let new_buffer = Arc::new(new_buffer);
|
||||||
|
subbuffers.push(new_buffer.clone());
|
||||||
|
|
||||||
|
Ok(new_buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a metal GPU capture trace on [`path`].
|
||||||
|
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||||
|
let capture = metal::CaptureManager::shared();
|
||||||
|
let descriptor = metal::CaptureDescriptor::new();
|
||||||
|
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||||
|
descriptor.set_capture_device(self);
|
||||||
|
descriptor.set_output_url(path);
|
||||||
|
|
||||||
|
capture
|
||||||
|
.start_capture(&descriptor)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||||
|
size.saturating_sub(1).next_power_of_two() as NSUInteger
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -330,7 +330,7 @@ impl Tensor {
|
|||||||
path: P,
|
path: P,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
|
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
|
||||||
let options =
|
let options: zip::write::FileOptions<()> =
|
||||||
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
|
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
|
||||||
|
|
||||||
for (name, tensor) in ts.iter() {
|
for (name, tensor) in ts.iter() {
|
||||||
|
@ -66,6 +66,7 @@ pub enum UnaryOp {
|
|||||||
Floor,
|
Floor,
|
||||||
Ceil,
|
Ceil,
|
||||||
Round,
|
Round,
|
||||||
|
Sign,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -254,6 +255,7 @@ pub(crate) struct Tanh;
|
|||||||
pub(crate) struct Floor;
|
pub(crate) struct Floor;
|
||||||
pub(crate) struct Ceil;
|
pub(crate) struct Ceil;
|
||||||
pub(crate) struct Round;
|
pub(crate) struct Round;
|
||||||
|
pub(crate) struct Sign;
|
||||||
|
|
||||||
macro_rules! bin_op {
|
macro_rules! bin_op {
|
||||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||||
@ -457,6 +459,13 @@ unary_op!(Recip, "recip", v, v.recip());
|
|||||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||||
|
|
||||||
|
// Hardcode the value for sqrt(2/pi)
|
||||||
|
// https://github.com/huggingface/candle/issues/1982
|
||||||
|
#[allow(clippy::excessive_precision)]
|
||||||
|
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
|
||||||
|
#[allow(clippy::excessive_precision)]
|
||||||
|
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
|
||||||
|
|
||||||
/// Tanh based approximation of the `gelu` operation
|
/// Tanh based approximation of the `gelu` operation
|
||||||
/// GeluErf is the more precise one.
|
/// GeluErf is the more precise one.
|
||||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
@ -469,7 +478,7 @@ impl UnaryOpT for Gelu {
|
|||||||
* v
|
* v
|
||||||
* (bf16::ONE
|
* (bf16::ONE
|
||||||
+ bf16::tanh(
|
+ bf16::tanh(
|
||||||
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
|
bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
||||||
* v
|
* v
|
||||||
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
||||||
))
|
))
|
||||||
@ -480,22 +489,18 @@ impl UnaryOpT for Gelu {
|
|||||||
* v
|
* v
|
||||||
* (f16::ONE
|
* (f16::ONE
|
||||||
+ f16::tanh(
|
+ f16::tanh(
|
||||||
(f16::from_f32_const(2.0) / f16::PI).sqrt()
|
f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
||||||
* v
|
* v
|
||||||
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn f32(v: f32) -> f32 {
|
fn f32(v: f32) -> f32 {
|
||||||
0.5 * v
|
0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
|
||||||
* (1.0
|
|
||||||
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
|
||||||
}
|
}
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn f64(v: f64) -> f64 {
|
fn f64(v: f64) -> f64 {
|
||||||
0.5 * v
|
0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
|
||||||
* (1.0
|
|
||||||
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
|
||||||
}
|
}
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn u8(_: u8) -> u8 {
|
fn u8(_: u8) -> u8 {
|
||||||
@ -922,3 +927,37 @@ impl std::ops::Deref for BackpropOp {
|
|||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl UnaryOpT for Sign {
|
||||||
|
const NAME: &'static str = "sign";
|
||||||
|
const KERNEL: &'static str = "usign";
|
||||||
|
const V: Self = Sign;
|
||||||
|
#[inline(always)]
|
||||||
|
fn bf16(v: bf16) -> bf16 {
|
||||||
|
bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f16(v: f16) -> f16 {
|
||||||
|
f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32(v: f32) -> f32 {
|
||||||
|
f32::from(v > 0.) - f32::from(v < 0.)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
f64::from(v > 0.) - f64::from(v < 0.)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u8(v: u8) -> u8 {
|
||||||
|
u8::min(1, v)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u32(v: u32) -> u32 {
|
||||||
|
u32::min(1, v)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn i64(v: i64) -> i64 {
|
||||||
|
(v > 0) as i64 - (v < 0) as i64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,24 +1,66 @@
|
|||||||
use super::{GgmlDType, QStorage};
|
use super::{GgmlDType, QStorage};
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||||
use crate::{CudaDevice, CudaStorage, Result};
|
use crate::{CudaDevice, CudaStorage, Result};
|
||||||
|
use half::f16;
|
||||||
|
|
||||||
use cudarc::driver::{CudaSlice, DeviceSlice};
|
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
pub struct QCudaStorage {
|
pub struct QCudaStorage {
|
||||||
data: CudaSlice<u8>,
|
data: CudaSlice<u8>,
|
||||||
dtype: GgmlDType,
|
dtype: GgmlDType,
|
||||||
device: CudaDevice,
|
device: CudaDevice,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
||||||
|
|
||||||
|
pub fn set_force_dmmv(f: bool) {
|
||||||
|
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
pub const WARP_SIZE: usize = 32;
|
pub const WARP_SIZE: usize = 32;
|
||||||
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
|
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
|
||||||
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
|
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
|
||||||
pub const NWARPS_Q4_0_AMPERE: usize = 4;
|
pub const NWARPS_Q4_0_AMPERE: usize = 4;
|
||||||
pub const GGML_CUDA_MMV_X: usize = 32;
|
pub const GGML_CUDA_MMV_X: usize = 32;
|
||||||
pub const GGML_CUDA_MMV_Y: usize = 1;
|
pub const GGML_CUDA_MMV_Y: usize = 1;
|
||||||
|
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
|
||||||
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
||||||
|
pub const MATRIX_ROW_PADDING: usize = 512;
|
||||||
|
|
||||||
fn dequantize(
|
fn ceil_div(p: usize, q: usize) -> usize {
|
||||||
|
(p + q - 1) / q
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pad(p: usize, q: usize) -> usize {
|
||||||
|
ceil_div(p, q) * q
|
||||||
|
}
|
||||||
|
|
||||||
|
fn quantize_q8_1(
|
||||||
|
src: &CudaView<f32>,
|
||||||
|
dst: &mut CudaSlice<u8>,
|
||||||
|
elem_count: usize,
|
||||||
|
ky: usize,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<()> {
|
||||||
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let kx = elem_count;
|
||||||
|
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
||||||
|
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
||||||
|
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
||||||
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
grid_dim: (num_blocks as u32, ky as u32, 1),
|
||||||
|
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
let params = (src, dst, kx as i32, kx_padded as i32);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dequantize_f32(
|
||||||
data: &CudaSlice<u8>,
|
data: &CudaSlice<u8>,
|
||||||
dtype: GgmlDType,
|
dtype: GgmlDType,
|
||||||
elem_count: usize,
|
elem_count: usize,
|
||||||
@ -28,39 +70,31 @@ fn dequantize(
|
|||||||
|
|
||||||
let nb = (elem_count + 255) / 256;
|
let nb = (elem_count + 255) / 256;
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
||||||
GgmlDType::Q5_0 => {
|
GgmlDType::Q5_0 => (
|
||||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
"dequantize_block_q5_0_f32",
|
||||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
false,
|
||||||
(
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
"dequantize_block_q5_0",
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
false,
|
),
|
||||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
GgmlDType::Q5_1 => (
|
||||||
nb,
|
"dequantize_block_q5_1_f32",
|
||||||
)
|
false,
|
||||||
}
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
GgmlDType::Q5_1 => {
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
),
|
||||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
|
||||||
(
|
GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
|
||||||
"dequantize_block_q5_1",
|
GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
|
||||||
false,
|
GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
|
||||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
|
||||||
nb,
|
GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
|
||||||
)
|
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
||||||
}
|
|
||||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
|
||||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
|
||||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
|
||||||
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
|
|
||||||
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
|
|
||||||
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
|
|
||||||
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
|
|
||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
|
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
||||||
// See e.g.
|
// See e.g.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
@ -83,9 +117,66 @@ fn dequantize(
|
|||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dequantize_mut_mal_vec(
|
fn dequantize_f16(
|
||||||
data: &CudaSlice<u8>,
|
data: &CudaSlice<u8>,
|
||||||
y: &cudarc::driver::CudaView<f32>,
|
dtype: GgmlDType,
|
||||||
|
elem_count: usize,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaStorage> {
|
||||||
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let nb = (elem_count + 255) / 256;
|
||||||
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||||
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
||||||
|
GgmlDType::Q5_0 => (
|
||||||
|
"dequantize_block_q5_0_f16",
|
||||||
|
false,
|
||||||
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
|
),
|
||||||
|
GgmlDType::Q5_1 => (
|
||||||
|
"dequantize_block_q5_1_f16",
|
||||||
|
false,
|
||||||
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
|
),
|
||||||
|
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
|
||||||
|
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
|
||||||
|
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
|
||||||
|
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
|
||||||
|
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
|
||||||
|
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
|
||||||
|
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
||||||
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
|
};
|
||||||
|
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
|
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
||||||
|
// See e.g.
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
grid_dim: (num_blocks as u32, 1, 1),
|
||||||
|
block_dim: (block_dim as u32, 1, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
if is_k {
|
||||||
|
let params = (data, &dst);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
} else {
|
||||||
|
let nb32 = match dtype {
|
||||||
|
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||||
|
_ => elem_count / 32,
|
||||||
|
};
|
||||||
|
let params = (data, &dst, nb32 as i32);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
}
|
||||||
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dequantize_mul_mat_vec(
|
||||||
|
data: &CudaSlice<u8>,
|
||||||
|
y: &CudaView<f32>,
|
||||||
dtype: GgmlDType,
|
dtype: GgmlDType,
|
||||||
ncols: usize,
|
ncols: usize,
|
||||||
nrows: usize,
|
nrows: usize,
|
||||||
@ -93,6 +184,13 @@ fn dequantize_mut_mal_vec(
|
|||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||||
|
if data_elems < ncols * nrows {
|
||||||
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
|
}
|
||||||
|
if y.len() != ncols {
|
||||||
|
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||||
|
}
|
||||||
let kernel_name = match dtype {
|
let kernel_name = match dtype {
|
||||||
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
||||||
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
||||||
@ -107,8 +205,8 @@ fn dequantize_mut_mal_vec(
|
|||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
let dst = dev.alloc_zeros::<f32>(nrows).w()?;
|
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||||
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (block_num_y as u32, 1, 1),
|
grid_dim: (block_num_y as u32, 1, 1),
|
||||||
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
||||||
@ -120,9 +218,149 @@ fn dequantize_mut_mal_vec(
|
|||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mul_mat_vec_via_q8_1(
|
||||||
|
data: &CudaSlice<u8>,
|
||||||
|
y: &CudaView<f32>,
|
||||||
|
dtype: GgmlDType,
|
||||||
|
ncols: usize,
|
||||||
|
nrows: usize,
|
||||||
|
b_size: usize,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaStorage> {
|
||||||
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||||
|
if data_elems < ncols * nrows {
|
||||||
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
|
}
|
||||||
|
if y.len() != ncols * b_size {
|
||||||
|
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||||
|
}
|
||||||
|
if b_size == 0 || b_size > 8 {
|
||||||
|
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
|
||||||
|
}
|
||||||
|
// Start by quantizing y
|
||||||
|
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||||
|
let y_size_in_bytes =
|
||||||
|
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
|
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
||||||
|
|
||||||
|
let kernel_name = match dtype {
|
||||||
|
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
|
||||||
|
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
|
||||||
|
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
|
||||||
|
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
|
||||||
|
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
|
||||||
|
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
|
||||||
|
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
|
||||||
|
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
|
||||||
|
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
|
||||||
|
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
|
||||||
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
|
};
|
||||||
|
let kernel_name = format!("{kernel_name}{b_size}");
|
||||||
|
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
|
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||||
|
let (nblocks, nwarps) = match b_size {
|
||||||
|
1 => (nrows as u32, 4),
|
||||||
|
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
||||||
|
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
||||||
|
_ => crate::bail!("unexpected bsize {b_size}"),
|
||||||
|
};
|
||||||
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
grid_dim: (nblocks, 1, 1),
|
||||||
|
block_dim: (WARP_SIZE as u32, nwarps, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let params = (
|
||||||
|
data,
|
||||||
|
&y_q8_1,
|
||||||
|
&dst,
|
||||||
|
/* ncols_x */ ncols as i32,
|
||||||
|
/* nrows_x */ nrows as i32,
|
||||||
|
/* nrows_y */ ncols_padded as i32,
|
||||||
|
/* nrows_dst */ nrows as i32,
|
||||||
|
);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn mul_mat_via_q8_1(
|
||||||
|
data: &CudaSlice<u8>,
|
||||||
|
y: &CudaView<f32>,
|
||||||
|
dtype: GgmlDType,
|
||||||
|
x_rows: usize,
|
||||||
|
x_cols: usize,
|
||||||
|
y_rows: usize,
|
||||||
|
y_cols: usize,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaStorage> {
|
||||||
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||||
|
if data_elems < x_rows * x_cols {
|
||||||
|
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
||||||
|
}
|
||||||
|
if y.len() != y_rows * y_cols {
|
||||||
|
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
|
||||||
|
}
|
||||||
|
if x_cols != y_rows {
|
||||||
|
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
|
||||||
|
}
|
||||||
|
let k = x_cols;
|
||||||
|
// Start by quantizing y
|
||||||
|
let k_padded = pad(k, MATRIX_ROW_PADDING);
|
||||||
|
let y_size_in_bytes =
|
||||||
|
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
|
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
|
||||||
|
|
||||||
|
let (kernel_name, mmq_x, mmq_y) = match dtype {
|
||||||
|
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
|
||||||
|
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
|
||||||
|
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
|
||||||
|
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
|
||||||
|
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
|
||||||
|
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
|
||||||
|
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
|
||||||
|
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
|
||||||
|
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
|
||||||
|
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
||||||
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
|
};
|
||||||
|
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
|
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
||||||
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
grid_dim: (
|
||||||
|
ceil_div(x_rows, mmq_y) as u32,
|
||||||
|
ceil_div(y_cols, mmq_x) as u32,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let params = (
|
||||||
|
/* vx */ data,
|
||||||
|
/* vy */ &y_q8_1,
|
||||||
|
/* dst */ &dst,
|
||||||
|
/* ncols_x */ x_cols as i32,
|
||||||
|
/* nrows_x */ x_rows as i32,
|
||||||
|
/* ncols_y */ y_cols as i32,
|
||||||
|
/* nrows_y */ k_padded as i32,
|
||||||
|
/* nrows_dst */ x_rows as i32,
|
||||||
|
);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
impl QCudaStorage {
|
impl QCudaStorage {
|
||||||
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||||
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
|
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
||||||
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
|
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
|
||||||
Ok(QCudaStorage {
|
Ok(QCudaStorage {
|
||||||
data,
|
data,
|
||||||
@ -140,6 +378,12 @@ impl QCudaStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||||
|
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
|
||||||
|
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
||||||
|
let vec = slice.to_vec();
|
||||||
|
T::to_float(&vec, dst)
|
||||||
|
}
|
||||||
|
|
||||||
let fast_kernel = matches!(
|
let fast_kernel = matches!(
|
||||||
self.dtype,
|
self.dtype,
|
||||||
GgmlDType::Q4_0
|
GgmlDType::Q4_0
|
||||||
@ -155,78 +399,38 @@ impl QCudaStorage {
|
|||||||
| GgmlDType::Q8K
|
| GgmlDType::Q8K
|
||||||
);
|
);
|
||||||
if fast_kernel {
|
if fast_kernel {
|
||||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
|
||||||
}
|
}
|
||||||
// Run the dequantization on cpu.
|
// Run the dequantization on cpu.
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
|
|
||||||
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
|
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
|
||||||
let mut out = vec![0.0; elem_count];
|
let mut out = vec![0.0; elem_count];
|
||||||
let block_len = elem_count / self.dtype.block_size();
|
let block_len = elem_count / self.dtype.block_size();
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
GgmlDType::F32 => {
|
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
|
||||||
let slice =
|
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
|
||||||
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) };
|
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
|
||||||
out.copy_from_slice(slice)
|
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
|
||||||
}
|
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
|
||||||
GgmlDType::F16 => {
|
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
|
||||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
|
||||||
half::f16::to_float(&vec, &mut out)?;
|
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
|
||||||
}
|
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
|
||||||
GgmlDType::Q4_0 => {
|
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
|
||||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
|
||||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
|
||||||
}
|
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
|
||||||
GgmlDType::Q4_1 => {
|
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
|
||||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q5_0 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q5_1 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q8_0 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q8_1 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q2K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q3K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q4K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q5K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q6K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q8K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.device
|
self.device
|
||||||
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||||
|
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||||
// Run the quantization on cpu.
|
// Run the quantization on cpu.
|
||||||
let src = match &src.slice {
|
let src = match &src.slice {
|
||||||
@ -255,7 +459,17 @@ impl QCudaStorage {
|
|||||||
storage: &CudaStorage,
|
storage: &CudaStorage,
|
||||||
layout: &crate::Layout,
|
layout: &crate::Layout,
|
||||||
) -> Result<(CudaStorage, crate::Shape)> {
|
) -> Result<(CudaStorage, crate::Shape)> {
|
||||||
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
|
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
8
|
||||||
|
};
|
||||||
|
let use_vec_kernel = match layout.shape().dims() {
|
||||||
|
[b, m, _k] => b * m <= max_bm,
|
||||||
|
[b, _k] => *b <= max_bm,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
if use_vec_kernel {
|
||||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
self.dequantize_matmul_vec(self_shape, storage, layout)
|
||||||
} else {
|
} else {
|
||||||
self.dequantize_matmul(self_shape, storage, layout)
|
self.dequantize_matmul(self_shape, storage, layout)
|
||||||
@ -276,22 +490,31 @@ impl QCudaStorage {
|
|||||||
Some((o1, o2)) => rhs.slice(o1..o2),
|
Some((o1, o2)) => rhs.slice(o1..o2),
|
||||||
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
||||||
};
|
};
|
||||||
let (with_batch, k) = match rhs_l.shape().dims() {
|
let (b_size, k) = match rhs_l.shape().dims() {
|
||||||
[1, 1, k] => (true, k),
|
[b, m, k] => (b * m, *k),
|
||||||
[1, k] => (false, k),
|
[b, k] => (*b, *k),
|
||||||
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
||||||
};
|
};
|
||||||
if ncols != *k {
|
if ncols != k {
|
||||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
||||||
}
|
}
|
||||||
|
|
||||||
let out =
|
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
|
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||||
let out_shape = if with_batch {
|
|
||||||
vec![1, 1, nrows]
|
|
||||||
} else {
|
} else {
|
||||||
vec![1, nrows]
|
mul_mat_vec_via_q8_1(
|
||||||
|
&self.data,
|
||||||
|
&rhs,
|
||||||
|
self.dtype,
|
||||||
|
ncols,
|
||||||
|
nrows,
|
||||||
|
b_size,
|
||||||
|
self.device(),
|
||||||
|
)?
|
||||||
};
|
};
|
||||||
|
let mut out_shape = rhs_l.shape().dims().to_vec();
|
||||||
|
out_shape.pop();
|
||||||
|
out_shape.push(nrows);
|
||||||
Ok((out, out_shape.into()))
|
Ok((out, out_shape.into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,9 +535,30 @@ impl QCudaStorage {
|
|||||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
|
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
|
||||||
}
|
}
|
||||||
|
|
||||||
let data_f32 = self.dequantize(n * k)?;
|
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
let data_f32 = self.dequantize(n * k)?;
|
||||||
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
|
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
||||||
|
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
|
||||||
|
} else {
|
||||||
|
let storage = storage.as_cuda_slice::<f32>()?;
|
||||||
|
let storage = match layout.contiguous_offsets() {
|
||||||
|
Some((o1, o2)) => storage.slice(o1..o2),
|
||||||
|
None => Err(crate::Error::RequiresContiguous {
|
||||||
|
op: "quantized-matmul",
|
||||||
|
}
|
||||||
|
.bt())?,
|
||||||
|
};
|
||||||
|
mul_mat_via_q8_1(
|
||||||
|
&self.data,
|
||||||
|
&storage,
|
||||||
|
self.dtype,
|
||||||
|
/* x_rows */ n,
|
||||||
|
/* x_cols */ k,
|
||||||
|
/* y_rows */ k,
|
||||||
|
/* y_cols */ b * m,
|
||||||
|
self.device(),
|
||||||
|
)?
|
||||||
|
};
|
||||||
let mut out_shape = layout.shape().dims().to_vec();
|
let mut out_shape = layout.shape().dims().to_vec();
|
||||||
out_shape.pop();
|
out_shape.pop();
|
||||||
out_shape.push(n);
|
out_shape.push(n);
|
||||||
@ -322,11 +566,6 @@ impl QCudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> {
|
|
||||||
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
|
||||||
slice.to_vec()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||||
device: &CudaDevice,
|
device: &CudaDevice,
|
||||||
data: &[T],
|
data: &[T],
|
||||||
@ -341,3 +580,101 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
|||||||
dtype: T::DTYPE,
|
dtype: T::DTYPE,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cuda_quantize_q8_1() -> Result<()> {
|
||||||
|
let dev = CudaDevice::new(0)?;
|
||||||
|
let el = 256;
|
||||||
|
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
||||||
|
let y_size_in_bytes =
|
||||||
|
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
|
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||||
|
let y = dev.htod_sync_copy(&vs).w()?;
|
||||||
|
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cuda_mmv_q8_1() -> Result<()> {
|
||||||
|
let dev = CudaDevice::new(0)?;
|
||||||
|
let ncols = 256;
|
||||||
|
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
||||||
|
let y = dev.htod_sync_copy(&vs).w()?;
|
||||||
|
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
||||||
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
|
let cuda_storage = mul_mat_vec_via_q8_1(
|
||||||
|
&xs.data,
|
||||||
|
&y.slice(..),
|
||||||
|
/* dtype */ GgmlDType::Q4_0,
|
||||||
|
/* ncols */ ncols,
|
||||||
|
/* nrows */ 1,
|
||||||
|
/* b_size */ 1,
|
||||||
|
&dev,
|
||||||
|
)?;
|
||||||
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
|
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||||
|
assert_eq!(vs.len(), 1);
|
||||||
|
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
||||||
|
// Q8 means 1/256 precision.
|
||||||
|
assert_eq!(vs[0], 5561664.5);
|
||||||
|
|
||||||
|
let cuda_storage = dequantize_mul_mat_vec(
|
||||||
|
&xs.data,
|
||||||
|
&y.slice(..),
|
||||||
|
/* dtype */ GgmlDType::Q4_0,
|
||||||
|
/* ncols */ ncols,
|
||||||
|
/* nrows */ 1,
|
||||||
|
&dev,
|
||||||
|
)?;
|
||||||
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
|
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||||
|
assert_eq!(vs.len(), 1);
|
||||||
|
assert_eq!(vs[0], 5561851.0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cuda_mm_q8_1() -> Result<()> {
|
||||||
|
let dev = CudaDevice::new(0)?;
|
||||||
|
let ncols = 256;
|
||||||
|
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
||||||
|
let y = dev.htod_sync_copy(&vs).w()?;
|
||||||
|
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
||||||
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
|
let cuda_storage = mul_mat_via_q8_1(
|
||||||
|
&xs.data,
|
||||||
|
&y.slice(..),
|
||||||
|
/* dtype */ GgmlDType::Q4_0,
|
||||||
|
/* x_rows */ 4,
|
||||||
|
/* x_cols */ ncols,
|
||||||
|
/* y_rows */ ncols,
|
||||||
|
/* y_cols */ 4,
|
||||||
|
&dev,
|
||||||
|
)?;
|
||||||
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
|
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||||
|
|
||||||
|
/*
|
||||||
|
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
||||||
|
x @ x.t() / 16
|
||||||
|
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
|
||||||
|
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
|
||||||
|
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
|
||||||
|
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
|
||||||
|
*/
|
||||||
|
assert_eq!(vs.len(), 16);
|
||||||
|
assert_eq!(vs[0], 347604.0);
|
||||||
|
assert_eq!(vs[1], 888153.06);
|
||||||
|
assert_eq!(vs[4], 869780.7);
|
||||||
|
assert_eq!(vs[5], 2483145.0);
|
||||||
|
assert_eq!(vs[11], 9407368.0);
|
||||||
|
assert_eq!(vs[14], 9470856.0);
|
||||||
|
assert_eq!(vs[15], 13138824.0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -24,6 +24,10 @@ impl QCudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
@ -135,7 +135,6 @@ pub enum ValueType {
|
|||||||
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
||||||
String,
|
String,
|
||||||
// The value is an array of other values, with the length and type prepended.
|
// The value is an array of other values, with the length and type prepended.
|
||||||
///
|
|
||||||
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
||||||
Array,
|
Array,
|
||||||
}
|
}
|
||||||
@ -218,10 +217,16 @@ impl Value {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// This will also automatically upcast any integral types which will not truncate.
|
||||||
pub fn to_u64(&self) -> Result<u64> {
|
pub fn to_u64(&self) -> Result<u64> {
|
||||||
match self {
|
match self {
|
||||||
Self::U64(v) => Ok(*v),
|
Self::U64(v) => Ok(*v),
|
||||||
v => crate::bail!("not a u64 {v:?}"),
|
// Autoupcast cases here
|
||||||
|
Self::U8(v) => Ok(*v as u64),
|
||||||
|
Self::U16(v) => Ok(*v as u64),
|
||||||
|
Self::U32(v) => Ok(*v as u64),
|
||||||
|
Self::Bool(v) => Ok(*v as u64),
|
||||||
|
v => crate::bail!("not a u64 or upcastable to u64 {v:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,9 +149,12 @@ impl QMetalStorage {
|
|||||||
let (n, k) = self_shape.dims2()?;
|
let (n, k) = self_shape.dims2()?;
|
||||||
let mut dst_shape = src_shape.dims().to_vec();
|
let mut dst_shape = src_shape.dims().to_vec();
|
||||||
|
|
||||||
let (b, m) = match dst_shape.len() {
|
// We always use a single batch dimension and stack all the tensors in the batch on the
|
||||||
3 => (dst_shape[0], dst_shape[1]),
|
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
||||||
2 => (1, dst_shape[0]),
|
// properly.
|
||||||
|
let m = match dst_shape.len() {
|
||||||
|
3 => dst_shape[0] * dst_shape[1],
|
||||||
|
2 => dst_shape[0],
|
||||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||||
};
|
};
|
||||||
let last_k = dst_shape.pop().unwrap();
|
let last_k = dst_shape.pop().unwrap();
|
||||||
@ -163,18 +166,23 @@ impl QMetalStorage {
|
|||||||
let device = storage.device().clone();
|
let device = storage.device().clone();
|
||||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||||
let command_buffer = device.command_buffer()?;
|
let command_buffer = device.command_buffer()?;
|
||||||
candle_metal_kernels::call_quantized_matmul_t(
|
// In some cases it would be better to use the mm variant, though it has its drawbacks
|
||||||
device.device(),
|
// around memory alignemnt.
|
||||||
&command_buffer,
|
for batch_id in 0..m {
|
||||||
device.kernels(),
|
candle_metal_kernels::call_quantized_matmul_mv_t(
|
||||||
self.dtype.into(),
|
device.device(),
|
||||||
(b, m, n, k),
|
&command_buffer,
|
||||||
storage.buffer(),
|
device.kernels(),
|
||||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
self.dtype.into(),
|
||||||
&self.buffer,
|
(1, 1, n, k),
|
||||||
&dst,
|
storage.buffer(),
|
||||||
)
|
(layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
|
||||||
.map_err(MetalError::from)?;
|
&self.buffer,
|
||||||
|
batch_id * n * DType::F32.size_in_bytes(),
|
||||||
|
&dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
}
|
||||||
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
|
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
|
||||||
Ok((dst_storage, dst_shape))
|
Ok((dst_storage, dst_shape))
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
|
||||||
use k_quants::*;
|
use k_quants::*;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
@ -360,9 +360,24 @@ impl QTensor {
|
|||||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
||||||
let none = crate::op::BackpropOp::none();
|
let none = crate::op::BackpropOp::none();
|
||||||
let is_variable = false;
|
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
|
||||||
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
|
}
|
||||||
.to_device(device)
|
|
||||||
|
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
|
||||||
|
// In the CUDA case, we have a specialized kernel as this can be useful for volta
|
||||||
|
// architectures. https://github.com/huggingface/candle/issues/2136
|
||||||
|
match &self.storage {
|
||||||
|
QStorage::Cuda(s) => {
|
||||||
|
let s = s.dequantize_f16(self.shape.elem_count())?;
|
||||||
|
let none = crate::op::BackpropOp::none();
|
||||||
|
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
|
||||||
|
.to_device(device)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
|
||||||
|
Ok(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
pub fn storage_size_in_bytes(&self) -> usize {
|
||||||
@ -378,6 +393,7 @@ impl QTensor {
|
|||||||
pub enum QMatMul {
|
pub enum QMatMul {
|
||||||
QTensor(std::sync::Arc<QTensor>),
|
QTensor(std::sync::Arc<QTensor>),
|
||||||
Tensor(Tensor),
|
Tensor(Tensor),
|
||||||
|
TensorF16(Tensor),
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_local! {
|
thread_local! {
|
||||||
@ -391,6 +407,17 @@ thread_local! {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static DEQUANTIZE_ALL_F16: bool = {
|
||||||
|
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
|
||||||
|
Ok(s) => {
|
||||||
|
!s.is_empty() && s != "0"
|
||||||
|
},
|
||||||
|
Err(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||||
let dequantize = match qtensor.dtype() {
|
let dequantize = match qtensor.dtype() {
|
||||||
@ -400,6 +427,9 @@ impl QMatMul {
|
|||||||
let t = if dequantize {
|
let t = if dequantize {
|
||||||
let tensor = qtensor.dequantize(&qtensor.device())?;
|
let tensor = qtensor.dequantize(&qtensor.device())?;
|
||||||
Self::Tensor(tensor)
|
Self::Tensor(tensor)
|
||||||
|
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
|
||||||
|
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
|
||||||
|
Self::TensorF16(tensor)
|
||||||
} else {
|
} else {
|
||||||
Self::QTensor(qtensor)
|
Self::QTensor(qtensor)
|
||||||
};
|
};
|
||||||
@ -409,6 +439,25 @@ impl QMatMul {
|
|||||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dequantize_f16(&self) -> Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::QTensor(t) => t.dequantize_f16(&t.device()),
|
||||||
|
Self::Tensor(t) => t.to_dtype(DType::F16),
|
||||||
|
Self::TensorF16(t) => Ok(t.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let w = self.dequantize_f16()?;
|
||||||
|
let in_dtype = xs.dtype();
|
||||||
|
let w = match *xs.dims() {
|
||||||
|
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||||
|
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||||
|
_ => w.t()?,
|
||||||
|
};
|
||||||
|
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::CustomOp1 for QTensor {
|
impl crate::CustomOp1 for QTensor {
|
||||||
@ -486,6 +535,15 @@ impl crate::Module for QMatMul {
|
|||||||
};
|
};
|
||||||
xs.matmul(&w)
|
xs.matmul(&w)
|
||||||
}
|
}
|
||||||
|
Self::TensorF16(w) => {
|
||||||
|
let in_dtype = xs.dtype();
|
||||||
|
let w = match *xs.dims() {
|
||||||
|
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||||
|
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||||
|
_ => w.t()?,
|
||||||
|
};
|
||||||
|
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -349,6 +349,30 @@ impl MmapedSafetensors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct SliceSafetensors<'a> {
|
||||||
|
safetensors: SafeTensors<'a>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> SliceSafetensors<'a> {
|
||||||
|
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||||
|
pub fn new(buffer: &'a [u8]) -> Result<Self> {
|
||||||
|
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
|
||||||
|
Ok(Self { safetensors })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||||
|
self.safetensors.tensor(name)?.load(dev)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||||
|
self.safetensors.tensors()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||||
|
Ok(self.safetensors.tensor(name)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct BufferedSafetensors {
|
pub struct BufferedSafetensors {
|
||||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||||
}
|
}
|
||||||
|
@ -171,7 +171,7 @@ impl Shape {
|
|||||||
}
|
}
|
||||||
let mut acc = 1;
|
let mut acc = 1;
|
||||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
||||||
if stride != acc {
|
if dim > 1 && stride != acc {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
acc *= dim;
|
acc *= dim;
|
||||||
@ -186,7 +186,7 @@ impl Shape {
|
|||||||
}
|
}
|
||||||
let mut acc = 1;
|
let mut acc = 1;
|
||||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
||||||
if stride != acc {
|
if dim > 1 && stride != acc {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
acc *= dim;
|
acc *= dim;
|
||||||
|
239
candle-core/src/sort.rs
Normal file
239
candle-core/src/sort.rs
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
use crate::{Result, Tensor};
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
struct ArgSort {
|
||||||
|
asc: bool,
|
||||||
|
last_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ArgSort {
|
||||||
|
fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
|
||||||
|
#[allow(clippy::uninit_vec)]
|
||||||
|
// Safety: indexes are set later in the parallelized section.
|
||||||
|
let mut sort_indexes = unsafe {
|
||||||
|
let el_count = layout.shape().elem_count();
|
||||||
|
let mut v = Vec::with_capacity(el_count);
|
||||||
|
v.set_len(el_count);
|
||||||
|
v
|
||||||
|
};
|
||||||
|
if self.asc {
|
||||||
|
sort_indexes
|
||||||
|
.par_chunks_exact_mut(self.last_dim)
|
||||||
|
.zip(vs.par_chunks_exact(self.last_dim))
|
||||||
|
.for_each(|(indexes, vs)| {
|
||||||
|
indexes
|
||||||
|
.iter_mut()
|
||||||
|
.enumerate()
|
||||||
|
.for_each(|(i, v)| *v = i as u32);
|
||||||
|
indexes.sort_by(|&i, &j| {
|
||||||
|
vs[i as usize]
|
||||||
|
.partial_cmp(&vs[j as usize])
|
||||||
|
.unwrap_or(std::cmp::Ordering::Greater)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
sort_indexes
|
||||||
|
.par_chunks_exact_mut(self.last_dim)
|
||||||
|
.zip(vs.par_chunks_exact(self.last_dim))
|
||||||
|
.for_each(|(indexes, vs)| {
|
||||||
|
indexes
|
||||||
|
.iter_mut()
|
||||||
|
.enumerate()
|
||||||
|
.for_each(|(i, v)| *v = i as u32);
|
||||||
|
indexes.sort_by(|&j, &i| {
|
||||||
|
vs[i as usize]
|
||||||
|
.partial_cmp(&vs[j as usize])
|
||||||
|
.unwrap_or(std::cmp::Ordering::Greater)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
sort_indexes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::CustomOp1 for ArgSort {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"argsort"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &crate::CpuStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(crate::CpuStorage, crate::Shape)> {
|
||||||
|
let sort_indexes = match storage {
|
||||||
|
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
|
||||||
|
};
|
||||||
|
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
|
||||||
|
Ok((sort_indexes, layout.shape().into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &crate::CudaStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
||||||
|
use crate::cuda_backend::cudarc::driver::{
|
||||||
|
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||||
|
};
|
||||||
|
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
|
||||||
|
use crate::{CudaDevice, WithDType};
|
||||||
|
|
||||||
|
impl Map1Any for ArgSort {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
_wrap: W,
|
||||||
|
) -> Result<S> {
|
||||||
|
let slice = match layout.contiguous_offsets() {
|
||||||
|
None => crate::bail!("input has to be contiguous"),
|
||||||
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
|
};
|
||||||
|
let elem_count = layout.shape().elem_count();
|
||||||
|
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||||
|
let func = if self.asc {
|
||||||
|
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
||||||
|
} else {
|
||||||
|
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
||||||
|
};
|
||||||
|
let ncols = self.last_dim;
|
||||||
|
let nrows = elem_count / ncols;
|
||||||
|
let ncols_pad = next_power_of_2(ncols);
|
||||||
|
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
||||||
|
let cfg = LaunchConfig {
|
||||||
|
grid_dim: (1, nrows as u32, 1),
|
||||||
|
block_dim: (ncols_pad as u32, 1, 1),
|
||||||
|
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||||
|
};
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(S::U32(dst))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use crate::backend::BackendStorage;
|
||||||
|
let dev = storage.device();
|
||||||
|
let slice = self.map(&storage.slice, dev, layout)?;
|
||||||
|
let dst = crate::cuda_backend::CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: dev.clone(),
|
||||||
|
};
|
||||||
|
Ok((dst, layout.shape().clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &crate::MetalStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(crate::MetalStorage, crate::Shape)> {
|
||||||
|
use crate::backend::BackendStorage;
|
||||||
|
use crate::DType;
|
||||||
|
|
||||||
|
let name = {
|
||||||
|
if self.asc {
|
||||||
|
match storage.dtype() {
|
||||||
|
DType::BF16 => "asort_asc_bf16",
|
||||||
|
DType::F16 => "asort_asc_f16",
|
||||||
|
DType::F32 => "asort_asc_f32",
|
||||||
|
DType::F64 => "asort_asc_f64",
|
||||||
|
DType::U8 => "asort_asc_u8",
|
||||||
|
DType::U32 => "asort_asc_u32",
|
||||||
|
DType::I64 => "asort_asc_i64",
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match storage.dtype() {
|
||||||
|
DType::BF16 => "asort_desc_bf16",
|
||||||
|
DType::F16 => "asort_desc_f16",
|
||||||
|
DType::F32 => "asort_desc_f32",
|
||||||
|
DType::F64 => "asort_desc_f64",
|
||||||
|
DType::U8 => "asort_desc_u8",
|
||||||
|
DType::U32 => "asort_desc_u32",
|
||||||
|
DType::I64 => "asort_desc_i64",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let device = storage.device();
|
||||||
|
let kernels = device.kernels();
|
||||||
|
let command_buffer = device.command_buffer()?;
|
||||||
|
let el = layout.shape().elem_count();
|
||||||
|
let ncols = self.last_dim;
|
||||||
|
let nrows = el / ncols;
|
||||||
|
let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
|
||||||
|
let dst = device.new_buffer(el, DType::U32, "asort")?;
|
||||||
|
let mut ncols_pad = 1;
|
||||||
|
while ncols_pad < ncols {
|
||||||
|
ncols_pad *= 2;
|
||||||
|
}
|
||||||
|
candle_metal_kernels::call_arg_sort(
|
||||||
|
device.metal_device(),
|
||||||
|
&command_buffer,
|
||||||
|
kernels,
|
||||||
|
name,
|
||||||
|
nrows,
|
||||||
|
ncols,
|
||||||
|
ncols_pad,
|
||||||
|
src,
|
||||||
|
&dst,
|
||||||
|
)
|
||||||
|
.map_err(crate::Error::wrap)?;
|
||||||
|
let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
|
||||||
|
Ok((dst, layout.shape().clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
fn next_power_of_2(x: usize) -> usize {
|
||||||
|
let mut n = 1;
|
||||||
|
while n < x {
|
||||||
|
n *= 2
|
||||||
|
}
|
||||||
|
n
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tensor {
|
||||||
|
/// Returns the indices that sort the tensor along the last dimension.
|
||||||
|
///
|
||||||
|
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||||
|
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||||
|
/// comes to ties.
|
||||||
|
pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
|
||||||
|
if !self.is_contiguous() {
|
||||||
|
return Err(crate::Error::RequiresContiguous {
|
||||||
|
op: "arg_sort_last_dim",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let last_dim = match self.dims().last() {
|
||||||
|
None => crate::bail!("empty last-dim in arg-sort"),
|
||||||
|
Some(last_dim) => *last_dim,
|
||||||
|
};
|
||||||
|
// No need for a backward pass for arg sort.
|
||||||
|
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
|
||||||
|
/// sorted indexes.
|
||||||
|
///
|
||||||
|
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||||
|
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||||
|
/// comes to ties.
|
||||||
|
pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
|
||||||
|
if !self.is_contiguous() {
|
||||||
|
return Err(crate::Error::RequiresContiguous {
|
||||||
|
op: "sort_last_dim",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let asort = self.arg_sort_last_dim(asc)?;
|
||||||
|
let sorted = self.gather(&asort, crate::D::Minus1)?;
|
||||||
|
Ok((sorted, asort))
|
||||||
|
}
|
||||||
|
}
|
@ -44,9 +44,19 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||||
let lhs = self.device().location();
|
let lhs_device = self.device();
|
||||||
let rhs = rhs.device().location();
|
let rhs_device = rhs.device();
|
||||||
if lhs != rhs {
|
let lhs = lhs_device.location();
|
||||||
|
let rhs = rhs_device.location();
|
||||||
|
let same_device = if self.device().is_metal() {
|
||||||
|
// On metal, we require the device to be exactly the same rather than
|
||||||
|
// having the same location. In cuda this is not necessary as all CudaDevice on the
|
||||||
|
// same GPU will use the same cuda stream.
|
||||||
|
lhs_device.same_device(&rhs_device)
|
||||||
|
} else {
|
||||||
|
lhs == rhs
|
||||||
|
};
|
||||||
|
if !same_device {
|
||||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -79,6 +79,9 @@ macro_rules! unary_op {
|
|||||||
($fn_name:ident, $op_name:ident) => {
|
($fn_name:ident, $op_name:ident) => {
|
||||||
pub fn $fn_name(&self) -> Result<Self> {
|
pub fn $fn_name(&self) -> Result<Self> {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
|
if shape.elem_count() == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||||
@ -92,6 +95,9 @@ macro_rules! binary_op {
|
|||||||
($fn_name:ident, $op_name:ident) => {
|
($fn_name:ident, $op_name:ident) => {
|
||||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||||
|
if shape.elem_count() == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||||
&*rhs.storage(),
|
&*rhs.storage(),
|
||||||
self.layout(),
|
self.layout(),
|
||||||
@ -114,6 +120,9 @@ macro_rules! binary_op_scalar {
|
|||||||
.broadcast_as(self.shape())?,
|
.broadcast_as(self.shape())?,
|
||||||
};
|
};
|
||||||
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
||||||
|
if self.elem_count() == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||||
&*rhs.storage(),
|
&*rhs.storage(),
|
||||||
self.layout(),
|
self.layout(),
|
||||||
@ -447,7 +456,15 @@ impl Tensor {
|
|||||||
shape: S,
|
shape: S,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Self::new_impl(array, shape.into(), device, false)
|
let shape = shape.into();
|
||||||
|
let n: usize = shape.elem_count();
|
||||||
|
let buffer_size: usize = array.len();
|
||||||
|
if buffer_size != n {
|
||||||
|
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||||
|
}
|
||||||
|
let storage = device.storage_from_slice(array)?;
|
||||||
|
let none = BackpropOp::none();
|
||||||
|
Ok(from_storage(storage, shape, none, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||||
@ -510,6 +527,7 @@ impl Tensor {
|
|||||||
unary_op!(ceil, Ceil);
|
unary_op!(ceil, Ceil);
|
||||||
unary_op!(floor, Floor);
|
unary_op!(floor, Floor);
|
||||||
unary_op!(round, Round);
|
unary_op!(round, Round);
|
||||||
|
unary_op!(sign, Sign);
|
||||||
|
|
||||||
/// Round element of the input tensor to the nearest integer.
|
/// Round element of the input tensor to the nearest integer.
|
||||||
///
|
///
|
||||||
@ -645,6 +663,9 @@ impl Tensor {
|
|||||||
/// # Ok::<(), candle_core::Error>(())
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
/// ```
|
/// ```
|
||||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||||
|
if self.elem_count() == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let storage = self.storage().affine(self.layout(), mul, add)?;
|
let storage = self.storage().affine(self.layout(), mul, add)?;
|
||||||
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
|
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
@ -652,6 +673,9 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||||
|
if self.elem_count() == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let storage = self.storage().elu(self.layout(), alpha)?;
|
let storage = self.storage().elu(self.layout(), alpha)?;
|
||||||
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
|
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
@ -659,6 +683,9 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Raise the tensor to some float exponent `e`.
|
/// Raise the tensor to some float exponent `e`.
|
||||||
pub fn powf(&self, e: f64) -> Result<Self> {
|
pub fn powf(&self, e: f64) -> Result<Self> {
|
||||||
|
if self.elem_count() == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let storage = self.storage().powf(self.layout(), e)?;
|
let storage = self.storage().powf(self.layout(), e)?;
|
||||||
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
|
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
@ -1153,6 +1180,9 @@ impl Tensor {
|
|||||||
let n = b_dims[dim - 1];
|
let n = b_dims[dim - 1];
|
||||||
|
|
||||||
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||||
|
if c_shape.elem_count() == 0 || k == 0 {
|
||||||
|
return Tensor::zeros(c_shape, self.dtype(), self.device());
|
||||||
|
}
|
||||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||||
let batching_b: usize = b_dims[..dim - 2].iter().product();
|
let batching_b: usize = b_dims[..dim - 2].iter().product();
|
||||||
if k != k2 || batching != batching_b {
|
if k != k2 || batching != batching_b {
|
||||||
@ -2007,6 +2037,16 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a tensor that is in row major order. This always makes a copy.
|
||||||
|
pub fn force_contiguous(&self) -> Result<Tensor> {
|
||||||
|
let shape = self.shape();
|
||||||
|
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||||
|
self.storage()
|
||||||
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
|
let op = BackpropOp::new1(self, Op::Copy);
|
||||||
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
||||||
/// copied.
|
/// copied.
|
||||||
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
||||||
|
@ -58,20 +58,18 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if dim == 0 {
|
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
||||||
|
if all_contiguous {
|
||||||
|
Self::cat_contiguous(args, dim)
|
||||||
|
} else if dim == 0 {
|
||||||
Self::cat0(args)
|
Self::cat0(args)
|
||||||
} else {
|
} else {
|
||||||
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
let args: Vec<Tensor> = args
|
||||||
if all_contiguous {
|
.iter()
|
||||||
Self::cat_contiguous(args, dim)
|
.map(|a| a.as_ref().transpose(0, dim))
|
||||||
} else {
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let args: Vec<Tensor> = args
|
let cat = Self::cat0(&args)?;
|
||||||
.iter()
|
cat.transpose(0, dim)
|
||||||
.map(|a| a.as_ref().transpose(0, dim))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let cat = Self::cat0(&args)?;
|
|
||||||
cat.transpose(0, dim)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,4 +235,66 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the values on `self` using values from `src`. The copy starts at the specified
|
||||||
|
/// `offset` for the target dimension `dim` on `self`.
|
||||||
|
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||||
|
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||||
|
///
|
||||||
|
/// Note that this modifies `self` in place and as such is not compatibel with
|
||||||
|
/// back-propagation.
|
||||||
|
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||||
|
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||||
|
if !self.is_contiguous() || !src.is_contiguous() {
|
||||||
|
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
|
||||||
|
}
|
||||||
|
if self.dtype() != src.dtype() {
|
||||||
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: self.dtype(),
|
||||||
|
rhs: src.dtype(),
|
||||||
|
op: "slice-set",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if self.device().location() != src.device().location() {
|
||||||
|
Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: self.device().location(),
|
||||||
|
rhs: src.device().location(),
|
||||||
|
op: "slice-set",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if self.rank() != src.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: self.rank(),
|
||||||
|
got: src.rank(),
|
||||||
|
shape: self.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
|
||||||
|
if dim_idx == dim && *v2 + offset > *v1 {
|
||||||
|
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
|
||||||
|
}
|
||||||
|
if dim_idx != dim && v1 != v2 {
|
||||||
|
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let block_size: usize = src.dims().iter().skip(1 + dim).product();
|
||||||
|
let d1: usize = src.dims().iter().take(dim).product();
|
||||||
|
let d2 = block_size * src.dims()[dim];
|
||||||
|
let dst_o = self.layout().start_offset() + offset * block_size;
|
||||||
|
let src_o = src.layout().start_offset();
|
||||||
|
src.storage().copy2d(
|
||||||
|
&mut self.storage_mut(),
|
||||||
|
d1,
|
||||||
|
d2,
|
||||||
|
/* src_s */ d2,
|
||||||
|
/* dst_s */ block_size * self.dims()[dim],
|
||||||
|
src_o,
|
||||||
|
dst_o,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,9 +34,14 @@ impl Var {
|
|||||||
Ok(Self(inner))
|
Ok(Self(inner))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
|
||||||
pub fn from_tensor(t: &Tensor) -> Result<Self> {
|
pub fn from_tensor(t: &Tensor) -> Result<Self> {
|
||||||
let inner = t.make_var()?;
|
if t.is_variable() {
|
||||||
Ok(Self(inner))
|
Ok(Self(t.clone()))
|
||||||
|
} else {
|
||||||
|
let inner = t.make_var()?;
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rand_f64<S: Into<Shape>>(
|
pub fn rand_f64<S: Into<Shape>>(
|
||||||
|
106
candle-core/tests/matmul_tests.rs
Normal file
106
candle-core/tests/matmul_tests.rs
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
||||||
|
|
||||||
|
fn matmul(device: &Device) -> Result<()> {
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||||
|
|
||||||
|
let data = vec![1.0f32, 2.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||||
|
let data = vec![3.0f32, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||||
|
|
||||||
|
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||||
|
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
||||||
|
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||||
|
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||||
|
|
||||||
|
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
||||||
|
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||||
|
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
||||||
|
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
||||||
|
|
||||||
|
// Also perform the matmul on contiguous transposed versions.
|
||||||
|
let a_tt = a.t()?.contiguous()?.t()?;
|
||||||
|
assert!(!a_tt.is_contiguous());
|
||||||
|
assert_eq!(a.dims(), a_tt.dims());
|
||||||
|
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||||
|
|
||||||
|
let b_tt = b.t()?.contiguous()?.t()?;
|
||||||
|
assert!(!b_tt.is_contiguous());
|
||||||
|
assert_eq!(b.dims(), b_tt.dims());
|
||||||
|
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||||
|
|
||||||
|
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||||
|
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||||
|
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||||
|
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||||
|
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||||
|
let out = lhs.broadcast_matmul(&rhs)?;
|
||||||
|
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||||
|
for idx1 in 0..3 {
|
||||||
|
for idx2 in 0..6 {
|
||||||
|
let out = out.i((idx1, idx2))?;
|
||||||
|
let lhs = lhs.i((idx1, 0))?;
|
||||||
|
let rhs = rhs.i(idx2)?;
|
||||||
|
let out2 = lhs.matmul(&rhs);
|
||||||
|
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||||
|
// With cuda, we see errors of up to ~1e-12.
|
||||||
|
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/candle/issues/1948
|
||||||
|
fn squeeze_mm(device: &Device) -> Result<()> {
|
||||||
|
let seq_len = 8_usize;
|
||||||
|
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
|
||||||
|
let x = a.i((.., seq_len - 1, ..))?;
|
||||||
|
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
|
||||||
|
let x = x.matmul(&w)?;
|
||||||
|
assert_eq!(x.dims(), &[1, 32]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/candle/issues/1992
|
||||||
|
fn mm_layout(device: &Device) -> Result<()> {
|
||||||
|
let a = Tensor::arange(0f32, 16f32, device)?.reshape((1, 1, 4, 4))?;
|
||||||
|
let b = Tensor::arange(0f32, 8f32, device)?.reshape((1, 1, 4, 2))?;
|
||||||
|
let mm1 = a.matmul(&b)?;
|
||||||
|
// Forces the layout to be:
|
||||||
|
// shape: [1, 1, 4, 2], stride: [8, 2, 2, 1], start_offset: 0
|
||||||
|
// This is still a contiguous matrix but matmul checks are only the two last dimensions have
|
||||||
|
// non 1 sizes but matmul check may be reluctant to handle it.
|
||||||
|
let b = b.transpose(1, 2)?.force_contiguous()?.transpose(1, 2)?;
|
||||||
|
let mm2 = a.matmul(&b)?;
|
||||||
|
let diff = (mm1 - mm2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||||
|
test_device!(
|
||||||
|
broadcast_matmul,
|
||||||
|
broadcast_matmul_cpu,
|
||||||
|
broadcast_matmul_gpu,
|
||||||
|
broadcast_matmul_metal
|
||||||
|
);
|
||||||
|
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
|
||||||
|
test_device!(mm_layout, mm_layout_cpu, mm_layout_gpu, mm_layout_metal);
|
@ -3,7 +3,7 @@ use candle_core::{
|
|||||||
quantized::{self, GgmlDType},
|
quantized::{self, GgmlDType},
|
||||||
test_device,
|
test_device,
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, Module, Result, Tensor,
|
DType, Device, IndexOp, Module, Result, Tensor,
|
||||||
};
|
};
|
||||||
use quantized::{k_quants, GgmlType};
|
use quantized::{k_quants, GgmlType};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
@ -47,18 +47,14 @@ fn test_matmul(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantized_matmul(device: &Device) -> Result<()> {
|
fn quantized_matmul(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let (m, k, n) = (3, 64, 4);
|
let (m, k, n) = (3, 64, 4);
|
||||||
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
let lhs_s = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
|
||||||
let mut dst = vec![42.; 3 * 4];
|
let mut dst = vec![42.; 3 * 4];
|
||||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||||
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||||
&[
|
&[
|
||||||
@ -67,7 +63,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
let mm = lhs.matmul(&tensor_rhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
mm.to_vec2::<f32>()?,
|
mm.to_vec2::<f32>()?,
|
||||||
&[
|
&[
|
||||||
@ -79,7 +75,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
|
|
||||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&lhs)?;
|
||||||
match device {
|
match device {
|
||||||
Device::Metal(_) => assert_eq!(
|
Device::Metal(_) => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
@ -89,7 +85,15 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
_ => assert_eq!(
|
Device::Cuda(_) => assert_eq!(
|
||||||
|
to_vec2_round(&res, 0)?,
|
||||||
|
&[
|
||||||
|
[84866.0, 214045.0, 344676.0, 473707.0],
|
||||||
|
[213425.0, 604313.0, 1000431.0, 1387960.0],
|
||||||
|
[342030.0, 994630.0, 1656248.0, 2302250.0]
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Device::Cpu => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
&[
|
&[
|
||||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||||
@ -98,22 +102,16 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let (m, k, n) = (3, 64, 4);
|
let (m, k, n) = (3, 64, 4);
|
||||||
let lhs = (0..(m * k))
|
let lhs_s = (0..(m * k))
|
||||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
|
||||||
let mut dst = vec![42.; 3 * 4];
|
let mut dst = vec![42.; 3 * 4];
|
||||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||||
let rhs = (0..k * n)
|
let rhs = (0..k * n)
|
||||||
@ -121,7 +119,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||||
&[
|
&[
|
||||||
@ -129,7 +127,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
-196472.0, 63012.0, 324585.0, 587902.0
|
-196472.0, 63012.0, 324585.0, 587902.0
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
let mm = lhs.matmul(&tensor_rhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
to_vec2_round(&mm, 0)?,
|
to_vec2_round(&mm, 0)?,
|
||||||
&[
|
&[
|
||||||
@ -141,7 +139,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
|
|
||||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&lhs)?;
|
||||||
match device {
|
match device {
|
||||||
Device::Metal(_) => assert_eq!(
|
Device::Metal(_) => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
@ -151,7 +149,15 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
[-196102.0, 63022.0, 324233.0, 587191.0]
|
[-196102.0, 63022.0, 324233.0, 587191.0]
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
_ => assert_eq!(
|
Device::Cuda(_) => assert_eq!(
|
||||||
|
to_vec2_round(&res, 0)?,
|
||||||
|
&[
|
||||||
|
[243740.0, -19762.0, -285476.0, -550498.0],
|
||||||
|
[23774.0, 21645.0, 19395.0, 18364.0],
|
||||||
|
[-196045.0, 63030.0, 324120.0, 587079.0]
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Device::Cpu => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
&[
|
&[
|
||||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||||
@ -160,22 +166,58 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
let lhs2 = Tensor::stack(&[&lhs, &lhs], 0)?;
|
||||||
|
let res2 = matmul.forward(&lhs2)?;
|
||||||
|
let res2 = res2.i(1)?;
|
||||||
|
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
if device.is_cuda() {
|
||||||
|
assert!(diff < 0.1);
|
||||||
|
} else {
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(
|
fn qmm_batch(dev: &Device) -> Result<()> {
|
||||||
quantized_matmul,
|
let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?;
|
||||||
quantized_matmul_cpu,
|
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
|
||||||
quantized_matmul_cuda,
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
quantized_matmul_metal
|
let mm = rhs.forward(&lhs)?;
|
||||||
);
|
assert_eq!(mm.shape().dims(), [2, 6]);
|
||||||
test_device!(
|
let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?;
|
||||||
quantized_matmul_neg,
|
let mm2 = rhs.forward(&lhs2)?;
|
||||||
quantized_matmul_neg_cpu,
|
assert_eq!(mm2.shape().dims(), [4, 6]);
|
||||||
quantized_matmul_neg_cuda,
|
let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
quantized_matmul_neg_metal
|
assert_eq!(diff2, 0.0);
|
||||||
);
|
let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?;
|
||||||
|
let mm3 = rhs.forward(&lhs3)?;
|
||||||
|
assert_eq!(mm3.shape().dims(), [6, 6]);
|
||||||
|
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff3, 0.0);
|
||||||
|
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff3, 0.0);
|
||||||
|
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
|
||||||
|
let mm4 = rhs.forward(&lhs4)?;
|
||||||
|
assert_eq!(mm4.shape().dims(), [12, 6]);
|
||||||
|
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
if dev.is_cuda() {
|
||||||
|
// We use a different kernel for sizes from 1 to 8 on cuda which explains
|
||||||
|
// the difference here.
|
||||||
|
assert!(0. < diff4 && diff4 < 1e-4)
|
||||||
|
} else {
|
||||||
|
assert_eq!(diff4, 0.0)
|
||||||
|
};
|
||||||
|
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff4, 0.0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
|
||||||
|
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
|
||||||
|
test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
|
||||||
|
|
||||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
@ -183,6 +225,13 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
|||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst.to_vec1::<f32>()?,
|
dst.to_vec1::<f32>()?,
|
||||||
&[
|
&[
|
||||||
@ -209,6 +258,13 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
|||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst.to_vec1::<f32>()?),
|
round_vector(&dst.to_vec1::<f32>()?),
|
||||||
&[
|
&[
|
||||||
@ -235,6 +291,13 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
|||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst.to_vec1::<f32>()?),
|
round_vector(&dst.to_vec1::<f32>()?),
|
||||||
&[
|
&[
|
||||||
@ -261,6 +324,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
|
|||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst.to_vec1::<f32>()?),
|
round_vector(&dst.to_vec1::<f32>()?),
|
||||||
&[
|
&[
|
||||||
@ -345,6 +415,13 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
|||||||
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
||||||
if error > max_error {
|
if error > max_error {
|
||||||
bail!(
|
bail!(
|
||||||
@ -362,6 +439,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
|||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -381,6 +465,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -395,6 +486,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
|||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -414,6 +512,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -428,6 +533,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
|||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -447,6 +559,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -461,6 +580,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
|||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -480,6 +606,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -494,6 +627,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
|||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -513,6 +653,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -527,6 +674,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
|||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -546,6 +700,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
|
@ -1,5 +1,31 @@
|
|||||||
use candle_core::{DType, Result, Tensor};
|
use candle_core::{DType, Result, Tensor};
|
||||||
|
|
||||||
|
struct TmpFile(std::path::PathBuf);
|
||||||
|
|
||||||
|
impl TmpFile {
|
||||||
|
fn create(base: &str) -> TmpFile {
|
||||||
|
let filename = std::env::temp_dir().join(format!(
|
||||||
|
"candle-{}-{}-{:?}",
|
||||||
|
base,
|
||||||
|
std::process::id(),
|
||||||
|
std::thread::current().id(),
|
||||||
|
));
|
||||||
|
TmpFile(filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::AsRef<std::path::Path> for TmpFile {
|
||||||
|
fn as_ref(&self) -> &std::path::Path {
|
||||||
|
self.0.as_path()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for TmpFile {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
std::fs::remove_file(&self.0).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn npy() -> Result<()> {
|
fn npy() -> Result<()> {
|
||||||
let npy = Tensor::read_npy("tests/test.npy")?;
|
let npy = Tensor::read_npy("tests/test.npy")?;
|
||||||
@ -22,3 +48,24 @@ fn npz() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn safetensors() -> Result<()> {
|
||||||
|
use candle_core::safetensors::Load;
|
||||||
|
|
||||||
|
let tmp_file = TmpFile::create("st");
|
||||||
|
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
|
||||||
|
t.save_safetensors("t", &tmp_file)?;
|
||||||
|
// Load from file.
|
||||||
|
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
|
||||||
|
let t2 = st.get("t").unwrap();
|
||||||
|
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0f32);
|
||||||
|
// Load from bytes.
|
||||||
|
let bytes = std::fs::read(tmp_file)?;
|
||||||
|
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
|
||||||
|
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
|
||||||
|
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0f32);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -96,6 +96,40 @@ fn clamp(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn asort(device: &Device) -> Result<()> {
|
||||||
|
let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
|
||||||
|
let tensor = Tensor::new(data, device)?;
|
||||||
|
let indexes = tensor.arg_sort_last_dim(true)?;
|
||||||
|
assert_eq!(
|
||||||
|
indexes.to_vec2::<u32>()?,
|
||||||
|
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||||
|
);
|
||||||
|
let indexes = tensor.arg_sort_last_dim(false)?;
|
||||||
|
assert_eq!(
|
||||||
|
indexes.to_vec2::<u32>()?,
|
||||||
|
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||||
|
);
|
||||||
|
let (sorted, indexes) = tensor.sort_last_dim(true)?;
|
||||||
|
assert_eq!(
|
||||||
|
indexes.to_vec2::<u32>()?,
|
||||||
|
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
sorted.to_vec2::<f32>()?,
|
||||||
|
[[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]]
|
||||||
|
);
|
||||||
|
let (sorted, indexes) = tensor.sort_last_dim(false)?;
|
||||||
|
assert_eq!(
|
||||||
|
indexes.to_vec2::<u32>()?,
|
||||||
|
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
sorted.to_vec2::<f32>()?,
|
||||||
|
[[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn unary_op(device: &Device) -> Result<()> {
|
fn unary_op(device: &Device) -> Result<()> {
|
||||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
@ -106,6 +140,9 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
|
||||||
|
let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
|
||||||
|
assert!(max_diff.to_vec0::<f32>()? < 5e-3);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||||
[
|
[
|
||||||
@ -148,6 +185,14 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||||
[3000.0, 300.]
|
[3000.0, 300.]
|
||||||
);
|
);
|
||||||
|
let tensor = Tensor::new(
|
||||||
|
&[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
assert_eq!(
|
||||||
|
tensor.sign()?.to_vec1::<f32>()?,
|
||||||
|
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -620,6 +665,30 @@ fn broadcast(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn slice_set(device: &Device) -> Result<()> {
|
||||||
|
let (b, h, max_t, d) = (2, 4, 7, 3);
|
||||||
|
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
|
||||||
|
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
|
||||||
|
cache.slice_set(&tensor, 2, 0)?;
|
||||||
|
let cache_t = cache.narrow(2, 0, 4)?;
|
||||||
|
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
cache.slice_set(&tensor, 2, 1)?;
|
||||||
|
let cache_t = cache.narrow(2, 1, 4)?;
|
||||||
|
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
|
||||||
|
cache.slice_set(&ones, 2, 6)?;
|
||||||
|
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
let diff = (cache.narrow(2, 6, 1)? - 1.)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn cat(device: &Device) -> Result<()> {
|
fn cat(device: &Device) -> Result<()> {
|
||||||
// 1D
|
// 1D
|
||||||
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||||
@ -707,6 +776,8 @@ fn embeddings(device: &Device) -> Result<()> {
|
|||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
let hs = t.index_select(&ids, 0)?;
|
let hs = t.index_select(&ids, 0)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
|
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
||||||
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -734,44 +805,47 @@ fn index_select(device: &Device) -> Result<()> {
|
|||||||
[9.0, 10.0, 11.0]
|
[9.0, 10.0, 11.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
let hs = t.index_select(&ids, 1)?;
|
for dtype in [DType::U8, DType::U32, DType::I64] {
|
||||||
assert_eq!(
|
let ids = ids.to_dtype(dtype)?;
|
||||||
hs.to_vec2::<f32>()?,
|
let hs = t.index_select(&ids, 1)?;
|
||||||
&[
|
assert_eq!(
|
||||||
[0.0, 2.0, 1.0],
|
hs.to_vec2::<f32>()?,
|
||||||
[3.0, 5.0, 4.0],
|
&[
|
||||||
[6.0, 8.0, 7.0],
|
[0.0, 2.0, 1.0],
|
||||||
[9.0, 11.0, 10.0]
|
[3.0, 5.0, 4.0],
|
||||||
]
|
[6.0, 8.0, 7.0],
|
||||||
);
|
[9.0, 11.0, 10.0]
|
||||||
let hs = t.index_select(&ids, 0)?;
|
]
|
||||||
assert_eq!(
|
);
|
||||||
hs.to_vec2::<f32>()?,
|
let hs = t.index_select(&ids, 0)?;
|
||||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
assert_eq!(
|
||||||
);
|
hs.to_vec2::<f32>()?,
|
||||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
);
|
||||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||||
let hs = t.index_select(&ids, 0)?;
|
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||||
assert_eq!(
|
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||||
hs.to_vec2::<f32>()?,
|
let hs = t.index_select(&ids, 0)?;
|
||||||
&[
|
assert_eq!(
|
||||||
[0.0, 1.0, 2.0],
|
hs.to_vec2::<f32>()?,
|
||||||
[6.0, 7.0, 8.0],
|
&[
|
||||||
[3.0, 4.0, 5.0],
|
[0.0, 1.0, 2.0],
|
||||||
[0.0, 1.0, 2.0],
|
[6.0, 7.0, 8.0],
|
||||||
[6.0, 7.0, 8.0],
|
[3.0, 4.0, 5.0],
|
||||||
[3.0, 4.0, 5.0],
|
[0.0, 1.0, 2.0],
|
||||||
]
|
[6.0, 7.0, 8.0],
|
||||||
);
|
[3.0, 4.0, 5.0],
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
// Test when selecting dim > 0 with ids size different from elem count of
|
// Test when selecting dim > 0 with ids size different from elem count of
|
||||||
// target dim in source/input.
|
// target dim in source/input.
|
||||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||||
let hs = t.index_select(&ids, 1)?;
|
let hs = t.index_select(&ids, 1)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -933,74 +1007,6 @@ fn gather(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matmul(device: &Device) -> Result<()> {
|
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
|
||||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
|
||||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
|
||||||
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
|
||||||
|
|
||||||
let data = vec![1.0f32, 2.0];
|
|
||||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
|
||||||
let data = vec![3.0f32, 4.0];
|
|
||||||
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
|
||||||
|
|
||||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
|
||||||
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
|
||||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
|
||||||
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
|
||||||
|
|
||||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
|
||||||
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
|
||||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
|
||||||
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
|
||||||
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
|
||||||
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
|
||||||
|
|
||||||
// Also perform the matmul on contiguous transposed versions.
|
|
||||||
let a_tt = a.t()?.contiguous()?.t()?;
|
|
||||||
assert!(!a_tt.is_contiguous());
|
|
||||||
assert_eq!(a.dims(), a_tt.dims());
|
|
||||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
|
||||||
|
|
||||||
let b_tt = b.t()?.contiguous()?.t()?;
|
|
||||||
assert!(!b_tt.is_contiguous());
|
|
||||||
assert_eq!(b.dims(), b_tt.dims());
|
|
||||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
|
||||||
|
|
||||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
|
||||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
|
||||||
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
|
||||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
|
||||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
|
||||||
let out = lhs.broadcast_matmul(&rhs)?;
|
|
||||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
|
||||||
for idx1 in 0..3 {
|
|
||||||
for idx2 in 0..6 {
|
|
||||||
let out = out.i((idx1, idx2))?;
|
|
||||||
let lhs = lhs.i((idx1, 0))?;
|
|
||||||
let rhs = rhs.i(idx2)?;
|
|
||||||
let out2 = lhs.matmul(&rhs);
|
|
||||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
|
||||||
// With cuda, we see errors of up to ~1e-12.
|
|
||||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn broadcasting(device: &Device) -> Result<()> {
|
fn broadcasting(device: &Device) -> Result<()> {
|
||||||
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
||||||
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
||||||
@ -1135,6 +1141,27 @@ fn randn(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn zero_dim(device: &Device) -> Result<()> {
|
||||||
|
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
|
||||||
|
assert_eq!(t.dims3()?, (4, 0, 1));
|
||||||
|
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
|
||||||
|
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
|
||||||
|
assert_eq!(t_cat.dims3()?, (4, 3, 1));
|
||||||
|
let t_cat = Tensor::cat(&[&t, &t], 1)?;
|
||||||
|
assert_eq!(t_cat.dims3()?, (4, 0, 1));
|
||||||
|
let t_unary = t.sqrt()?;
|
||||||
|
assert_eq!(t_unary.dims3()?, (4, 0, 1));
|
||||||
|
let t_plus = (&t + 1.)?;
|
||||||
|
assert_eq!(t_plus.dims3()?, (4, 0, 1));
|
||||||
|
let t_mm = t2.matmul(&t.t()?)?;
|
||||||
|
assert_eq!(t_mm.dims3()?, (4, 3, 0));
|
||||||
|
let t_mm = t.matmul(&t2.t()?)?;
|
||||||
|
assert_eq!(t_mm.dims3()?, (4, 0, 3));
|
||||||
|
let t_mm = t.t()?.matmul(&t)?;
|
||||||
|
assert_eq!(t_mm.dims3()?, (4, 1, 1));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||||
@ -1143,6 +1170,7 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
|||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||||
|
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
|
||||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||||
@ -1154,13 +1182,6 @@ test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
|||||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
|
||||||
test_device!(
|
|
||||||
broadcast_matmul,
|
|
||||||
broadcast_matmul_cpu,
|
|
||||||
broadcast_matmul_gpu,
|
|
||||||
broadcast_matmul_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
test_device!(
|
||||||
broadcasting,
|
broadcasting,
|
||||||
broadcasting_cpu,
|
broadcasting_cpu,
|
||||||
@ -1189,7 +1210,9 @@ test_device!(
|
|||||||
);
|
);
|
||||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||||
|
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
|
||||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||||
|
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
@ -1317,8 +1340,8 @@ fn pow() -> Result<()> {
|
|||||||
let rhs = (&lhs - 2.)?;
|
let rhs = (&lhs - 2.)?;
|
||||||
let res = lhs.pow(&rhs)?;
|
let res = lhs.pow(&rhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&res, 4)?,
|
test_utils::to_vec2_round(&res, 3)?,
|
||||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]]
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -89,7 +89,7 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
|||||||
|
|
||||||
pub fn load() -> Result<crate::vision::Dataset> {
|
pub fn load() -> Result<crate::vision::Dataset> {
|
||||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
let dataset_id = "mnist".to_string();
|
let dataset_id = "ylecun/mnist".to_string();
|
||||||
let repo = Repo::with_revision(
|
let repo = Repo::with_revision(
|
||||||
dataset_id,
|
dataset_id,
|
||||||
RepoType::Dataset,
|
RepoType::Dataset,
|
||||||
|
@ -25,7 +25,9 @@ hf-hub = { workspace = true, features = ["tokio"] }
|
|||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
palette = { version = "0.7.6", optional = true }
|
||||||
|
enterpolation = { version = "0.2.1", optional = true}
|
||||||
|
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
rubato = { version = "0.15.0", optional = true }
|
rubato = { version = "0.15.0", optional = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
@ -65,6 +67,7 @@ onnx = ["candle-onnx"]
|
|||||||
metal = ["candle/metal", "candle-nn/metal"]
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
microphone = ["cpal"]
|
microphone = ["cpal"]
|
||||||
encodec = ["cpal", "symphonia", "rubato"]
|
encodec = ["cpal", "symphonia", "rubato"]
|
||||||
|
depth_anything_v2 = ["palette", "enterpolation"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
@ -101,3 +104,7 @@ required-features = ["candle-datasets"]
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "encodec"
|
name = "encodec"
|
||||||
required-features = ["encodec"]
|
required-features = ["encodec"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "depth_anything_v2"
|
||||||
|
required-features = ["depth_anything_v2"]
|
||||||
|
46
candle-examples/examples/clip/README.md
Normal file
46
candle-examples/examples/clip/README.md
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
Contrastive Language-Image Pre-Training
|
||||||
|
|
||||||
|
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||||
|
pairs of images with related texts.
|
||||||
|
|
||||||
|
https://github.com/openai/CLIP
|
||||||
|
|
||||||
|
https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||||
|
|
||||||
|
## Running on an example on cpu
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||||
|
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
|
||||||
|
INFO clip: Probability: 0.0000% Text: a cycling race
|
||||||
|
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
||||||
|
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
INFO clip: Probability: 99.9999% Text: a cycling race
|
||||||
|
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
||||||
|
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running on an example with metal feature (mac)
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||||
|
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
|
||||||
|
INFO clip: Probability: 0.0000% Text: a cycling race
|
||||||
|
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
||||||
|
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
INFO clip: Probability: 99.9999% Text: a cycling race
|
||||||
|
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
||||||
|
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
||||||
|
```
|
202
candle-examples/examples/clip/main.rs
Normal file
202
candle-examples/examples/clip/main.rs
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
|
use candle_transformers::models::clip;
|
||||||
|
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, use_value_delimiter = true)]
|
||||||
|
images: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(long, use_value_delimiter = true)]
|
||||||
|
sequences: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
||||||
|
let img = image::io::Reader::open(path)?.decode()?;
|
||||||
|
let (height, width) = (image_size, image_size);
|
||||||
|
let img = img.resize_to_fill(
|
||||||
|
width as u32,
|
||||||
|
height as u32,
|
||||||
|
image::imageops::FilterType::Triangle,
|
||||||
|
);
|
||||||
|
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
|
||||||
|
let img = img.into_raw();
|
||||||
|
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
||||||
|
.permute((2, 0, 1))?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.affine(2. / 255., -1.)?;
|
||||||
|
// .unsqueeze(0)?;
|
||||||
|
Ok(img)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_images<T: AsRef<std::path::Path>>(
|
||||||
|
paths: &Vec<T>,
|
||||||
|
image_size: usize,
|
||||||
|
) -> anyhow::Result<Tensor> {
|
||||||
|
let mut images = vec![];
|
||||||
|
|
||||||
|
for path in paths {
|
||||||
|
let tensor = load_image(path, image_size)?;
|
||||||
|
images.push(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
let images = Tensor::stack(&images, 0)?;
|
||||||
|
|
||||||
|
Ok(images)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
// std::env::set_var("RUST_BACKTRACE", "full");
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
|
||||||
|
let api = api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"openai/clip-vit-base-patch32".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/15".to_string(),
|
||||||
|
));
|
||||||
|
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokenizer = get_tokenizer(args.tokenizer)?;
|
||||||
|
|
||||||
|
let config = clip::ClipConfig::vit_base_patch32();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let vec_imgs = match args.images {
|
||||||
|
Some(imgs) => imgs,
|
||||||
|
None => vec![
|
||||||
|
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
|
||||||
|
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
|
||||||
|
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
||||||
|
|
||||||
|
let vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
||||||
|
|
||||||
|
let model = clip::ClipModel::new(vb, &config)?;
|
||||||
|
|
||||||
|
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
||||||
|
|
||||||
|
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||||
|
|
||||||
|
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||||
|
|
||||||
|
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
info!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||||
|
|
||||||
|
let probability_vec = softmax_image_vec
|
||||||
|
.iter()
|
||||||
|
.map(|v| v * 100.0)
|
||||||
|
.collect::<Vec<f32>>();
|
||||||
|
|
||||||
|
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||||
|
|
||||||
|
for (i, img) in vec_imgs.iter().enumerate() {
|
||||||
|
let start = i * probability_per_image;
|
||||||
|
let end = start + probability_per_image;
|
||||||
|
let prob = &probability_vec[start..end];
|
||||||
|
info!("\n\nResults for image: {}\n", img);
|
||||||
|
|
||||||
|
for (i, p) in prob.iter().enumerate() {
|
||||||
|
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||||
|
let tokenizer = match tokenizer {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"openai/clip-vit-base-patch32".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/15".to_string(),
|
||||||
|
));
|
||||||
|
api.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
Some(file) => file.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Tokenizer::from_file(tokenizer).map_err(E::msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tokenize_sequences(
|
||||||
|
sequences: Option<Vec<String>>,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
device: &Device,
|
||||||
|
) -> anyhow::Result<(Tensor, Vec<String>)> {
|
||||||
|
let pad_id = *tokenizer
|
||||||
|
.get_vocab(true)
|
||||||
|
.get("<|endoftext|>")
|
||||||
|
.ok_or(E::msg("No pad token"))?;
|
||||||
|
|
||||||
|
let vec_seq = match sequences {
|
||||||
|
Some(seq) => seq,
|
||||||
|
None => vec![
|
||||||
|
"a cycling race".to_string(),
|
||||||
|
"a photo of two cats".to_string(),
|
||||||
|
"a robot holding a candle".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut tokens = vec![];
|
||||||
|
|
||||||
|
for seq in vec_seq.clone() {
|
||||||
|
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||||
|
tokens.push(encoding.get_ids().to_vec());
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
|
||||||
|
|
||||||
|
// Pad the sequences to have the same length
|
||||||
|
for token_vec in tokens.iter_mut() {
|
||||||
|
let len_diff = max_len - token_vec.len();
|
||||||
|
if len_diff > 0 {
|
||||||
|
token_vec.extend(vec![pad_id; len_diff]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let input_ids = Tensor::new(tokens, device)?;
|
||||||
|
|
||||||
|
Ok((input_ids, vec_seq))
|
||||||
|
}
|
13
candle-examples/examples/depth_anything_v2/README.md
Normal file
13
candle-examples/examples/depth_anything_v2/README.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# candle-dinov2
|
||||||
|
|
||||||
|
[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which
|
||||||
|
builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer.
|
||||||
|
|
||||||
|
This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it.
|
||||||
|
|
||||||
|
## Running an example with color map and CUDA
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
```
|
||||||
|
|
50
candle-examples/examples/depth_anything_v2/color_map.rs
Normal file
50
candle-examples/examples/depth_anything_v2/color_map.rs
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
use enterpolation::linear::ConstEquidistantLinear;
|
||||||
|
use enterpolation::Generator;
|
||||||
|
use palette::LinSrgb;
|
||||||
|
|
||||||
|
use candle::Tensor;
|
||||||
|
|
||||||
|
pub struct SpectralRColormap {
|
||||||
|
gradient: ConstEquidistantLinear<f32, LinSrgb, 9>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SpectralRColormap {
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
// Define a colormap similar to 'Spectral_r' by specifying key colors.
|
||||||
|
// got the colors from ChatGPT-4o
|
||||||
|
let gradient = ConstEquidistantLinear::<f32, _, 9>::equidistant_unchecked([
|
||||||
|
LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue
|
||||||
|
LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue
|
||||||
|
LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan
|
||||||
|
LinSrgb::new(0.6706, 0.8667, 0.6431), // Green
|
||||||
|
LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow
|
||||||
|
LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange
|
||||||
|
LinSrgb::new(0.9922, 0.6824, 0.3804), // Red
|
||||||
|
LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red
|
||||||
|
LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple
|
||||||
|
]);
|
||||||
|
Self { gradient }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_color(&self, value: f32) -> LinSrgb {
|
||||||
|
self.gradient.gen(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gray2color(&self, gray: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
println!("Gray: {:?}", gray.dims());
|
||||||
|
let gray_values: Vec<f32> = gray.flatten_all()?.to_vec1()?;
|
||||||
|
let rgb_values: Vec<f32> = gray_values
|
||||||
|
.iter()
|
||||||
|
.map(|g| self.get_color(*g))
|
||||||
|
.flat_map(|rgb| [rgb.red, rgb.green, rgb.blue])
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let [.., height, width] = gray.dims() else {
|
||||||
|
candle::bail!("Not enough dims!")
|
||||||
|
};
|
||||||
|
|
||||||
|
let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?;
|
||||||
|
|
||||||
|
color.permute((2, 0, 1))
|
||||||
|
}
|
||||||
|
}
|
187
candle-examples/examples/depth_anything_v2/main.rs
Normal file
187
candle-examples/examples/depth_anything_v2/main.rs
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
//! Depth Anything V2
|
||||||
|
//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use std::ffi::OsString;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::DType::{F32, U8};
|
||||||
|
use candle::{DType, Device, Module, Result, Tensor};
|
||||||
|
use candle_examples::{load_image, load_image_and_resize, save_image};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config};
|
||||||
|
use candle_transformers::models::dinov2;
|
||||||
|
|
||||||
|
use crate::color_map::SpectralRColormap;
|
||||||
|
|
||||||
|
mod color_map;
|
||||||
|
|
||||||
|
// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207
|
||||||
|
const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
|
||||||
|
const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225];
|
||||||
|
|
||||||
|
const DINO_IMG_SIZE: usize = 518;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
dinov2_model: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
depth_anything_v2_model: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: PathBuf,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
output_dir: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
color_map: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let dinov2_model_file = match args.dinov2_model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("lmz/candle-dino-v2".into());
|
||||||
|
api.get("dinov2_vits14.safetensors")?
|
||||||
|
}
|
||||||
|
Some(dinov2_model) => dinov2_model,
|
||||||
|
};
|
||||||
|
println!("Using file {:?}", dinov2_model_file);
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? };
|
||||||
|
let dinov2 = dinov2::vit_small(vb)?;
|
||||||
|
println!("DinoV2 model built");
|
||||||
|
|
||||||
|
let depth_anything_model_file = match args.depth_anything_v2_model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into());
|
||||||
|
api.get("depth_anything_v2_vits.safetensors")?
|
||||||
|
}
|
||||||
|
Some(depth_anything_model) => depth_anything_model,
|
||||||
|
};
|
||||||
|
println!("Using file {:?}", depth_anything_model_file);
|
||||||
|
|
||||||
|
let vb = unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = DepthAnythingV2Config::vit_small();
|
||||||
|
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
|
||||||
|
|
||||||
|
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;
|
||||||
|
|
||||||
|
println!("Loaded image {image:?}");
|
||||||
|
|
||||||
|
let depth = depth_anything.forward(&image)?;
|
||||||
|
|
||||||
|
println!("Got predictions {:?}", depth.shape());
|
||||||
|
|
||||||
|
let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?;
|
||||||
|
|
||||||
|
let output_path = full_output_path(&args.image, &args.output_dir);
|
||||||
|
println!("Saving image to {}", output_path.to_string_lossy());
|
||||||
|
save_image(&output_image, output_path)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn full_output_path(image_path: &PathBuf, output_dir: &Option<PathBuf>) -> PathBuf {
|
||||||
|
let input_file_name = image_path.file_name().unwrap();
|
||||||
|
let mut output_file_name = OsString::from("depth_");
|
||||||
|
output_file_name.push(input_file_name);
|
||||||
|
let mut output_path = match output_dir {
|
||||||
|
None => image_path.parent().unwrap().to_path_buf(),
|
||||||
|
Some(output_path) => output_path.clone(),
|
||||||
|
};
|
||||||
|
output_path.push(output_file_name);
|
||||||
|
|
||||||
|
output_path
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_and_prep_image(
|
||||||
|
image_path: &PathBuf,
|
||||||
|
device: &Device,
|
||||||
|
) -> anyhow::Result<(usize, usize, Tensor)> {
|
||||||
|
let (_original_image, original_height, original_width) = load_image(&image_path, None)?;
|
||||||
|
|
||||||
|
let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)?
|
||||||
|
.unsqueeze(0)?
|
||||||
|
.to_dtype(F32)?
|
||||||
|
.to_device(&device)?;
|
||||||
|
|
||||||
|
let max_pixel_val = Tensor::try_from(255.0f32)?
|
||||||
|
.to_device(&device)?
|
||||||
|
.broadcast_as(image.shape())?;
|
||||||
|
let image = (image / max_pixel_val)?;
|
||||||
|
let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?;
|
||||||
|
|
||||||
|
Ok((original_height, original_width, image))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result<Tensor> {
|
||||||
|
let mean_tensor =
|
||||||
|
Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
|
||||||
|
let std_tensor =
|
||||||
|
Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
|
||||||
|
image.sub(&mean_tensor)?.div(&std_tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn post_process_image(
|
||||||
|
image: &Tensor,
|
||||||
|
original_height: usize,
|
||||||
|
original_width: usize,
|
||||||
|
color_map: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let out = image.interpolate2d(original_height, original_width)?;
|
||||||
|
let out = scale_image(&out)?;
|
||||||
|
|
||||||
|
let out = if color_map {
|
||||||
|
let spectral_r = SpectralRColormap::new();
|
||||||
|
spectral_r.gray2color(&out)?
|
||||||
|
} else {
|
||||||
|
let rgb_slice = [&out, &out, &out];
|
||||||
|
Tensor::cat(&rgb_slice, 0)?.squeeze(1)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let max_pixel_val = Tensor::try_from(255.0f32)?
|
||||||
|
.to_device(out.device())?
|
||||||
|
.broadcast_as(out.shape())?;
|
||||||
|
let out = (out * max_pixel_val)?;
|
||||||
|
|
||||||
|
out.to_dtype(U8)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scale_image(depth: &Tensor) -> Result<Tensor> {
|
||||||
|
let flat_values: Vec<f32> = depth.flatten_all()?.to_vec1()?;
|
||||||
|
|
||||||
|
let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap();
|
||||||
|
let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
|
||||||
|
|
||||||
|
let min_val_tensor = Tensor::try_from(*min_val)?
|
||||||
|
.to_device(depth.device())?
|
||||||
|
.broadcast_as(depth.shape())?;
|
||||||
|
let depth = (depth - min_val_tensor)?;
|
||||||
|
|
||||||
|
let range = max_val - min_val;
|
||||||
|
let range_tensor = Tensor::try_from(range)?
|
||||||
|
.to_device(depth.device())?
|
||||||
|
.broadcast_as(depth.shape())?;
|
||||||
|
|
||||||
|
depth / range_tensor
|
||||||
|
}
|
@ -16,6 +16,30 @@ use candle_transformers::generation::LogitsProcessor;
|
|||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "2b")]
|
||||||
|
Base2B,
|
||||||
|
#[value(name = "7b")]
|
||||||
|
Base7B,
|
||||||
|
#[value(name = "2b-it")]
|
||||||
|
Instruct2B,
|
||||||
|
#[value(name = "7b-it")]
|
||||||
|
Instruct7B,
|
||||||
|
#[value(name = "1.1-2b-it")]
|
||||||
|
InstructV1_1_2B,
|
||||||
|
#[value(name = "1.1-7b-it")]
|
||||||
|
InstructV1_1_7B,
|
||||||
|
#[value(name = "code-2b")]
|
||||||
|
CodeBase2B,
|
||||||
|
#[value(name = "code-7b")]
|
||||||
|
CodeBase7B,
|
||||||
|
#[value(name = "code-2b-it")]
|
||||||
|
CodeInstruct2B,
|
||||||
|
#[value(name = "code-7b-it")]
|
||||||
|
CodeInstruct7B,
|
||||||
|
}
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
model: Model,
|
model: Model,
|
||||||
device: Device,
|
device: Device,
|
||||||
@ -165,6 +189,13 @@ struct Args {
|
|||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 64)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model to use.
|
||||||
|
#[arg(long, default_value = "2b")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -196,14 +227,19 @@ fn main() -> Result<()> {
|
|||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let model_id = match &args.model_id {
|
let model_id = match &args.model_id {
|
||||||
Some(model_id) => match model_id.as_str() {
|
Some(model_id) => model_id.to_string(),
|
||||||
"7b-it" => "google/gemma-7b-it".to_string(),
|
None => match args.which {
|
||||||
"7b" => "google/gemma-7b".to_string(),
|
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
||||||
"2b-it" => "google/gemma-2b-it".to_string(),
|
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
||||||
"2b" => "google/gemma-2b".to_string(),
|
Which::Base2B => "google/gemma-2b".to_string(),
|
||||||
_ => model_id.to_string(),
|
Which::Base7B => "google/gemma-7b".to_string(),
|
||||||
|
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
||||||
|
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
||||||
|
Which::CodeBase2B => "google/codegemma-2b".to_string(),
|
||||||
|
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||||
|
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||||
|
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||||
},
|
},
|
||||||
None => "google/gemma-2b".to_string(),
|
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
model_id,
|
model_id,
|
||||||
@ -237,7 +273,7 @@ fn main() -> Result<()> {
|
|||||||
DType::F32
|
DType::F32
|
||||||
};
|
};
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(args.use_flash_attn, &config, vb)?;
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
19
candle-examples/examples/gte-qwen/README.md
Normal file
19
candle-examples/examples/gte-qwen/README.md
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# gte-Qwen1.5-7B-instruct
|
||||||
|
|
||||||
|
gte-Qwen1.5-7B-instruct is a variant of the GTE embedding model family.
|
||||||
|
|
||||||
|
- [Model card](https://huggingface.co/Alibaba-NLP/gte-Qwen1.5-7B-instruct) on the HuggingFace Hub.
|
||||||
|
- [Technical report](https://arxiv.org/abs/2308.03281) *Towards General Text Embeddings with Multi-stage Contrastive Learning*
|
||||||
|
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
Automatically download the model from the HuggingFace hub:
|
||||||
|
```bash
|
||||||
|
$ cargo run --example gte-qwen --release
|
||||||
|
```
|
||||||
|
|
||||||
|
or, load the model from a local directory:
|
||||||
|
```bash
|
||||||
|
cargo run --example gte-qwen --release --features cuda -- --local-repo /path/to/gte_Qwen1.5-7B-instruct/
|
||||||
|
```
|
178
candle-examples/examples/gte-qwen/main.rs
Normal file
178
candle-examples/examples/gte-qwen/main.rs
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::qwen2::{Config, Model};
|
||||||
|
|
||||||
|
use candle::{DType, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::{
|
||||||
|
utils::padding::{PaddingDirection, PaddingParams, PaddingStrategy},
|
||||||
|
Tokenizer,
|
||||||
|
};
|
||||||
|
|
||||||
|
// gte-Qwen1.5-7B-instruct use EOS token as padding token
|
||||||
|
const EOS_TOKEN: &str = "<|endoftext|>";
|
||||||
|
const EOS_TOKEN_ID: u32 = 151643;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "Alibaba-NLP/gte-Qwen1.5-7B-instruct")]
|
||||||
|
model_id: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
local_repo: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ConfigFiles {
|
||||||
|
pub config: std::path::PathBuf,
|
||||||
|
pub tokenizer: std::path::PathBuf,
|
||||||
|
pub weights: Vec<std::path::PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loading the model from the HuggingFace Hub. Network access is required.
|
||||||
|
fn load_from_hub(model_id: &str, revision: &str) -> Result<ConfigFiles> {
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
model_id.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.to_string(),
|
||||||
|
));
|
||||||
|
Ok(ConfigFiles {
|
||||||
|
config: repo.get("config.json")?,
|
||||||
|
tokenizer: repo.get("tokenizer.json")?,
|
||||||
|
weights: candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loading the model from a local directory.
|
||||||
|
fn load_from_local(local_path: &str) -> Result<ConfigFiles> {
|
||||||
|
let local_path = std::path::PathBuf::from(local_path);
|
||||||
|
let weight_path = local_path.join("model.safetensors.index.json");
|
||||||
|
let json: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(weight_path)?)?;
|
||||||
|
let weight_map = match json.get("weight_map") {
|
||||||
|
Some(serde_json::Value::Object(map)) => map,
|
||||||
|
Some(_) => panic!("`weight map` is not a map"),
|
||||||
|
None => panic!("`weight map` not found"),
|
||||||
|
};
|
||||||
|
let mut safetensors_files = std::collections::HashSet::new();
|
||||||
|
for value in weight_map.values() {
|
||||||
|
safetensors_files.insert(
|
||||||
|
value
|
||||||
|
.as_str()
|
||||||
|
.expect("Weight files should be parsed as strings"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let safetensors_paths = safetensors_files
|
||||||
|
.iter()
|
||||||
|
.map(|v| local_path.join(v))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
Ok(ConfigFiles {
|
||||||
|
config: local_path.join("config.json"),
|
||||||
|
tokenizer: local_path.join("tokenizer.json"),
|
||||||
|
weights: safetensors_paths,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Fetch the model. Do this offline if local path provided.
|
||||||
|
println!("Fetching model files...");
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config_files = match args.local_repo {
|
||||||
|
Some(local_path) => load_from_local(&local_path)?,
|
||||||
|
None => load_from_hub(&args.model_id, &args.revision)?,
|
||||||
|
};
|
||||||
|
println!("Model file retrieved in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
// Inputs will be padded to the longest sequence in the batch.
|
||||||
|
let padding = PaddingParams {
|
||||||
|
strategy: PaddingStrategy::BatchLongest,
|
||||||
|
direction: PaddingDirection::Left,
|
||||||
|
pad_to_multiple_of: None,
|
||||||
|
pad_id: EOS_TOKEN_ID,
|
||||||
|
pad_type_id: 0,
|
||||||
|
pad_token: String::from(EOS_TOKEN),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Tokenizer setup
|
||||||
|
let mut tokenizer = Tokenizer::from_file(config_files.tokenizer).map_err(E::msg)?;
|
||||||
|
tokenizer.with_padding(Some(padding));
|
||||||
|
|
||||||
|
// Model initialization
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let config: Config = serde_json::from_slice(&std::fs::read(config_files.config)?)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&config_files.weights, dtype, &device)? };
|
||||||
|
let mut model = Model::new(&config, vb)?;
|
||||||
|
println!("Model loaded in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
// Encode the queries and the targets
|
||||||
|
let instruct = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ";
|
||||||
|
let documents = vec![
|
||||||
|
format!("{instruct}how much protein should a female eat{EOS_TOKEN}"),
|
||||||
|
format!("{instruct}summit define{EOS_TOKEN}"),
|
||||||
|
format!("As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.{EOS_TOKEN}"),
|
||||||
|
format!("Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.{EOS_TOKEN}"),
|
||||||
|
];
|
||||||
|
let encoded = tokenizer.encode_batch(documents, true).map_err(E::msg)?;
|
||||||
|
let tokens: Vec<&[u32]> = encoded.iter().map(|x| x.get_ids()).collect();
|
||||||
|
let tokens = Tensor::new(tokens, &device)?;
|
||||||
|
let mask: Vec<&[u32]> = encoded.iter().map(|x| x.get_attention_mask()).collect();
|
||||||
|
let mask = Tensor::new(mask, &device)?;
|
||||||
|
|
||||||
|
// Inference
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
let logits = model.forward(&tokens, 0, Some(&mask))?;
|
||||||
|
|
||||||
|
// Extract the last hidden states as embeddings since inputs are padded left.
|
||||||
|
let (_, seq_len, _) = logits.dims3()?;
|
||||||
|
let embd = logits
|
||||||
|
.narrow(1, seq_len - 1, 1)?
|
||||||
|
.squeeze(1)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
|
// Calculate the relativity scores. Note the embeddings should be normalized.
|
||||||
|
let norm = embd.broadcast_div(&embd.sqr()?.sum_keepdim(1)?.sqrt()?)?;
|
||||||
|
let scores = norm.narrow(0, 0, 2)?.matmul(&norm.narrow(0, 2, 2)?.t()?)?;
|
||||||
|
|
||||||
|
// Print the results
|
||||||
|
println!("Embedding done in {:?}", start_gen.elapsed());
|
||||||
|
println!("Scores: {:?}", scores.to_vec2::<f32>()?);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -17,7 +17,7 @@ use clap::{Parser, ValueEnum};
|
|||||||
|
|
||||||
use candle::{DType, Tensor};
|
use candle::{DType, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
@ -31,6 +31,8 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
|||||||
enum Which {
|
enum Which {
|
||||||
V1,
|
V1,
|
||||||
V2,
|
V2,
|
||||||
|
V3,
|
||||||
|
V3Instruct,
|
||||||
#[value(name = "solar-10.7b")]
|
#[value(name = "solar-10.7b")]
|
||||||
Solar10_7B,
|
Solar10_7B,
|
||||||
#[value(name = "tiny-llama-1.1b-chat")]
|
#[value(name = "tiny-llama-1.1b-chat")]
|
||||||
@ -45,19 +47,23 @@ struct Args {
|
|||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long, default_value_t = 0.8)]
|
||||||
temperature: Option<f64>,
|
temperature: f64,
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
/// Nucleus sampling probability cutoff.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, default_value_t = 10000)]
|
#[arg(short = 'n', long, default_value_t = 10000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
/// Disable the key-value cache.
|
/// Disable the key-value cache.
|
||||||
@ -83,18 +89,18 @@ struct Args {
|
|||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
/// The model size to use.
|
/// The model size to use.
|
||||||
#[arg(long, default_value = "v2")]
|
#[arg(long, default_value = "v3")]
|
||||||
which: Which,
|
which: Which,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
#[arg(long, default_value_t = 1.0)]
|
#[arg(long, default_value_t = 1.1)]
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 128)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,11 +126,13 @@ fn main() -> Result<()> {
|
|||||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||||
None => DType::F16,
|
None => DType::F16,
|
||||||
};
|
};
|
||||||
let (llama, tokenizer_filename, mut cache) = {
|
let (llama, tokenizer_filename, mut cache, config) = {
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||||
|
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||||
|
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
||||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||||
});
|
});
|
||||||
@ -138,7 +146,7 @@ fn main() -> Result<()> {
|
|||||||
let config = config.into_config(args.use_flash_attn);
|
let config = config.into_config(args.use_flash_attn);
|
||||||
|
|
||||||
let filenames = match args.which {
|
let filenames = match args.which {
|
||||||
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
|
||||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||||
}
|
}
|
||||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||||
@ -146,10 +154,12 @@ fn main() -> Result<()> {
|
|||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
(Llama::load(vb, &config)?, tokenizer_filename, cache)
|
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
||||||
};
|
};
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
let eos_token_id = config
|
||||||
|
.eos_token_id
|
||||||
|
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
|
||||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(prompt, true)
|
.encode(prompt, true)
|
||||||
@ -160,8 +170,22 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
print!("{prompt}");
|
print!("{prompt}");
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
let mut logits_processor = {
|
||||||
let start_gen = std::time::Instant::now();
|
let temperature = args.temperature;
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (args.top_k, args.top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut start_gen = std::time::Instant::now();
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut token_generated = 0;
|
let mut token_generated = 0;
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
@ -170,6 +194,9 @@ fn main() -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
(tokens.len(), 0)
|
(tokens.len(), 0)
|
||||||
};
|
};
|
||||||
|
if index == 1 {
|
||||||
|
start_gen = std::time::Instant::now()
|
||||||
|
}
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let logits = llama.forward(&input, context_index, &mut cache)?;
|
let logits = llama.forward(&input, context_index, &mut cache)?;
|
||||||
@ -205,7 +232,7 @@ fn main() -> Result<()> {
|
|||||||
println!(
|
println!(
|
||||||
"\n\n{} tokens generated ({} token/s)\n",
|
"\n\n{} tokens generated ({} token/s)\n",
|
||||||
token_generated,
|
token_generated,
|
||||||
token_generated as f64 / dt.as_secs_f64(),
|
(token_generated - 1) as f64 / dt.as_secs_f64(),
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::{bail, Error as E, Result};
|
use anyhow::{bail, Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
@ -24,57 +24,15 @@ mod model;
|
|||||||
use model::{Config, Llama};
|
use model::{Config, Llama};
|
||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 4096;
|
const MAX_SEQ_LEN: usize = 4096;
|
||||||
const DEFAULT_PROMPT: &str = r"
|
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||||
EDWARD:
|
|
||||||
I wonder how our princely father 'scaped,
|
|
||||||
Or whether he be 'scaped away or no
|
|
||||||
From Clifford's and Northumberland's pursuit:
|
|
||||||
Had he been ta'en, we should have heard the news;
|
|
||||||
Had he been slain, we should have heard the news;
|
|
||||||
Or had he 'scaped, methinks we should have heard
|
|
||||||
The happy tidings of his good escape.
|
|
||||||
How fares my brother? why is he so sad?
|
|
||||||
|
|
||||||
RICHARD:
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
I cannot joy, until I be resolved
|
enum Which {
|
||||||
Where our right valiant father is become.
|
V2_7b,
|
||||||
I saw him in the battle range about;
|
V2_70b,
|
||||||
And watch'd him how he singled Clifford forth.
|
V3_8b,
|
||||||
Methought he bore him in the thickest troop
|
V3_70b,
|
||||||
As doth a lion in a herd of neat;
|
}
|
||||||
Or as a bear, encompass'd round with dogs,
|
|
||||||
Who having pinch'd a few and made them cry,
|
|
||||||
The rest stand all aloof, and bark at him.
|
|
||||||
So fared our father with his enemies;
|
|
||||||
So fled his enemies my warlike father:
|
|
||||||
Methinks, 'tis prize enough to be his son.
|
|
||||||
See how the morning opes her golden gates,
|
|
||||||
And takes her farewell of the glorious sun!
|
|
||||||
How well resembles it the prime of youth,
|
|
||||||
Trimm'd like a younker prancing to his love!
|
|
||||||
|
|
||||||
EDWARD:
|
|
||||||
Dazzle mine eyes, or do I see three suns?
|
|
||||||
|
|
||||||
RICHARD:
|
|
||||||
Three glorious suns, each one a perfect sun;
|
|
||||||
Not separated with the racking clouds,
|
|
||||||
But sever'd in a pale clear-shining sky.
|
|
||||||
See, see! they join, embrace, and seem to kiss,
|
|
||||||
As if they vow'd some league inviolable:
|
|
||||||
Now are they but one lamp, one light, one sun.
|
|
||||||
In this the heaven figures some event.
|
|
||||||
|
|
||||||
EDWARD:
|
|
||||||
'Tis wondrous strange, the like yet never heard of.
|
|
||||||
I think it cites us, brother, to the field,
|
|
||||||
That we, the sons of brave Plantagenet,
|
|
||||||
Each one already blazing by our meeds,
|
|
||||||
Should notwithstanding join our lights together
|
|
||||||
And over-shine the earth as this the world.
|
|
||||||
Whate'er it bodes, henceforward will I bear
|
|
||||||
Upon my target three fair-shining suns.
|
|
||||||
";
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
@ -86,8 +44,8 @@ struct Args {
|
|||||||
rank: Option<usize>,
|
rank: Option<usize>,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long, default_value_t = 0.8)]
|
||||||
temperature: Option<f64>,
|
temperature: f64,
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
/// Nucleus sampling probability cutoff.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -117,6 +75,12 @@ struct Args {
|
|||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
dtype: Option<String>,
|
dtype: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "v3-8b")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "nccl_id.txt")]
|
||||||
|
comm_file: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -129,14 +93,27 @@ fn main() -> Result<()> {
|
|||||||
Some("bf16") => DType::BF16,
|
Some("bf16") => DType::BF16,
|
||||||
Some("f32") => DType::F32,
|
Some("f32") => DType::F32,
|
||||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||||
None => DType::F16,
|
None => match args.which {
|
||||||
|
Which::V2_7b | Which::V2_70b => DType::F16,
|
||||||
|
Which::V3_8b | Which::V3_70b => DType::BF16,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let api = Api::new()?;
|
let comm_file = std::path::PathBuf::from(&args.comm_file);
|
||||||
|
if comm_file.exists() {
|
||||||
|
bail!("comm file {comm_file:?} already exists, please remove it first")
|
||||||
|
}
|
||||||
|
|
||||||
let model_id = args
|
let api = Api::new()?;
|
||||||
.model_id
|
let model_id = match args.model_id {
|
||||||
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
|
Some(model) => model,
|
||||||
|
None => match args.which {
|
||||||
|
Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||||
|
Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(),
|
||||||
|
Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||||
|
Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
println!("loading the model weights from {model_id}");
|
println!("loading the model weights from {model_id}");
|
||||||
let revision = args.revision.unwrap_or("main".to_string());
|
let revision = args.revision.unwrap_or("main".to_string());
|
||||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
@ -145,39 +122,40 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||||
|
|
||||||
if args.rank.is_none() {
|
let rank = match args.rank {
|
||||||
let children: Vec<_> = (0..args.num_shards)
|
None => {
|
||||||
.map(|rank| {
|
println!("creating {} child processes", args.num_shards);
|
||||||
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
|
let children: Vec<_> = (0..args.num_shards)
|
||||||
args.push_back("--rank".to_string());
|
.map(|rank| {
|
||||||
args.push_back(format!("{rank}"));
|
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
|
||||||
let name = args.pop_front().unwrap();
|
args.push_back("--rank".to_string());
|
||||||
std::process::Command::new(name).args(args).spawn().unwrap()
|
args.push_back(format!("{rank}"));
|
||||||
})
|
let name = args.pop_front().unwrap();
|
||||||
.collect();
|
std::process::Command::new(name).args(args).spawn().unwrap()
|
||||||
for mut child in children {
|
})
|
||||||
child.wait().unwrap();
|
.collect();
|
||||||
|
for mut child in children {
|
||||||
|
child.wait()?;
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
}
|
}
|
||||||
return Ok(());
|
Some(rank) => rank,
|
||||||
}
|
};
|
||||||
|
|
||||||
let i = args.rank.unwrap();
|
|
||||||
let num_shards = args.num_shards;
|
let num_shards = args.num_shards;
|
||||||
let rank = i;
|
|
||||||
// Primitive IPC
|
// Primitive IPC
|
||||||
let id = if rank == 0 {
|
let id = if rank == 0 {
|
||||||
let id = Id::new().unwrap();
|
let id = Id::new().unwrap();
|
||||||
std::fs::File::create("nccl_id.txt.tmp")?
|
let tmp_file = comm_file.with_extension(".comm.tgz");
|
||||||
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())
|
std::fs::File::create(&tmp_file)?
|
||||||
.unwrap();
|
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
|
||||||
std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?;
|
std::fs::rename(&tmp_file, &comm_file)?;
|
||||||
id
|
id
|
||||||
} else {
|
} else {
|
||||||
let path = std::path::PathBuf::from("nccl_id.txt");
|
while !comm_file.exists() {
|
||||||
while !path.exists() {
|
|
||||||
std::thread::sleep(std::time::Duration::from_secs(1));
|
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||||
}
|
}
|
||||||
let data = std::fs::read("nccl_id.txt")?;
|
let data = std::fs::read(&comm_file)?;
|
||||||
let internal: [i8; 128] = data
|
let internal: [i8; 128] = data
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|i| i as i8)
|
.map(|i| i as i8)
|
||||||
@ -187,14 +165,17 @@ fn main() -> Result<()> {
|
|||||||
let id: Id = Id::uninit(internal);
|
let id: Id = Id::uninit(internal);
|
||||||
id
|
id
|
||||||
};
|
};
|
||||||
let device = CudaDevice::new(i)?;
|
let device = CudaDevice::new(rank)?;
|
||||||
let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap());
|
let comm = match Comm::from_rank(device, rank, num_shards, id) {
|
||||||
|
Ok(comm) => Rc::new(comm),
|
||||||
|
Err(err) => anyhow::bail!("nccl error {:?}", err.0),
|
||||||
|
};
|
||||||
if rank == 0 {
|
if rank == 0 {
|
||||||
std::fs::remove_file("nccl_id.txt")?;
|
std::fs::remove_file(comm_file)?;
|
||||||
}
|
}
|
||||||
println!("Rank {rank:?} spawned");
|
println!("Rank {rank:?} spawned");
|
||||||
|
|
||||||
let device = Device::new_cuda(i)?;
|
let device = Device::new_cuda(rank)?;
|
||||||
let cache = model::Cache::new(dtype, &config, &device)?;
|
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
@ -210,14 +191,24 @@ fn main() -> Result<()> {
|
|||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
let temperature = if args.temperature <= 0. {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(args.temperature)
|
||||||
|
};
|
||||||
|
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
let start_gen = std::time::Instant::now();
|
let mut start_gen = std::time::Instant::now();
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let start_gen = std::time::Instant::now();
|
// Only start timing at the second token as processing the first token waits for all the
|
||||||
|
// weights to be loaded in an async way.
|
||||||
|
if index == 1 {
|
||||||
|
start_gen = std::time::Instant::now()
|
||||||
|
};
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
@ -228,25 +219,23 @@ fn main() -> Result<()> {
|
|||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
|
if Some(next_token) == config.eos_token_id {
|
||||||
|
break;
|
||||||
|
}
|
||||||
if rank == 0 {
|
if rank == 0 {
|
||||||
println!("> {:?}", start_gen.elapsed());
|
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||||
println!(
|
print!("{t}");
|
||||||
"{} token: {} '{}'",
|
std::io::stdout().flush()?;
|
||||||
index + 1,
|
}
|
||||||
next_token,
|
|
||||||
tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let dt = start_gen.elapsed();
|
println!();
|
||||||
if rank == 0 {
|
if rank == 0 {
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
"\n\n{} tokens generated ({} token/s)\n",
|
||||||
args.sample_len,
|
args.sample_len,
|
||||||
args.sample_len as f64 / dt.as_secs_f64(),
|
(args.sample_len - 1) as f64 / dt.as_secs_f64(),
|
||||||
tokenizer
|
|
||||||
.decode(new_tokens.as_slice(), true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1,15 +1,14 @@
|
|||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||||
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
||||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||||
use half::f16;
|
|
||||||
use serde::Deserialize;
|
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use super::MAX_SEQ_LEN;
|
use super::MAX_SEQ_LEN;
|
||||||
|
|
||||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
pub type Config = candle_transformers::models::llama::LlamaConfig;
|
||||||
|
|
||||||
struct TensorParallelColumnLinear {
|
struct TensorParallelColumnLinear {
|
||||||
linear: Linear,
|
linear: Linear,
|
||||||
@ -26,7 +25,7 @@ impl TensorParallelColumnLinear {
|
|||||||
|
|
||||||
struct TensorParallelRowLinear {
|
struct TensorParallelRowLinear {
|
||||||
linear: Linear,
|
linear: Linear,
|
||||||
comm: Rc<Comm>,
|
all_reduce: AllReduce,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AllReduce {
|
struct AllReduce {
|
||||||
@ -36,8 +35,6 @@ struct AllReduce {
|
|||||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||||
/// But for this example purposes, this will work
|
/// But for this example purposes, this will work
|
||||||
unsafe impl Sync for AllReduce {}
|
unsafe impl Sync for AllReduce {}
|
||||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
|
||||||
/// But for this example purposes, this will work
|
|
||||||
unsafe impl Send for AllReduce {}
|
unsafe impl Send for AllReduce {}
|
||||||
|
|
||||||
impl CustomOp1 for AllReduce {
|
impl CustomOp1 for AllReduce {
|
||||||
@ -46,7 +43,7 @@ impl CustomOp1 for AllReduce {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||||
todo!("implement allreduce for cpu is not necessary for single node");
|
candle::bail!("AllReduce is never used on cpu")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
@ -56,31 +53,49 @@ impl CustomOp1 for AllReduce {
|
|||||||
l: &Layout,
|
l: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::WrapErr;
|
use candle::cuda_backend::WrapErr;
|
||||||
|
use cudarc::driver::DeviceSlice;
|
||||||
|
use half::{bf16, f16};
|
||||||
|
|
||||||
let elem_count = l.shape().elem_count();
|
let elem_count = l.shape().elem_count();
|
||||||
let dev = s.device().clone();
|
let dev = s.device().clone();
|
||||||
let s = s.as_cuda_slice::<f16>()?;
|
let dst = match s.dtype() {
|
||||||
// let s = match l.contiguous_offsets() {
|
DType::BF16 => {
|
||||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
let s = s.as_cuda_slice::<bf16>()?;
|
||||||
// Some((o1, o2)) => s.slice(o1..o2),
|
let s = match l.contiguous_offsets() {
|
||||||
// };
|
Some((0, l)) if l == s.len() => s,
|
||||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
Some(_) | None => candle::bail!("input has to be contiguous"),
|
||||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
};
|
||||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(s, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(candle::Error::debug)?;
|
||||||
|
candle::CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let s = s.as_cuda_slice::<f16>()?;
|
||||||
|
let s = match l.contiguous_offsets() {
|
||||||
|
Some((0, l)) if l == s.len() => s,
|
||||||
|
Some(_) | None => candle::bail!("input has to be contiguous"),
|
||||||
|
};
|
||||||
|
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(s, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(candle::Error::debug)?;
|
||||||
|
candle::CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
dtype => candle::bail!("unsupported dtype {dtype:?}"),
|
||||||
|
};
|
||||||
Ok((dst, l.shape().clone()))
|
Ok((dst, l.shape().clone()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
|
||||||
x.apply_op1(AllReduce { comm: comm.clone() })
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TensorParallelRowLinear {
|
impl TensorParallelRowLinear {
|
||||||
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
|
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
|
||||||
Self { linear, comm }
|
let all_reduce = AllReduce { comm };
|
||||||
|
Self { linear, all_reduce }
|
||||||
}
|
}
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = self.linear.forward(x)?;
|
self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce)
|
||||||
all_reduce_sum(&x, &self.comm)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,23 +136,6 @@ impl TensorParallelRowLinear {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct Config {
|
|
||||||
pub hidden_size: usize,
|
|
||||||
pub intermediate_size: usize,
|
|
||||||
pub vocab_size: usize,
|
|
||||||
pub num_hidden_layers: usize,
|
|
||||||
pub num_attention_heads: usize,
|
|
||||||
pub num_key_value_heads: usize,
|
|
||||||
pub rms_norm_eps: f64,
|
|
||||||
#[serde(default = "default_rope")]
|
|
||||||
pub rope_theta: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_rope() -> f32 {
|
|
||||||
10_000.0
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Cache {
|
pub struct Cache {
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
@ -161,7 +159,6 @@ impl Cache {
|
|||||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
// This is different from the paper, see:
|
// This is different from the paper, see:
|
||||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
|
||||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -197,16 +194,10 @@ struct CausalSelfAttention {
|
|||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?;
|
let (_b_sz, _, seq_len, _hidden_size) = x.shape().dims4()?;
|
||||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
candle_nn::rotary_emb::rope(x, &cos, &sin)
|
||||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
|
||||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
|
||||||
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
|
|
||||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
|
||||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
|
||||||
Ok(rope)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||||
@ -232,13 +223,16 @@ impl CausalSelfAttention {
|
|||||||
|
|
||||||
let q = q
|
let q = q
|
||||||
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
let k = k
|
let k = k
|
||||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
let mut v = v
|
let mut v = v
|
||||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
|
||||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||||
@ -269,25 +263,14 @@ impl CausalSelfAttention {
|
|||||||
let v = v.transpose(1, 2)?;
|
let v = v.transpose(1, 2)?;
|
||||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||||
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
||||||
.transpose(1, 2)?;
|
.reshape((b_sz, seq_len, hidden_size))?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
|
||||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
|
|
||||||
let y = self.o_proj.forward(&y)?;
|
let y = self.o_proj.forward(&y)?;
|
||||||
Ok(y)
|
Ok(y)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||||
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
||||||
if n_rep == 1 {
|
candle_transformers::utils::repeat_kv(x, n_rep)
|
||||||
Ok(x)
|
|
||||||
} else {
|
|
||||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?;
|
|
||||||
let x = x
|
|
||||||
.unsqueeze(2)?
|
|
||||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||||
@ -301,7 +284,7 @@ impl CausalSelfAttention {
|
|||||||
qkv_proj,
|
qkv_proj,
|
||||||
o_proj,
|
o_proj,
|
||||||
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
|
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
|
||||||
num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
|
num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(),
|
||||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||||
cache: cache.clone(),
|
cache: cache.clone(),
|
||||||
})
|
})
|
||||||
@ -315,18 +298,6 @@ struct Mlp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Mlp {
|
impl Mlp {
|
||||||
fn new(
|
|
||||||
c_fc1: TensorParallelColumnLinear,
|
|
||||||
c_fc2: TensorParallelColumnLinear,
|
|
||||||
c_proj: TensorParallelRowLinear,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
c_fc1,
|
|
||||||
c_fc2,
|
|
||||||
c_proj,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||||
self.c_proj.forward(&x)
|
self.c_proj.forward(&x)
|
||||||
@ -336,7 +307,11 @@ impl Mlp {
|
|||||||
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
||||||
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
||||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
|
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
|
||||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
Ok(Self {
|
||||||
|
c_fc1,
|
||||||
|
c_fc2,
|
||||||
|
c_proj,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -427,10 +402,8 @@ impl Llama {
|
|||||||
cfg,
|
cfg,
|
||||||
comm.clone(),
|
comm.clone(),
|
||||||
)
|
)
|
||||||
.unwrap()
|
|
||||||
})
|
})
|
||||||
.collect();
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
4
candle-examples/examples/llava/constants.rs
Normal file
4
candle-examples/examples/llava/constants.rs
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
pub const DEFAULT_IMAGE_TOKEN: &str = "<image>";
|
||||||
|
pub const DEFAULT_IM_START_TOKEN: &str = "<im_start>";
|
||||||
|
pub const DEFAULT_IM_END_TOKEN: &str = "<im_end>";
|
||||||
|
pub const IMAGE_PLACEHOLDER: &str = "<image-placeholder>";
|
114
candle-examples/examples/llava/conversation.rs
Normal file
114
candle-examples/examples/llava/conversation.rs
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
pub enum SeparatorStyle {
|
||||||
|
Two,
|
||||||
|
Mpt,
|
||||||
|
}
|
||||||
|
pub struct Conversation {
|
||||||
|
pub system: String,
|
||||||
|
pub roles: Vec<String>,
|
||||||
|
pub messages: Vec<(String, Option<String>)>,
|
||||||
|
pub offset: i32,
|
||||||
|
pub sep_style: SeparatorStyle,
|
||||||
|
pub sep: String,
|
||||||
|
pub sep2: Option<String>,
|
||||||
|
pub version: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Conversation {
|
||||||
|
pub fn new(
|
||||||
|
system: &str,
|
||||||
|
roles: &[String],
|
||||||
|
offset: i32,
|
||||||
|
sep_style: SeparatorStyle,
|
||||||
|
sep: &str,
|
||||||
|
sep2: Option<&str>,
|
||||||
|
version: &str,
|
||||||
|
) -> Self {
|
||||||
|
Conversation {
|
||||||
|
system: system.to_string(),
|
||||||
|
roles: roles.to_vec(),
|
||||||
|
messages: Vec::new(),
|
||||||
|
offset,
|
||||||
|
sep_style,
|
||||||
|
sep: sep.to_string(),
|
||||||
|
sep2: sep2.map(|s| s.to_string()),
|
||||||
|
version: version.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn conv_chatml_direct() -> Self {
|
||||||
|
Conversation::new(
|
||||||
|
"<|im_start|>system\nAnswer the questions.",
|
||||||
|
&[
|
||||||
|
"<|im_start|>user\n".to_string(),
|
||||||
|
"<|im_start|>assistant\n".to_string(),
|
||||||
|
],
|
||||||
|
0,
|
||||||
|
SeparatorStyle::Mpt,
|
||||||
|
"<|im_end|>",
|
||||||
|
None,
|
||||||
|
"mpt",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn conv_llava_v1() -> Self {
|
||||||
|
Conversation::new(
|
||||||
|
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||||
|
&[
|
||||||
|
"USER".to_string(),
|
||||||
|
"ASSISTANT".to_string(),
|
||||||
|
],
|
||||||
|
0,
|
||||||
|
SeparatorStyle::Two,
|
||||||
|
" ",
|
||||||
|
Some("</s>"),
|
||||||
|
"v1"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append_message(&mut self, role: String, message: Option<&str>) {
|
||||||
|
self.messages.push((role, message.map(|s| s.to_string())))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append_user_message(&mut self, message: Option<&str>) {
|
||||||
|
self.append_message(self.roles[0].clone(), message);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append_assistant_message(&mut self, message: Option<&str>) {
|
||||||
|
self.append_message(self.roles[1].clone(), message);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_prompt(&self) -> String {
|
||||||
|
match self.sep_style {
|
||||||
|
SeparatorStyle::Mpt => {
|
||||||
|
let mut ret = String::new();
|
||||||
|
ret.push_str(&self.system);
|
||||||
|
ret.push_str(&self.sep);
|
||||||
|
for (role, message) in &self.messages {
|
||||||
|
ret.push_str(role);
|
||||||
|
if let Some(message) = message {
|
||||||
|
ret.push_str(message);
|
||||||
|
};
|
||||||
|
ret.push_str(&self.sep);
|
||||||
|
}
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
SeparatorStyle::Two => {
|
||||||
|
let seps = [self.sep.clone(), self.sep2.clone().unwrap()];
|
||||||
|
let mut ret = String::new();
|
||||||
|
ret.push_str(&self.system);
|
||||||
|
ret.push_str(&seps[0]);
|
||||||
|
for (i, (role, message)) in self.messages.iter().enumerate() {
|
||||||
|
ret.push_str(role);
|
||||||
|
if let Some(message) = message {
|
||||||
|
ret.push_str(": "); // strictly follow the python implementation, otherwise it will cause some minor difference between tokens ^_^
|
||||||
|
ret.push_str(message);
|
||||||
|
ret.push_str(&seps[i % 2]);
|
||||||
|
} else {
|
||||||
|
ret.push(':')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
317
candle-examples/examples/llava/image_processor.rs
Normal file
317
candle-examples/examples/llava/image_processor.rs
Normal file
@ -0,0 +1,317 @@
|
|||||||
|
use std::cmp::min;
|
||||||
|
|
||||||
|
use candle::{bail, DType, Device, Result, Tensor};
|
||||||
|
use candle_transformers::models::llava::{
|
||||||
|
config::{HFPreProcessorConfig, LLaVAConfig},
|
||||||
|
utils::select_best_resolution,
|
||||||
|
};
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
use image::{imageops::overlay, DynamicImage, GenericImageView, Rgb, RgbImage};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
//This struct is mainly for LLaVA aplications, hence it's not completely compatible with python transformer CLIPImageProcessor few several preprocess that LLaVA used, including "openai/clip-vit-large-patch14-336" and "openai/clip-vit-large-patch14".
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct ImageProcessor {
|
||||||
|
#[serde(default = "default_size")]
|
||||||
|
pub size: u32, // this is not the same as python transformer
|
||||||
|
#[serde(default = "default_do_resize")]
|
||||||
|
pub do_resize: bool,
|
||||||
|
|
||||||
|
//resample: u32 // 3 for PIL bicubic, equivalent to rust CatmullRom. Hence below we use CatmullRom
|
||||||
|
#[serde(default = "default_do_center_crop")]
|
||||||
|
pub do_center_crop: bool,
|
||||||
|
#[serde(default = "default_crop_size")]
|
||||||
|
pub crop_size: u32, // this is not the same as python transformer
|
||||||
|
#[serde(default = "default_do_rescale")]
|
||||||
|
pub do_rescale: bool,
|
||||||
|
#[serde(default = "default_rescale_factor")]
|
||||||
|
pub rescale_factor: f32,
|
||||||
|
#[serde(default = "default_do_normalize")]
|
||||||
|
pub do_normalize: bool,
|
||||||
|
#[serde(default = "default_image_mean")]
|
||||||
|
pub image_mean: Vec<f32>,
|
||||||
|
#[serde(default = "default_image_std")]
|
||||||
|
pub image_std: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_size() -> u32 {
|
||||||
|
224
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_do_resize() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_do_center_crop() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_crop_size() -> u32 {
|
||||||
|
224
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_do_rescale() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_rescale_factor() -> f32 {
|
||||||
|
1.0 / 255.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_do_normalize() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_image_mean() -> Vec<f32> {
|
||||||
|
vec![0.48145466, 0.4578275, 0.40821073]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_image_std() -> Vec<f32> {
|
||||||
|
vec![0.26862954, 0.2613026, 0.2757771]
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ImageProcessor {
|
||||||
|
pub fn from_pretrained(clip_id: &str) -> Result<Self> {
|
||||||
|
let api = Api::new().map_err(|e| candle::Error::Msg(e.to_string()))?;
|
||||||
|
let api = api.model(clip_id.to_string());
|
||||||
|
let config_filename = api
|
||||||
|
.get("preprocessor_config.json")
|
||||||
|
.map_err(|e| candle::Error::Msg(e.to_string()))?;
|
||||||
|
let image_processor =
|
||||||
|
serde_json::from_slice(&std::fs::read(config_filename).map_err(candle::Error::Io)?)
|
||||||
|
.map_err(|e| candle::Error::Msg(e.to_string()))?;
|
||||||
|
Ok(image_processor)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_hf_preprocessor_config(hf_preprocessor_config: &HFPreProcessorConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
size: hf_preprocessor_config.size["shortest_edge"] as u32,
|
||||||
|
do_resize: hf_preprocessor_config.do_resize,
|
||||||
|
do_center_crop: hf_preprocessor_config.do_center_crop,
|
||||||
|
crop_size: hf_preprocessor_config.crop_size["height"] as u32,
|
||||||
|
do_rescale: hf_preprocessor_config.do_rescale,
|
||||||
|
rescale_factor: hf_preprocessor_config.rescale_factor,
|
||||||
|
do_normalize: hf_preprocessor_config.do_normalize,
|
||||||
|
image_mean: hf_preprocessor_config.image_mean.clone(),
|
||||||
|
image_std: hf_preprocessor_config.image_std.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///shortest edge to self.resize, other edge is resized to maintain aspect ratio
|
||||||
|
pub fn resize(&self, image: &DynamicImage) -> DynamicImage {
|
||||||
|
let (width, height) = image.dimensions();
|
||||||
|
let size = self.size;
|
||||||
|
if width == size && height == size {
|
||||||
|
image.clone()
|
||||||
|
} else {
|
||||||
|
let (new_width, new_height) = if width < height {
|
||||||
|
(
|
||||||
|
size,
|
||||||
|
(((size * height) as f32) / width as f32).ceil() as u32,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
(((size * width) as f32) / height as f32).ceil() as u32,
|
||||||
|
size,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
image.resize(
|
||||||
|
new_width,
|
||||||
|
new_height,
|
||||||
|
image::imageops::FilterType::CatmullRom,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn center_crop(&self, image: &DynamicImage) -> DynamicImage {
|
||||||
|
let (width, height) = image.dimensions();
|
||||||
|
let crop_size = self.crop_size;
|
||||||
|
let (left, top) = calculate_middle((width, height), (crop_size, crop_size));
|
||||||
|
image.crop_imm(left, top, crop_size, crop_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_tensor(&self, image: &DynamicImage) -> Result<Tensor> {
|
||||||
|
let img = image.to_rgb8().into_raw();
|
||||||
|
let (width, height) = image.dimensions();
|
||||||
|
Tensor::from_vec(img, (height as usize, width as usize, 3), &Device::Cpu)?
|
||||||
|
.to_dtype(DType::F32) // only for internal compute
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rescale(&self, tensor: &Tensor) -> Result<Tensor> {
|
||||||
|
let rescale_factor = self.rescale_factor as f64;
|
||||||
|
tensor.affine(rescale_factor, 0.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize(&self, tensor: &Tensor) -> Result<Tensor> {
|
||||||
|
let image_mean = self.image_mean.clone();
|
||||||
|
let image_std = self.image_std.clone();
|
||||||
|
let mean = Tensor::from_vec(image_mean, (3,), &Device::Cpu)?;
|
||||||
|
let std = Tensor::from_vec(image_std, (3,), &Device::Cpu)?;
|
||||||
|
tensor.broadcast_sub(&mean)?.broadcast_div(&std)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_channel_dimension_format(&self, tensor: &Tensor) -> Result<Tensor> {
|
||||||
|
tensor.permute((2, 0, 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn preprocess(&self, image: &DynamicImage) -> Result<Tensor> {
|
||||||
|
let image = if self.do_resize {
|
||||||
|
self.resize(image)
|
||||||
|
} else {
|
||||||
|
image.clone()
|
||||||
|
};
|
||||||
|
let image = if self.do_center_crop {
|
||||||
|
self.center_crop(&image)
|
||||||
|
} else {
|
||||||
|
image
|
||||||
|
};
|
||||||
|
let tensor = self.to_tensor(&image)?;
|
||||||
|
let tensor = if self.do_rescale {
|
||||||
|
self.rescale(&tensor)?
|
||||||
|
} else {
|
||||||
|
tensor
|
||||||
|
};
|
||||||
|
let tensor = if self.do_normalize {
|
||||||
|
self.normalize(&tensor)?
|
||||||
|
} else {
|
||||||
|
tensor
|
||||||
|
};
|
||||||
|
self.to_channel_dimension_format(&tensor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn calculate_middle(image_size: (u32, u32), center_size: (u32, u32)) -> (u32, u32) {
|
||||||
|
let (width, height) = image_size;
|
||||||
|
let (center_width, center_height) = center_size;
|
||||||
|
let left = if width <= center_width {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
((width as f32 - center_width as f32) / 2.0).ceil() as u32
|
||||||
|
};
|
||||||
|
let top = if height <= center_height {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
((height as f32 - center_height as f32) / 2.0).ceil() as u32
|
||||||
|
};
|
||||||
|
(left, top)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn process_image(
|
||||||
|
image: &DynamicImage,
|
||||||
|
processor: &ImageProcessor,
|
||||||
|
llava_config: &LLaVAConfig,
|
||||||
|
) -> candle::Result<Tensor> {
|
||||||
|
if llava_config.image_aspect_ratio == *"square" {
|
||||||
|
processor.preprocess(image)?.unsqueeze(0)
|
||||||
|
} else if llava_config.image_aspect_ratio == *"anyres" {
|
||||||
|
process_anyres_image(image, processor, &llava_config.image_grid_pinpoints)
|
||||||
|
} else if llava_config.image_aspect_ratio == *"pad" {
|
||||||
|
process_pad_image(image, processor)
|
||||||
|
} else {
|
||||||
|
bail!("Invalid image aspect ratio")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn process_pad_image(image: &DynamicImage, processor: &ImageProcessor) -> Result<Tensor> {
|
||||||
|
let mean_color = processor
|
||||||
|
.image_mean
|
||||||
|
.iter()
|
||||||
|
.map(|x| ((*x) * 255.0) as u8)
|
||||||
|
.collect::<Vec<u8>>();
|
||||||
|
let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
|
||||||
|
let image_padded = expand2square(image, mean_color);
|
||||||
|
processor.preprocess(&image_padded)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn process_anyres_image(
|
||||||
|
image: &DynamicImage,
|
||||||
|
processor: &ImageProcessor,
|
||||||
|
grid_pinpoints: &[(u32, u32)],
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let original_size = image.dimensions();
|
||||||
|
let best_resolution = select_best_resolution(original_size, grid_pinpoints);
|
||||||
|
let image_padded = resize_and_pad_image(image, best_resolution);
|
||||||
|
let image_original_resize = image.resize_exact(
|
||||||
|
processor.size,
|
||||||
|
processor.size,
|
||||||
|
image::imageops::FilterType::CatmullRom,
|
||||||
|
);
|
||||||
|
let mut patches = vec![image_original_resize];
|
||||||
|
for patch in divide_to_patches(&image_padded, processor.crop_size) {
|
||||||
|
patches.push(patch);
|
||||||
|
}
|
||||||
|
let tensors = patches
|
||||||
|
.iter()
|
||||||
|
.map(|patch| processor.preprocess(patch))
|
||||||
|
.collect::<Result<Vec<Tensor>>>()?;
|
||||||
|
Tensor::stack(&tensors, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expand2square(image: &DynamicImage, background_color: Rgb<u8>) -> DynamicImage {
|
||||||
|
let (width, height) = image.dimensions();
|
||||||
|
match width.cmp(&height) {
|
||||||
|
std::cmp::Ordering::Less => {
|
||||||
|
let mut new_image =
|
||||||
|
DynamicImage::from(RgbImage::from_pixel(height, height, background_color));
|
||||||
|
overlay(&mut new_image, image, ((height - width) / 2) as i64, 0);
|
||||||
|
new_image
|
||||||
|
}
|
||||||
|
std::cmp::Ordering::Equal => image.clone(),
|
||||||
|
std::cmp::Ordering::Greater => {
|
||||||
|
let mut new_image =
|
||||||
|
DynamicImage::from(RgbImage::from_pixel(width, width, background_color));
|
||||||
|
overlay(&mut new_image, image, 0, ((width - height) / 2) as i64);
|
||||||
|
new_image
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resize_and_pad_image(image: &DynamicImage, target_resolution: (u32, u32)) -> DynamicImage {
|
||||||
|
let (original_width, original_height) = image.dimensions();
|
||||||
|
let original_width_f = original_width as f32;
|
||||||
|
let original_height_f = original_height as f32;
|
||||||
|
let (target_width, target_height) = target_resolution;
|
||||||
|
let target_width_f = target_width as f32;
|
||||||
|
let target_height_f = target_height as f32;
|
||||||
|
let scale_w = target_width_f / original_width_f;
|
||||||
|
let scale_h = target_height_f / original_height_f;
|
||||||
|
let (new_width, new_height) = if scale_w < scale_h {
|
||||||
|
(
|
||||||
|
target_width,
|
||||||
|
min((original_height_f * scale_w).ceil() as u32, target_height),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
min((original_width_f * scale_h).ceil() as u32, target_width),
|
||||||
|
target_height,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let resized_image = image.resize_exact(
|
||||||
|
new_width,
|
||||||
|
new_height,
|
||||||
|
image::imageops::FilterType::CatmullRom,
|
||||||
|
);
|
||||||
|
let mut new_image = DynamicImage::new_rgb8(target_width, target_height);
|
||||||
|
let (paste_x, paste_y) =
|
||||||
|
calculate_middle((target_width, target_height), (new_width, new_height));
|
||||||
|
overlay(
|
||||||
|
&mut new_image,
|
||||||
|
&resized_image,
|
||||||
|
paste_x.into(),
|
||||||
|
paste_y.into(),
|
||||||
|
);
|
||||||
|
new_image
|
||||||
|
}
|
||||||
|
|
||||||
|
fn divide_to_patches(image: &DynamicImage, patch_size: u32) -> Vec<DynamicImage> {
|
||||||
|
let (width, height) = image.dimensions();
|
||||||
|
let mut patches = Vec::new();
|
||||||
|
for y in (0..height).step_by(patch_size as usize) {
|
||||||
|
for x in (0..width).step_by(patch_size as usize) {
|
||||||
|
let patch = image.crop_imm(x, y, patch_size, patch_size);
|
||||||
|
patches.push(patch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
patches
|
||||||
|
}
|
316
candle-examples/examples/llava/main.rs
Normal file
316
candle-examples/examples/llava/main.rs
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
pub mod constants;
|
||||||
|
pub mod conversation;
|
||||||
|
pub mod image_processor;
|
||||||
|
|
||||||
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
use candle_transformers::models::llama::Cache;
|
||||||
|
|
||||||
|
use anyhow::{bail, Error as E, Result};
|
||||||
|
use candle::{DType, Device, IndexOp, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::llava::config::{
|
||||||
|
HFGenerationConfig, HFLLaVAConfig, HFPreProcessorConfig,
|
||||||
|
};
|
||||||
|
use candle_transformers::models::llava::{config::LLaVAConfig, LLaVA};
|
||||||
|
use clap::Parser;
|
||||||
|
use constants::*;
|
||||||
|
use conversation::Conversation;
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
use image_processor::{process_image, ImageProcessor};
|
||||||
|
use std::io::Write;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about,long_about=None)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long, default_value = "llava-hf/llava-v1.6-vicuna-7b-hf")]
|
||||||
|
model_path: String,
|
||||||
|
#[arg(long, default_value = "tokenizer/tokenizer.json")]
|
||||||
|
tokenizer_path: String,
|
||||||
|
#[arg(long)]
|
||||||
|
model_base: Option<String>,
|
||||||
|
#[arg(long)]
|
||||||
|
image_file: String, // Required
|
||||||
|
#[arg(long)]
|
||||||
|
conv_mode: Option<String>,
|
||||||
|
#[arg(long, default_value_t = 0.2)]
|
||||||
|
temperature: f32,
|
||||||
|
#[arg(long, default_value_t = 512)]
|
||||||
|
max_new_tokens: usize,
|
||||||
|
#[arg(long, action)]
|
||||||
|
hf: bool,
|
||||||
|
#[arg(long, action)]
|
||||||
|
cpu: bool,
|
||||||
|
#[arg(long, action)]
|
||||||
|
no_kv_cache: bool,
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
/// The seed to use when generating random samples. Copy from candle llama. Not exist in python llava.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
//from https://github.com/huggingface/candle/blob/main/candle-examples/examples/clip/main.rs
|
||||||
|
fn load_image<T: AsRef<std::path::Path>>(
|
||||||
|
path: T,
|
||||||
|
processor: &ImageProcessor,
|
||||||
|
llava_config: &LLaVAConfig,
|
||||||
|
dtype: DType,
|
||||||
|
) -> Result<((u32, u32), Tensor)> {
|
||||||
|
let img = image::io::Reader::open(path)?.decode()?;
|
||||||
|
let img_tensor = process_image(&img, processor, llava_config)?;
|
||||||
|
Ok(((img.width(), img.height()), img_tensor.to_dtype(dtype)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model_name_from_path(model_path: &str) -> String {
|
||||||
|
let model_paths: Vec<String> = model_path
|
||||||
|
.trim_matches('/')
|
||||||
|
.split('/')
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect();
|
||||||
|
if model_paths.last().unwrap().starts_with("checkpoint-") {
|
||||||
|
format!(
|
||||||
|
"{}_{}",
|
||||||
|
model_paths[model_paths.len() - 2],
|
||||||
|
model_paths.last().unwrap()
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
model_paths.last().unwrap().to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn duplicate_vec<T>(vec: &[T], n: usize) -> Vec<T>
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
{
|
||||||
|
let mut res = Vec::new();
|
||||||
|
for _ in 0..n {
|
||||||
|
res.extend(vec.to_owned());
|
||||||
|
}
|
||||||
|
res
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_separator<T>(x: Vec<Vec<T>>, sep: Vec<T>) -> Vec<Vec<T>>
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
{
|
||||||
|
let sep = vec![sep];
|
||||||
|
let sep = duplicate_vec(&sep, x.len());
|
||||||
|
let mut res = x
|
||||||
|
.iter()
|
||||||
|
.zip(sep.iter())
|
||||||
|
.flat_map(|(x, y)| vec![x.clone(), y.clone()])
|
||||||
|
.collect::<Vec<Vec<T>>>();
|
||||||
|
res.pop();
|
||||||
|
res
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tokenizer_image_token(
|
||||||
|
prompt: &str,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
image_token_index: i64,
|
||||||
|
llava_config: &LLaVAConfig,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let prompt_chunks = prompt
|
||||||
|
.split("<image>")
|
||||||
|
.map(|s| {
|
||||||
|
tokenizer
|
||||||
|
.encode(s, true)
|
||||||
|
.unwrap()
|
||||||
|
.get_ids()
|
||||||
|
.to_vec()
|
||||||
|
.iter()
|
||||||
|
.map(|x| *x as i64)
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect::<Vec<Vec<i64>>>();
|
||||||
|
let mut input_ids = Vec::new();
|
||||||
|
let mut offset = 0;
|
||||||
|
if !prompt_chunks.is_empty()
|
||||||
|
&& !prompt_chunks[0].is_empty()
|
||||||
|
&& prompt_chunks[0][0] == llava_config.bos_token_id as i64
|
||||||
|
{
|
||||||
|
offset = 1;
|
||||||
|
input_ids.push(prompt_chunks[0][0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for x in insert_separator(
|
||||||
|
prompt_chunks,
|
||||||
|
duplicate_vec(&[image_token_index], offset + 1),
|
||||||
|
)
|
||||||
|
.iter()
|
||||||
|
{
|
||||||
|
input_ids.extend(x[1..].to_vec())
|
||||||
|
}
|
||||||
|
let input_len = input_ids.len();
|
||||||
|
Tensor::from_vec(input_ids, (1, input_len), &Device::Cpu).map_err(E::msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let mut args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
println!("Start loading model");
|
||||||
|
let api = Api::new()?;
|
||||||
|
let api = api.model(args.model_path.clone());
|
||||||
|
let (llava_config, tokenizer, clip_vision_config, image_processor) = if args.hf {
|
||||||
|
let config_filename = api.get("config.json")?;
|
||||||
|
let hf_llava_config: HFLLaVAConfig =
|
||||||
|
serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
|
let generation_config_filename = api.get("generation_config.json")?;
|
||||||
|
let generation_config: HFGenerationConfig =
|
||||||
|
serde_json::from_slice(&std::fs::read(generation_config_filename)?)?;
|
||||||
|
let preprocessor_config_filename = api.get("preprocessor_config.json")?;
|
||||||
|
let preprocessor_config: HFPreProcessorConfig =
|
||||||
|
serde_json::from_slice(&std::fs::read(preprocessor_config_filename)?)?;
|
||||||
|
let llava_config =
|
||||||
|
hf_llava_config.to_llava_config(&generation_config, &preprocessor_config);
|
||||||
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
let clip_vision_config = hf_llava_config.to_clip_vision_config();
|
||||||
|
(
|
||||||
|
llava_config,
|
||||||
|
tokenizer,
|
||||||
|
Some(clip_vision_config),
|
||||||
|
ImageProcessor::from_hf_preprocessor_config(&preprocessor_config),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
let config_filename = api.get("config.json")?;
|
||||||
|
let llava_config: LLaVAConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
|
let tokenizer = Tokenizer::from_file(&args.tokenizer_path)
|
||||||
|
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.tokenizer_path, e)))?;
|
||||||
|
(
|
||||||
|
llava_config.clone(),
|
||||||
|
tokenizer,
|
||||||
|
None,
|
||||||
|
ImageProcessor::from_pretrained(&llava_config.mm_vision_tower.unwrap())?,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let llama_config = llava_config.to_llama_config();
|
||||||
|
let dtype: DType = match llava_config.torch_dtype.as_str() {
|
||||||
|
"float16" => DType::F16,
|
||||||
|
"bfloat16" => DType::BF16,
|
||||||
|
_ => bail!("unsupported dtype"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let eos_token_id = llava_config.eos_token_id;
|
||||||
|
|
||||||
|
println!("setting kv cache");
|
||||||
|
let mut cache = Cache::new(!args.no_kv_cache, dtype, &llama_config, &device)?;
|
||||||
|
|
||||||
|
println!("loading model weights");
|
||||||
|
|
||||||
|
let weight_filenames =
|
||||||
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_filenames, dtype, &device)? };
|
||||||
|
let llava: LLaVA = LLaVA::load(vb, &llava_config, clip_vision_config)?;
|
||||||
|
|
||||||
|
println!("generating conv template");
|
||||||
|
let image_token_se = format!(
|
||||||
|
"{}{}{}",
|
||||||
|
DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_END_TOKEN
|
||||||
|
);
|
||||||
|
let qs = if args.prompt.contains(IMAGE_PLACEHOLDER) {
|
||||||
|
if llava_config.mm_use_im_start_end {
|
||||||
|
args.prompt.replace(IMAGE_PLACEHOLDER, &image_token_se)
|
||||||
|
} else {
|
||||||
|
args.prompt.replace(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN)
|
||||||
|
}
|
||||||
|
} else if llava_config.mm_use_im_start_end {
|
||||||
|
format!("{}\n{}", image_token_se, args.prompt)
|
||||||
|
} else {
|
||||||
|
format!("{}\n{}", DEFAULT_IMAGE_TOKEN, args.prompt)
|
||||||
|
};
|
||||||
|
|
||||||
|
let model_name = get_model_name_from_path(&args.model_path).to_lowercase();
|
||||||
|
let conv_mode = if model_name.contains("llama-2") {
|
||||||
|
"llava_llama_2"
|
||||||
|
} else if model_name.contains("mistral") {
|
||||||
|
"mistral_instruct"
|
||||||
|
} else if model_name.contains("v1.6-34b") {
|
||||||
|
"chatml_direct"
|
||||||
|
} else if model_name.contains("v1") {
|
||||||
|
"llava_v1"
|
||||||
|
} else if model_name.contains("mpt") {
|
||||||
|
"mpt"
|
||||||
|
} else {
|
||||||
|
"llava_v0"
|
||||||
|
};
|
||||||
|
if args.conv_mode.is_some() && args.conv_mode.as_deref() != Some(conv_mode) {
|
||||||
|
println!(
|
||||||
|
"Warning: the model is trained with {}, but you are using {}",
|
||||||
|
conv_mode,
|
||||||
|
args.conv_mode.as_deref().unwrap()
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
args.conv_mode = Some(conv_mode.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut conv = match args.conv_mode {
|
||||||
|
Some(conv_mode) => match conv_mode.as_str() {
|
||||||
|
"chatml_direct" => Conversation::conv_chatml_direct(),
|
||||||
|
"llava_v1" => Conversation::conv_llava_v1(),
|
||||||
|
_ => todo!("not implement yet"),
|
||||||
|
},
|
||||||
|
None => bail!("conv_mode is required"),
|
||||||
|
};
|
||||||
|
conv.append_user_message(Some(&qs));
|
||||||
|
conv.append_assistant_message(None);
|
||||||
|
let prompt = conv.get_prompt();
|
||||||
|
println!("loading image");
|
||||||
|
let (image_size, image_tensor) =
|
||||||
|
load_image(&args.image_file, &image_processor, &llava_config, dtype)
|
||||||
|
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.image_file, e)))?;
|
||||||
|
let image_tensor = image_tensor.to_device(&device)?;
|
||||||
|
|
||||||
|
let mut logits_processor = {
|
||||||
|
let temperature = f64::from(args.temperature);
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
Sampling::All { temperature }
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
// get input tokens
|
||||||
|
let tokens = tokenizer_image_token(
|
||||||
|
&prompt,
|
||||||
|
&tokenizer,
|
||||||
|
llava_config.image_token_index as i64,
|
||||||
|
&llava_config,
|
||||||
|
)?;
|
||||||
|
let mut input_embeds =
|
||||||
|
llava.prepare_inputs_labels_for_multimodal(&tokens, &[image_tensor], &[image_size])?;
|
||||||
|
//inference loop, based on https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs
|
||||||
|
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||||
|
let mut index_pos = 0;
|
||||||
|
for index in 0..args.max_new_tokens {
|
||||||
|
let (_, input_embeds_len, _) = input_embeds.dims3()?;
|
||||||
|
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||||
|
(1, index_pos)
|
||||||
|
} else {
|
||||||
|
(input_embeds_len, 0)
|
||||||
|
};
|
||||||
|
let input = input_embeds.i((.., input_embeds_len.saturating_sub(context_size).., ..))?;
|
||||||
|
let logits = llava.forward(&input, context_index, &mut cache)?; //[1,32000]
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
let (_, input_len, _) = input.dims3()?;
|
||||||
|
index_pos += input_len;
|
||||||
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
|
let next_token_tensor = Tensor::from_vec(vec![next_token], 1, &device)?;
|
||||||
|
let next_embeds = llava.llama.embed(&next_token_tensor)?.unsqueeze(0)?;
|
||||||
|
input_embeds = Tensor::cat(&[input_embeds, next_embeds], 1)?;
|
||||||
|
if next_token == eos_token_id as u32 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
40
candle-examples/examples/llava/readme.md
Normal file
40
candle-examples/examples/llava/readme.md
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# candle-llava
|
||||||
|
|
||||||
|
LLaVA (Large Language-and-Vision Assistant) is an end-to-end trained large
|
||||||
|
multimodal model. This example is from [candle-llava](https://github.com/chenwanqq/candle-llava)
|
||||||
|
|
||||||
|
The code is based on [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA), Hence the llava-hf version of config may perform differently.
|
||||||
|
|
||||||
|
## model zoo
|
||||||
|
* [liuhaotian/LLaVA](https://huggingface.co/liuhaotian)
|
||||||
|
* [llava-hf](https://huggingface.co/llava-hf)
|
||||||
|
|
||||||
|
Right now this has been tested on `liuhaotian/llava-v1.6-vicuna-7b` and
|
||||||
|
`llava-hf/llava-v1.6-vicuna-7b-hf`. Memory usage might have room for optimization.
|
||||||
|
|
||||||
|
## Tokenizer Setup
|
||||||
|
The llava-hf models contain a `tokenizer.json` file so can be used directly with
|
||||||
|
the `-hf` command line flag.
|
||||||
|
|
||||||
|
For the original llava models, you can use the following code to generate the `tokenizer.json` file.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -n llava python=3.10
|
||||||
|
pip install transformers protobuf
|
||||||
|
conda activate llava
|
||||||
|
python -c "from transformers import AutoTokenizer;tokenizer=AutoTokenizer.from_pretrained('liuhaotian/llava-v1.6-vicuna-7b');tokenizer.save_pretrained('tokenizer')"
|
||||||
|
```
|
||||||
|
Then the `tokenizer.json` file should be in `tokenizer/tokenizer.json` (which is the default path).
|
||||||
|
|
||||||
|
|
||||||
|
## eval
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example llava --features cuda -- --image-file "llava_logo.png" --prompt "is this a cat?" --hf # default args, use llava-hf/llava-v1.6-vicuna-7b-hf. image-file is required^_^
|
||||||
|
cargo run --example llava --features cuda -- --model-path liuhaotian/llava-v1.6-vicuna-7b --image-file "llava_logo.png" --prompt "is this a cat?" # use liuhaotian/llava-v1.6-vicuna-7b, tokenizer setup should be done
|
||||||
|
```
|
||||||
|
|
||||||
|
## Major Limitations
|
||||||
|
1. Currently only support llama-2/vicuna llm. Haven't supoort Mistral yet.
|
||||||
|
2. There are some ops like split, nonzero and where are not supported by candle.
|
||||||
|
3. Lack of quantization and LoRA support.
|
@ -54,6 +54,7 @@ impl TextGeneration {
|
|||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
self.tokenizer.clear();
|
self.tokenizer.clear();
|
||||||
|
let dtype = self.model.dtype();
|
||||||
let mut tokens = self
|
let mut tokens = self
|
||||||
.tokenizer
|
.tokenizer
|
||||||
.tokenizer()
|
.tokenizer()
|
||||||
@ -66,7 +67,7 @@ impl TextGeneration {
|
|||||||
Some(token) => token,
|
Some(token) => token,
|
||||||
None => anyhow::bail!("cannot find the </s> token"),
|
None => anyhow::bail!("cannot find the </s> token"),
|
||||||
};
|
};
|
||||||
let mut state = State::new(1, &self.config, &self.device)?;
|
let mut state = State::new(1, &self.config, dtype, &self.device)?;
|
||||||
let mut next_logits = None;
|
let mut next_logits = None;
|
||||||
for &t in tokens.iter() {
|
for &t in tokens.iter() {
|
||||||
let input = Tensor::new(&[t], &self.device)?;
|
let input = Tensor::new(&[t], &self.device)?;
|
||||||
@ -84,7 +85,7 @@ impl TextGeneration {
|
|||||||
Some(logits) => logits,
|
Some(logits) => logits,
|
||||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||||
};
|
};
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
let logits = logits.squeeze(0)?.to_dtype(dtype)?;
|
||||||
let logits = if self.repeat_penalty == 1. {
|
let logits = if self.repeat_penalty == 1. {
|
||||||
logits
|
logits
|
||||||
} else {
|
} else {
|
||||||
@ -210,6 +211,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
config_file: Option<String>,
|
config_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "f32")]
|
||||||
|
dtype: String,
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
#[arg(long, default_value_t = 1.1)]
|
#[arg(long, default_value_t = 1.1)]
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
@ -220,6 +224,7 @@ struct Args {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
use std::str::FromStr;
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
@ -279,7 +284,8 @@ fn main() -> Result<()> {
|
|||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
let dtype = DType::from_str(&args.dtype)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
let model = Model::new(&config, vb.pp("backbone"))?;
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral;
|
|||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
@ -39,11 +39,26 @@ impl TextGeneration {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
temp: Option<f64>,
|
temp: Option<f64>,
|
||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
|
top_k: Option<usize>,
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
let logits_processor = {
|
||||||
|
let temperature = temp.unwrap_or(0.);
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (top_k, top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
@ -159,6 +174,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
@ -196,6 +215,10 @@ struct Args {
|
|||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 64)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// Use the slower dmmv cuda kernel.
|
||||||
|
#[arg(long)]
|
||||||
|
force_dmmv: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -203,6 +226,9 @@ fn main() -> Result<()> {
|
|||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
||||||
|
|
||||||
let _guard = if args.tracing {
|
let _guard = if args.tracing {
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
@ -307,6 +333,7 @@ fn main() -> Result<()> {
|
|||||||
args.seed,
|
args.seed,
|
||||||
args.temperature,
|
args.temperature,
|
||||||
args.top_p,
|
args.top_p,
|
||||||
|
args.top_k,
|
||||||
args.repeat_penalty,
|
args.repeat_penalty,
|
||||||
args.repeat_last_n,
|
args.repeat_last_n,
|
||||||
&device,
|
&device,
|
||||||
|
26
candle-examples/examples/moondream/README.md
Normal file
26
candle-examples/examples/moondream/README.md
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# candle-moondream
|
||||||
|
|
||||||
|
[Moondream](https://github.com/vikhyat/moondream) is a computer-vision model can answer real-world questions about images. It's tiny by today's models, with only 1.6B parameters. That enables it to run on a variety of devices, including mobile phones and edge devices.
|
||||||
|
|
||||||
|
## Running some examples
|
||||||
|
First download an example image
|
||||||
|
```bash
|
||||||
|
$ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
<img src="https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg" width="200">
|
||||||
|
|
||||||
|
Now you can run Moondream from the `candle-examples` crate:
|
||||||
|
```bash
|
||||||
|
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
|
||||||
|
|
||||||
|
avavx: false, neon: true, simd128: false, f16c: false
|
||||||
|
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
|
||||||
|
retrieved the files in 3.395583ms
|
||||||
|
Running on CPU, to run on GPU(metal), build this example with `--features metal`
|
||||||
|
loaded the model in 5.485493792s
|
||||||
|
loaded and encoded the image Tensor[dims 3, 378, 378; f32] in 4.801396417s
|
||||||
|
starting the inference loop
|
||||||
|
The girl is eating a hamburger.<
|
||||||
|
9 tokens generated (0.68 token/s)
|
||||||
|
```
|
343
candle-examples/examples/moondream/main.rs
Normal file
343
candle-examples/examples/moondream/main.rs
Normal file
@ -0,0 +1,343 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::{
|
||||||
|
generation::LogitsProcessor,
|
||||||
|
models::{moondream, quantized_moondream},
|
||||||
|
};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
Moondream(moondream::Model),
|
||||||
|
Quantized(quantized_moondream::Model),
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
verbose_prompt,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, image_embeds: &Tensor, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
println!("starting the inference loop");
|
||||||
|
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||||
|
if tokens.is_empty() {
|
||||||
|
anyhow::bail!("Empty prompts are not supported in the Moondream model.")
|
||||||
|
}
|
||||||
|
if self.verbose_prompt {
|
||||||
|
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||||
|
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
println!("{id:7} -> '{token}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut tokens = tokens.get_ids().to_vec();
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
|
||||||
|
// Moondream tokenizer bos_token and eos_token is "<|endoftext|>"
|
||||||
|
// https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json
|
||||||
|
let special_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||||
|
Some(token) => *token,
|
||||||
|
None => anyhow::bail!("cannot find the special token"),
|
||||||
|
};
|
||||||
|
let (bos_token, eos_token) = (special_token, special_token);
|
||||||
|
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
let mut load_t = std::time::Duration::from_secs_f64(0f64);
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = if index > 0 {
|
||||||
|
match self.model {
|
||||||
|
Model::Moondream(ref mut model) => model.text_model.forward(&input)?,
|
||||||
|
Model::Quantized(ref mut model) => model.text_model.forward(&input)?,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let bos_token = Tensor::new(&[bos_token], &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = match self.model {
|
||||||
|
Model::Moondream(ref mut model) => {
|
||||||
|
model
|
||||||
|
.text_model
|
||||||
|
.forward_with_img(&bos_token, &input, image_embeds)?
|
||||||
|
}
|
||||||
|
Model::Quantized(ref mut model) => {
|
||||||
|
model
|
||||||
|
.text_model
|
||||||
|
.forward_with_img(&bos_token, &input, image_embeds)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
load_t = start_gen.elapsed();
|
||||||
|
println!("load_t: {:?}", load_t);
|
||||||
|
logits
|
||||||
|
};
|
||||||
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token || tokens.ends_with(&[27, 10619, 29] /* <END> */) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||||
|
print!("{token}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let dt = start_gen.elapsed() - load_t;
|
||||||
|
println!(
|
||||||
|
"\ngenerated in {} seconds\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
dt.as_secs_f64(),
|
||||||
|
(generated_tokens - 1) as f64 / dt.as_secs_f64()
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Display the token for the specified prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
verbose_prompt: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 0)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
#[arg(long, default_value_t = 5000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.0)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
|
||||||
|
/// Use f16 precision for all the computations rather than f32.
|
||||||
|
#[arg(long)]
|
||||||
|
f16: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
|
/// (3, 378, 378).
|
||||||
|
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> candle::Result<Tensor> {
|
||||||
|
let img = image::io::Reader::open(p)?
|
||||||
|
.decode()
|
||||||
|
.map_err(candle::Error::wrap)?
|
||||||
|
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
let data = img.into_raw();
|
||||||
|
let data = Tensor::from_vec(data, (378, 378, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||||
|
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
|
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
|
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||||
|
.broadcast_sub(&mean)?
|
||||||
|
.broadcast_div(&std)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = hf_hub::api::tokio::Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id.to_string(),
|
||||||
|
None => {
|
||||||
|
if args.quantized {
|
||||||
|
"santiagomed/candle-moondream".to_string()
|
||||||
|
} else {
|
||||||
|
"vikhyatk/moondream2".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let model_file = match args.model_file {
|
||||||
|
Some(m) => m.into(),
|
||||||
|
None => {
|
||||||
|
if args.quantized {
|
||||||
|
repo.get("model-q4_0.gguf").await?
|
||||||
|
} else {
|
||||||
|
repo.get("model.safetensors").await?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let tokenizer = match args.tokenizer_file {
|
||||||
|
Some(m) => m.into(),
|
||||||
|
None => repo.get("tokenizer.json").await?,
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let config = moondream::Config::v2();
|
||||||
|
let dtype = if args.quantized {
|
||||||
|
if args.f16 {
|
||||||
|
anyhow::bail!("Quantized model does not support f16");
|
||||||
|
}
|
||||||
|
DType::F32
|
||||||
|
} else if device.is_cuda() || args.f16 {
|
||||||
|
DType::F16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let model = if args.quantized {
|
||||||
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||||
|
&model_file,
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
let model = quantized_moondream::Model::new(&config, vb)?;
|
||||||
|
Model::Quantized(model)
|
||||||
|
} else {
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||||
|
let model = moondream::Model::new(&config, vb)?;
|
||||||
|
Model::Moondream(model)
|
||||||
|
};
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let image = load_image(args.image)?
|
||||||
|
.to_device(&device)?
|
||||||
|
.to_dtype(dtype)?;
|
||||||
|
let image_embeds = image.unsqueeze(0)?;
|
||||||
|
let image_embeds = match model {
|
||||||
|
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
|
||||||
|
Model::Quantized(ref m) => image_embeds.apply(m.vision_encoder())?,
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"loaded and encoded the image {image:?} in {:?}",
|
||||||
|
start.elapsed()
|
||||||
|
);
|
||||||
|
|
||||||
|
let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", args.prompt);
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
args.verbose_prompt,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&prompt, &image_embeds, args.sample_len)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
36
candle-examples/examples/olmo/README.md
Normal file
36
candle-examples/examples/olmo/README.md
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# candle-olmo: Open Language Models designed to enable the science of language models
|
||||||
|
|
||||||
|
OLMo is a series of Open Language Models designed to enable the science of language models.
|
||||||
|
|
||||||
|
- **Project Page:** https://allenai.org/olmo
|
||||||
|
- **Paper:** [Link](https://arxiv.org/abs/2402.00838)
|
||||||
|
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
|
||||||
|
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
|
||||||
|
<!-- - **Press release:** TODO -->
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example olmo --release -- --prompt "It is only with the heart that one can see rightly"
|
||||||
|
|
||||||
|
avx: true, neon: false, simd128: false, f16c: true
|
||||||
|
temp: 0.20 repeat-penalty: 1.10 repeat-last-n: 64
|
||||||
|
retrieved the files in 354.977µs
|
||||||
|
loaded the model in 19.87779666s
|
||||||
|
It is only with the heart that one can see rightly; what is essential is invisible to the eye.
|
||||||
|
```
|
||||||
|
|
||||||
|
Various model sizes are available via the `--model` argument.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example olmo --release -- --model 1.7-7b --prompt 'It is only with the heart that one can see rightly'
|
||||||
|
|
||||||
|
avx: true, neon: false, simd128: false, f16c: true
|
||||||
|
temp: 0.20 repeat-penalty: 1.10 repeat-last-n: 64
|
||||||
|
retrieved the files in 1.226087ms
|
||||||
|
loaded the model in 171.274578609s
|
||||||
|
It is only with the heart that one can see rightly; what is essential is invisible to the eye.”
|
||||||
|
~ Antoine de Saint-Exupery, The Little Prince
|
||||||
|
I am a big fan of this quote. It reminds me that I need to be open and aware of my surroundings in order to truly appreciate them.
|
||||||
|
```
|
||||||
|
|
284
candle-examples/examples/olmo/main.rs
Normal file
284
candle-examples/examples/olmo/main.rs
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle_transformers::models::olmo::{Config, Model as OLMo};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
OLMo(OLMo),
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, false)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||||
|
};
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = match &mut self.model {
|
||||||
|
Model::OLMo(m) => m.forward(&input, start_pos)?,
|
||||||
|
};
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "1b")]
|
||||||
|
W1b,
|
||||||
|
#[value(name = "7b")]
|
||||||
|
W7b,
|
||||||
|
#[value(name = "7b-twin-2t")]
|
||||||
|
W7bTwin2T,
|
||||||
|
#[value(name = "1.7-7b")]
|
||||||
|
V1_7W7b,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 1000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "1b")]
|
||||||
|
model: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id,
|
||||||
|
None => match args.model {
|
||||||
|
Which::W1b => "allenai/OLMo-1B-hf".to_string(),
|
||||||
|
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
|
||||||
|
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
|
||||||
|
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => match args.model {
|
||||||
|
Which::W1b => {
|
||||||
|
vec![repo.get("model.safetensors")?]
|
||||||
|
}
|
||||||
|
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = {
|
||||||
|
let config_filename = repo.get("config.json")?;
|
||||||
|
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
|
config
|
||||||
|
};
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let model = {
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = OLMo::new(&config, vb)?;
|
||||||
|
Model::OLMo(model)
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -1,8 +1,9 @@
|
|||||||
# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
|
# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
|
||||||
|
|
||||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and
|
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5),
|
||||||
[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using
|
[Phi-2](https://huggingface.co/microsoft/phi-2), and
|
||||||
only 1.3 and 2.7 billion parameters but with state of the art performance compared to
|
[Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) are language models using
|
||||||
|
only 1.3, 2.7, and 3.8 billion parameters but with state of the art performance compared to
|
||||||
models with up to 10 billion parameters.
|
models with up to 10 billion parameters.
|
||||||
|
|
||||||
The candle implementation provides both the standard version as well as a
|
The candle implementation provides both the standard version as well as a
|
||||||
|
@ -7,11 +7,13 @@ extern crate accelerate_src;
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||||
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
|
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
|
||||||
|
use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3};
|
||||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, IndexOp, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
@ -20,13 +22,14 @@ use tokenizers::Tokenizer;
|
|||||||
enum Model {
|
enum Model {
|
||||||
MixFormer(MixFormer),
|
MixFormer(MixFormer),
|
||||||
Phi(Phi),
|
Phi(Phi),
|
||||||
|
Phi3(Phi3),
|
||||||
Quantized(QMixFormer),
|
Quantized(QMixFormer),
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
model: Model,
|
model: Model,
|
||||||
device: Device,
|
device: Device,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: TokenOutputStream,
|
||||||
logits_processor: LogitsProcessor,
|
logits_processor: LogitsProcessor,
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
@ -49,7 +52,7 @@ impl TextGeneration {
|
|||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
logits_processor,
|
logits_processor,
|
||||||
repeat_penalty,
|
repeat_penalty,
|
||||||
repeat_last_n,
|
repeat_last_n,
|
||||||
@ -61,7 +64,11 @@ impl TextGeneration {
|
|||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
let tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?;
|
||||||
if tokens.is_empty() {
|
if tokens.is_empty() {
|
||||||
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
||||||
}
|
}
|
||||||
@ -73,13 +80,14 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
let mut tokens = tokens.get_ids().to_vec();
|
let mut tokens = tokens.get_ids().to_vec();
|
||||||
let mut generated_tokens = 0usize;
|
let mut generated_tokens = 0usize;
|
||||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||||
Some(token) => *token,
|
Some(token) => token,
|
||||||
None => anyhow::bail!("cannot find the endoftext token"),
|
None => anyhow::bail!("cannot find the endoftext token"),
|
||||||
};
|
};
|
||||||
print!("{prompt}");
|
print!("{prompt}");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
|
let mut pos = 0;
|
||||||
for index in 0..sample_len {
|
for index in 0..sample_len {
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
@ -88,6 +96,7 @@ impl TextGeneration {
|
|||||||
Model::MixFormer(m) => m.forward(&input)?,
|
Model::MixFormer(m) => m.forward(&input)?,
|
||||||
Model::Phi(m) => m.forward(&input)?,
|
Model::Phi(m) => m.forward(&input)?,
|
||||||
Model::Quantized(m) => m.forward(&input)?,
|
Model::Quantized(m) => m.forward(&input)?,
|
||||||
|
Model::Phi3(m) => m.forward(&input, pos)?.i((.., 0, ..))?,
|
||||||
};
|
};
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
let logits = if self.repeat_penalty == 1. {
|
let logits = if self.repeat_penalty == 1. {
|
||||||
@ -107,9 +116,11 @@ impl TextGeneration {
|
|||||||
if next_token == eos_token {
|
if next_token == eos_token {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
print!("{token}");
|
print!("{t}");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
pos += context_size;
|
||||||
}
|
}
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
@ -128,6 +139,10 @@ enum WhichModel {
|
|||||||
V1_5,
|
V1_5,
|
||||||
#[value(name = "2")]
|
#[value(name = "2")]
|
||||||
V2,
|
V2,
|
||||||
|
#[value(name = "3")]
|
||||||
|
V3,
|
||||||
|
#[value(name = "3-medium")]
|
||||||
|
V3Medium,
|
||||||
#[value(name = "2-old")]
|
#[value(name = "2-old")]
|
||||||
V2Old,
|
V2Old,
|
||||||
PuffinPhiV2,
|
PuffinPhiV2,
|
||||||
@ -196,6 +211,10 @@ struct Args {
|
|||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 64)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The dtype to be used for running the model, e.g. f32, bf16, or f16.
|
||||||
|
#[arg(long)]
|
||||||
|
dtype: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -236,6 +255,8 @@ fn main() -> Result<()> {
|
|||||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||||
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
||||||
|
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
|
||||||
|
WhichModel::V3Medium => "microsoft/Phi-3-medium-4k-instruct".to_string(),
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
"lmz/candle-quantized-phi".to_string()
|
"lmz/candle-quantized-phi".to_string()
|
||||||
}
|
}
|
||||||
@ -253,9 +274,11 @@ fn main() -> Result<()> {
|
|||||||
WhichModel::V1 => "refs/pr/8".to_string(),
|
WhichModel::V1 => "refs/pr/8".to_string(),
|
||||||
WhichModel::V1_5 => "refs/pr/73".to_string(),
|
WhichModel::V1_5 => "refs/pr/73".to_string(),
|
||||||
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
||||||
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::V2
|
||||||
"main".to_string()
|
| WhichModel::V3
|
||||||
}
|
| WhichModel::V3Medium
|
||||||
|
| WhichModel::PuffinPhiV2
|
||||||
|
| WhichModel::PhiHermes => "main".to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -264,9 +287,12 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer_filename = match args.tokenizer {
|
let tokenizer_filename = match args.tokenizer {
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => match args.model {
|
None => match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => {
|
WhichModel::V1
|
||||||
repo.get("tokenizer.json")?
|
| WhichModel::V1_5
|
||||||
}
|
| WhichModel::V2
|
||||||
|
| WhichModel::V2Old
|
||||||
|
| WhichModel::V3
|
||||||
|
| WhichModel::V3Medium => repo.get("tokenizer.json")?,
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||||
}
|
}
|
||||||
@ -282,14 +308,19 @@ fn main() -> Result<()> {
|
|||||||
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
||||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||||
|
WhichModel::V3 | WhichModel::V3Medium => anyhow::bail!(
|
||||||
|
"use the quantized or quantized-phi examples for quantized phi-v3"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||||
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors(
|
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 | WhichModel::V3Medium => {
|
||||||
&repo,
|
candle_examples::hub_load_safetensors(
|
||||||
"model.safetensors.index.json",
|
&repo,
|
||||||
)?,
|
"model.safetensors.index.json",
|
||||||
|
)?
|
||||||
|
}
|
||||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
|
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
|
||||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
|
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
|
||||||
}
|
}
|
||||||
@ -306,6 +337,9 @@ fn main() -> Result<()> {
|
|||||||
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
||||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||||
|
WhichModel::V3 | WhichModel::V3Medium => {
|
||||||
|
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let model = if args.quantized {
|
let model = if args.quantized {
|
||||||
@ -320,7 +354,19 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
Model::Quantized(model)
|
Model::Quantized(model)
|
||||||
} else {
|
} else {
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
let dtype = match args.dtype {
|
||||||
|
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
|
||||||
|
None => {
|
||||||
|
if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium)
|
||||||
|
&& device.is_cuda()
|
||||||
|
{
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
||||||
let config_filename = repo.get("config.json")?;
|
let config_filename = repo.get("config.json")?;
|
||||||
@ -329,6 +375,13 @@ fn main() -> Result<()> {
|
|||||||
let phi = Phi::new(&config, vb)?;
|
let phi = Phi::new(&config, vb)?;
|
||||||
Model::Phi(phi)
|
Model::Phi(phi)
|
||||||
}
|
}
|
||||||
|
WhichModel::V3 | WhichModel::V3Medium => {
|
||||||
|
let config_filename = repo.get("config.json")?;
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let config: Phi3Config = serde_json::from_str(&config)?;
|
||||||
|
let phi3 = Phi3::new(&config, vb)?;
|
||||||
|
Model::Phi3(phi3)
|
||||||
|
}
|
||||||
WhichModel::V2Old => {
|
WhichModel::V2Old => {
|
||||||
let config = config();
|
let config = config();
|
||||||
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
|
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
|
||||||
@ -421,6 +474,10 @@ fn mmlu<P: AsRef<std::path::Path>>(
|
|||||||
m.clear_kv_cache();
|
m.clear_kv_cache();
|
||||||
m.forward(&input)?
|
m.forward(&input)?
|
||||||
}
|
}
|
||||||
|
Model::Phi3(m) => {
|
||||||
|
m.clear_kv_cache();
|
||||||
|
m.forward(&input, 0)?
|
||||||
|
}
|
||||||
Model::Quantized(m) => {
|
Model::Quantized(m) => {
|
||||||
m.clear_kv_cache();
|
m.clear_kv_cache();
|
||||||
m.forward(&input)?
|
m.forward(&input)?
|
||||||
|
325
candle-examples/examples/quantized-phi/main.rs
Normal file
325
candle-examples/examples/quantized-phi/main.rs
Normal file
@ -0,0 +1,325 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
use std::io::Write;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
use candle::quantized::gguf_file;
|
||||||
|
use candle::Tensor;
|
||||||
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_transformers::models::quantized_llama::ModelWeights as Phi3b;
|
||||||
|
use candle_transformers::models::quantized_phi::ModelWeights as Phi2;
|
||||||
|
use candle_transformers::models::quantized_phi3::ModelWeights as Phi3;
|
||||||
|
|
||||||
|
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "phi-2")]
|
||||||
|
Phi2,
|
||||||
|
#[value(name = "phi-3")]
|
||||||
|
Phi3,
|
||||||
|
/// Alternative implementation of phi-3, based on llama.
|
||||||
|
#[value(name = "phi-3b")]
|
||||||
|
Phi3b,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||||
|
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||||
|
/// is preserved.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
/// The tokenizer config in json format.
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||||
|
#[arg(long, default_value_t = 0.8)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Process prompt elements separately.
|
||||||
|
#[arg(long)]
|
||||||
|
split_prompt: bool,
|
||||||
|
|
||||||
|
/// Run on CPU rather than GPU even if a GPU is available.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "phi-3b")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||||
|
let tokenizer_path = match &self.tokenizer {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let repo = match self.which {
|
||||||
|
Which::Phi2 => "microsoft/phi-2",
|
||||||
|
Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct",
|
||||||
|
};
|
||||||
|
let api = api.model(repo.to_string());
|
||||||
|
api.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||||
|
let model_path = match &self.model {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let (repo, filename, revision) = match self.which {
|
||||||
|
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf", "main"),
|
||||||
|
Which::Phi3 => (
|
||||||
|
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
||||||
|
"Phi-3-mini-4k-instruct-q4.gguf",
|
||||||
|
"main",
|
||||||
|
),
|
||||||
|
Which::Phi3b => (
|
||||||
|
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
||||||
|
"Phi-3-mini-4k-instruct-q4.gguf",
|
||||||
|
"5eef2ce24766d31909c0b269fe90c817a8f263fb",
|
||||||
|
),
|
||||||
|
};
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
api.repo(hf_hub::Repo::with_revision(
|
||||||
|
repo.to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
revision.to_string(),
|
||||||
|
))
|
||||||
|
.get(filename)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(model_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_size(size_in_bytes: usize) -> String {
|
||||||
|
if size_in_bytes < 1_000 {
|
||||||
|
format!("{}B", size_in_bytes)
|
||||||
|
} else if size_in_bytes < 1_000_000 {
|
||||||
|
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||||
|
} else if size_in_bytes < 1_000_000_000 {
|
||||||
|
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||||
|
} else {
|
||||||
|
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
Phi2(Phi2),
|
||||||
|
Phi3(Phi3),
|
||||||
|
Phi3b(Phi3b),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Phi2(m) => m.forward(xs, pos),
|
||||||
|
Self::Phi3(m) => m.forward(xs, pos),
|
||||||
|
Self::Phi3b(m) => m.forward(xs, pos),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let model_path = args.model()?;
|
||||||
|
let mut file = std::fs::File::open(&model_path)?;
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let mut model = {
|
||||||
|
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
||||||
|
let mut total_size_in_bytes = 0;
|
||||||
|
for (_, tensor) in model.tensor_infos.iter() {
|
||||||
|
let elem_count = tensor.shape.elem_count();
|
||||||
|
total_size_in_bytes +=
|
||||||
|
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"loaded {:?} tensors ({}) in {:.2}s",
|
||||||
|
model.tensor_infos.len(),
|
||||||
|
&format_size(total_size_in_bytes),
|
||||||
|
start.elapsed().as_secs_f32(),
|
||||||
|
);
|
||||||
|
match args.which {
|
||||||
|
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
|
||||||
|
Which::Phi3 => Model::Phi3(Phi3::from_gguf(
|
||||||
|
args.use_flash_attn,
|
||||||
|
model,
|
||||||
|
&mut file,
|
||||||
|
&device,
|
||||||
|
)?),
|
||||||
|
Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("model built");
|
||||||
|
|
||||||
|
let tokenizer = args.tokenizer()?;
|
||||||
|
let mut tos = TokenOutputStream::new(tokenizer);
|
||||||
|
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||||
|
print!("{}", &prompt_str);
|
||||||
|
let tokens = tos
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt_str, true)
|
||||||
|
.map_err(anyhow::Error::msg)?;
|
||||||
|
let tokens = tokens.get_ids();
|
||||||
|
let to_sample = args.sample_len.saturating_sub(1);
|
||||||
|
let mut all_tokens = vec![];
|
||||||
|
let mut logits_processor = {
|
||||||
|
let temperature = args.temperature;
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (args.top_k, args.top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
|
let mut next_token = if !args.split_prompt {
|
||||||
|
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, 0)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
logits_processor.sample(&logits)?
|
||||||
|
} else {
|
||||||
|
let mut next_token = 0;
|
||||||
|
for (pos, token) in tokens.iter().enumerate() {
|
||||||
|
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, pos)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
next_token = logits_processor.sample(&logits)?
|
||||||
|
}
|
||||||
|
next_token
|
||||||
|
};
|
||||||
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
let eos_token = *tos
|
||||||
|
.tokenizer()
|
||||||
|
.get_vocab(true)
|
||||||
|
.get("<|endoftext|>")
|
||||||
|
.unwrap();
|
||||||
|
let start_post_prompt = std::time::Instant::now();
|
||||||
|
let mut sampled = 0;
|
||||||
|
for index in 0..to_sample {
|
||||||
|
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, tokens.len() + index)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = if args.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
args.repeat_penalty,
|
||||||
|
&all_tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
next_token = logits_processor.sample(&logits)?;
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
sampled += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
let dt = start_post_prompt.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||||
|
tokens.len(),
|
||||||
|
tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"{sampled:4} tokens generated: {:.2} token/s",
|
||||||
|
sampled as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the
|
|||||||
`tensor-tools` command line utility via:
|
`tensor-tools` command line utility via:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
## Using custom models
|
## Using custom models
|
||||||
|
@ -10,7 +10,7 @@ use tokenizers::Tokenizer;
|
|||||||
|
|
||||||
use candle::quantized::{ggml_file, gguf_file};
|
use candle::quantized::{ggml_file, gguf_file};
|
||||||
use candle::Tensor;
|
use candle::Tensor;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_transformers::models::quantized_llama as model;
|
use candle_transformers::models::quantized_llama as model;
|
||||||
@ -67,6 +67,10 @@ enum Which {
|
|||||||
Mixtral,
|
Mixtral,
|
||||||
#[value(name = "mixtral-instruct")]
|
#[value(name = "mixtral-instruct")]
|
||||||
MixtralInstruct,
|
MixtralInstruct,
|
||||||
|
#[value(name = "llama3-8b")]
|
||||||
|
L8b,
|
||||||
|
#[value(name = "phi3")]
|
||||||
|
Phi3,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -82,7 +86,9 @@ impl Which {
|
|||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode
|
| Self::L34bCode
|
||||||
| Self::Leo7b
|
| Self::Leo7b
|
||||||
| Self::Leo13b => false,
|
| Self::Leo13b
|
||||||
|
| Self::L8b
|
||||||
|
| Self::Phi3 => false,
|
||||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||||
// same way. Starling is a fine tuned version of OpenChat.
|
// same way. Starling is a fine tuned version of OpenChat.
|
||||||
Self::OpenChat35
|
Self::OpenChat35
|
||||||
@ -116,7 +122,9 @@ impl Which {
|
|||||||
| Self::Mistral7bInstruct
|
| Self::Mistral7bInstruct
|
||||||
| Self::Mistral7bInstructV02
|
| Self::Mistral7bInstructV02
|
||||||
| Self::OpenChat35
|
| Self::OpenChat35
|
||||||
| Self::Starling7bAlpha => false,
|
| Self::Starling7bAlpha
|
||||||
|
| Self::L8b
|
||||||
|
| Self::Phi3 => false,
|
||||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -140,33 +148,37 @@ impl Which {
|
|||||||
| Self::Mistral7bInstruct
|
| Self::Mistral7bInstruct
|
||||||
| Self::Mistral7bInstructV02
|
| Self::Mistral7bInstructV02
|
||||||
| Self::Zephyr7bAlpha
|
| Self::Zephyr7bAlpha
|
||||||
| Self::Zephyr7bBeta => false,
|
| Self::Zephyr7bBeta
|
||||||
|
| Self::L8b
|
||||||
|
| Self::Phi3 => false,
|
||||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tokenizer_repo(&self) -> &'static str {
|
fn tokenizer_repo(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Which::L7b
|
Self::L7b
|
||||||
| Which::L13b
|
| Self::L13b
|
||||||
| Which::L70b
|
| Self::L70b
|
||||||
| Which::L7bChat
|
| Self::L7bChat
|
||||||
| Which::L13bChat
|
| Self::L13bChat
|
||||||
| Which::L70bChat
|
| Self::L70bChat
|
||||||
| Which::L7bCode
|
| Self::L7bCode
|
||||||
| Which::L13bCode
|
| Self::L13bCode
|
||||||
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
|
| Self::L34bCode => "hf-internal-testing/llama-tokenizer",
|
||||||
Which::Leo7b => "LeoLM/leo-hessianai-7b",
|
Self::Leo7b => "LeoLM/leo-hessianai-7b",
|
||||||
Which::Leo13b => "LeoLM/leo-hessianai-13b",
|
Self::Leo13b => "LeoLM/leo-hessianai-13b",
|
||||||
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
Self::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
||||||
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
Self::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
Which::Mistral7b
|
Self::Mistral7b
|
||||||
| Which::Mistral7bInstruct
|
| Self::Mistral7bInstruct
|
||||||
| Which::Mistral7bInstructV02
|
| Self::Mistral7bInstructV02
|
||||||
| Which::Zephyr7bAlpha
|
| Self::Zephyr7bAlpha
|
||||||
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
| Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
||||||
Which::OpenChat35 => "openchat/openchat_3.5",
|
Self::OpenChat35 => "openchat/openchat_3.5",
|
||||||
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||||
|
Self::L8b => "meta-llama/Meta-Llama-3-8B",
|
||||||
|
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -200,6 +212,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
@ -235,6 +251,10 @@ struct Args {
|
|||||||
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
|
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
gqa: Option<usize>,
|
gqa: Option<usize>,
|
||||||
|
|
||||||
|
/// Use the slower dmmv cuda kernel.
|
||||||
|
#[arg(long)]
|
||||||
|
force_dmmv: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -314,10 +334,28 @@ impl Args {
|
|||||||
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
||||||
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
||||||
),
|
),
|
||||||
|
// TODO: swap to TheBloke model when available
|
||||||
|
Which::L8b => (
|
||||||
|
"QuantFactory/Meta-Llama-3-8B-GGUF",
|
||||||
|
"Meta-Llama-3-8B.Q4_K_S.gguf",
|
||||||
|
),
|
||||||
|
Which::Phi3 => (
|
||||||
|
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
||||||
|
"Phi-3-mini-4k-instruct-q4.gguf",
|
||||||
|
),
|
||||||
|
};
|
||||||
|
let revision = if self.which == Which::Phi3 {
|
||||||
|
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
|
||||||
|
} else {
|
||||||
|
"main"
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
api.repo(hf_hub::Repo::with_revision(
|
||||||
api.get(filename)?
|
repo.to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
revision.to_string(),
|
||||||
|
))
|
||||||
|
.get(filename)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Ok(model_path)
|
Ok(model_path)
|
||||||
@ -341,11 +379,13 @@ fn main() -> anyhow::Result<()> {
|
|||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let temperature = if args.temperature == 0. {
|
|
||||||
None
|
#[cfg(feature = "cuda")]
|
||||||
} else {
|
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
||||||
Some(args.temperature)
|
|
||||||
};
|
candle::cuda::set_gemm_reduced_precision_f16(true);
|
||||||
|
candle::cuda::set_gemm_reduced_precision_bf16(true);
|
||||||
|
|
||||||
let _guard = if args.tracing {
|
let _guard = if args.tracing {
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
@ -413,7 +453,9 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L13bCode
|
| Which::L13bCode
|
||||||
| Which::L34bCode
|
| Which::L34bCode
|
||||||
| Which::Leo7b
|
| Which::Leo7b
|
||||||
| Which::Leo13b => 1,
|
| Which::Leo13b
|
||||||
|
| Which::L8b
|
||||||
|
| Which::Phi3 => 1,
|
||||||
Which::Mixtral
|
Which::Mixtral
|
||||||
| Which::MixtralInstruct
|
| Which::MixtralInstruct
|
||||||
| Which::Mistral7b
|
| Which::Mistral7b
|
||||||
@ -492,7 +534,20 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt_tokens
|
prompt_tokens
|
||||||
};
|
};
|
||||||
let mut all_tokens = vec![];
|
let mut all_tokens = vec![];
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
let mut logits_processor = {
|
||||||
|
let temperature = args.temperature;
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (args.top_k, args.top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
let start_prompt_processing = std::time::Instant::now();
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
let mut next_token = if !args.split_prompt {
|
let mut next_token = if !args.split_prompt {
|
||||||
@ -517,11 +572,14 @@ fn main() -> anyhow::Result<()> {
|
|||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let eos_token = if args.which.is_open_chat() {
|
let eos_token = match args.which {
|
||||||
"<|end_of_turn|>"
|
Which::L8b => "<|end_of_text|>",
|
||||||
} else {
|
_ => match args.which.is_open_chat() {
|
||||||
"</s>"
|
true => "<|end_of_turn|>",
|
||||||
|
false => "</s>",
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||||
let start_post_prompt = std::time::Instant::now();
|
let start_post_prompt = std::time::Instant::now();
|
||||||
let mut sampled = 0;
|
let mut sampled = 0;
|
||||||
|
27
candle-examples/examples/qwen/README.md
Normal file
27
candle-examples/examples/qwen/README.md
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# candle-qwen: large language model series from Alibaba Cloud
|
||||||
|
|
||||||
|
Qwen 1.5 is a series of large language models that provide strong performances
|
||||||
|
on English and Chinese.
|
||||||
|
|
||||||
|
- [Blog post](https://qwenlm.github.io/blog/qwen1.5/) introducing Qwen1.5.
|
||||||
|
- [Model card](https://huggingface.co/Qwen/Qwen1.5-0.5B) on the HuggingFace Hub.
|
||||||
|
- [Blog post](https://qwenlm.github.io/blog/qwen-moe/) for the
|
||||||
|
mixture-of-experts (MoE) variant.
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example qwen --release -- --prompt "Hello there "
|
||||||
|
```
|
||||||
|
|
||||||
|
Various model sizes are available via the `--model` argument, including the MoE
|
||||||
|
variant.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example qwen --release -- --model moe-a2.7b --prompt 'def print_prime(n: int): '
|
||||||
|
def print_prime(n: int): # n is the number of primes to be printed
|
||||||
|
for i in range(2, n + 1):
|
||||||
|
if all(i % j != 0 for j in range(2, i)):
|
||||||
|
print(i)
|
||||||
|
```
|
||||||
|
|
@ -7,7 +7,8 @@ extern crate accelerate_src;
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use candle_transformers::models::qwen2::{Config, Model};
|
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
|
||||||
|
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
@ -16,6 +17,20 @@ use candle_transformers::generation::LogitsProcessor;
|
|||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
Base(ModelBase),
|
||||||
|
Moe(ModelMoe),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Moe(ref mut m) => m.forward(xs, s),
|
||||||
|
Self::Base(ref mut m) => m.forward(xs, s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
model: Model,
|
model: Model,
|
||||||
device: Device,
|
device: Device,
|
||||||
@ -127,6 +142,16 @@ enum WhichModel {
|
|||||||
W14b,
|
W14b,
|
||||||
#[value(name = "72b")]
|
#[value(name = "72b")]
|
||||||
W72b,
|
W72b,
|
||||||
|
#[value(name = "moe-a2.7b")]
|
||||||
|
MoeA27b,
|
||||||
|
#[value(name = "2-0.5b")]
|
||||||
|
W2_0_5b,
|
||||||
|
#[value(name = "2-1.5b")]
|
||||||
|
W2_1_5b,
|
||||||
|
#[value(name = "2-7b")]
|
||||||
|
W2_7b,
|
||||||
|
#[value(name = "2-72b")]
|
||||||
|
W2_72b,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -217,15 +242,20 @@ fn main() -> Result<()> {
|
|||||||
let model_id = match args.model_id {
|
let model_id = match args.model_id {
|
||||||
Some(model_id) => model_id,
|
Some(model_id) => model_id,
|
||||||
None => {
|
None => {
|
||||||
let size = match args.model {
|
let (version, size) = match args.model {
|
||||||
WhichModel::W0_5b => "0.5B",
|
WhichModel::W2_0_5b => ("2", "0.5B"),
|
||||||
WhichModel::W1_8b => "1.8B",
|
WhichModel::W2_1_5b => ("2", "1.5B"),
|
||||||
WhichModel::W4b => "4B",
|
WhichModel::W2_7b => ("2", "7B"),
|
||||||
WhichModel::W7b => "7B",
|
WhichModel::W2_72b => ("2", "72B"),
|
||||||
WhichModel::W14b => "14B",
|
WhichModel::W0_5b => ("1.5", "0.5B"),
|
||||||
WhichModel::W72b => "72B",
|
WhichModel::W1_8b => ("1.5", "1.8B"),
|
||||||
|
WhichModel::W4b => ("1.5", "4B"),
|
||||||
|
WhichModel::W7b => ("1.5", "7B"),
|
||||||
|
WhichModel::W14b => ("1.5", "14B"),
|
||||||
|
WhichModel::W72b => ("1.5", "72B"),
|
||||||
|
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
|
||||||
};
|
};
|
||||||
format!("Qwen/Qwen1.5-{size}")
|
format!("Qwen/Qwen{version}-{size}")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
@ -243,8 +273,16 @@ fn main() -> Result<()> {
|
|||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => match args.model {
|
None => match args.model {
|
||||||
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
|
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
|
||||||
WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
|
vec![repo.get("model.safetensors")?]
|
||||||
|
}
|
||||||
|
WhichModel::W4b
|
||||||
|
| WhichModel::W7b
|
||||||
|
| WhichModel::W2_7b
|
||||||
|
| WhichModel::W14b
|
||||||
|
| WhichModel::W72b
|
||||||
|
| WhichModel::W2_72b
|
||||||
|
| WhichModel::MoeA27b => {
|
||||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -254,7 +292,6 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config_file = repo.get("config.json")?;
|
let config_file = repo.get("config.json")?;
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let dtype = if device.is_cuda() {
|
let dtype = if device.is_cuda() {
|
||||||
DType::BF16
|
DType::BF16
|
||||||
@ -262,7 +299,16 @@ fn main() -> Result<()> {
|
|||||||
DType::F32
|
DType::F32
|
||||||
};
|
};
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = match args.model {
|
||||||
|
WhichModel::MoeA27b => {
|
||||||
|
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||||
|
Model::Moe(ModelMoe::new(&config, vb)?)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||||
|
Model::Base(ModelBase::new(&config, vb)?)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
9
candle-examples/examples/recurrent-gemma/README.md
Normal file
9
candle-examples/examples/recurrent-gemma/README.md
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
# candle-recurrent-gemma
|
||||||
|
|
||||||
|
This model card corresponds to the 2B base version of the RecurrentGemma model
|
||||||
|
[huggingface model card](https://huggingface.co/google/recurrentgemma-2b).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --features cuda -r --example recurrent-gemma -- \
|
||||||
|
--prompt "Write me a poem about Machine Learning."
|
||||||
|
```
|
321
candle-examples/examples/recurrent-gemma/main.rs
Normal file
321
candle-examples/examples/recurrent-gemma/main.rs
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::quantized_recurrent_gemma::Model as QModel;
|
||||||
|
use candle_transformers::models::recurrent_gemma::{Config, Model as BModel};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
B(BModel),
|
||||||
|
Q(QModel),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::B(m) => m.forward(xs, pos),
|
||||||
|
Self::Q(m) => m.forward(xs, pos),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "2b")]
|
||||||
|
Base2B,
|
||||||
|
#[value(name = "2b-it")]
|
||||||
|
Instruct2B,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
top_k: usize,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let sampling = match temp {
|
||||||
|
None => candle_transformers::generation::Sampling::ArgMax,
|
||||||
|
Some(temperature) => match top_p {
|
||||||
|
None => candle_transformers::generation::Sampling::TopK {
|
||||||
|
temperature,
|
||||||
|
k: top_k,
|
||||||
|
},
|
||||||
|
Some(top_p) => candle_transformers::generation::Sampling::TopKThenTopP {
|
||||||
|
temperature,
|
||||||
|
k: top_k,
|
||||||
|
p: top_p,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let logits_processor = LogitsProcessor::from_sampling(seed, sampling);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the <eos> token"),
|
||||||
|
};
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
#[arg(long, default_value_t = 250)]
|
||||||
|
top_k: usize,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 8000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model to use.
|
||||||
|
#[arg(long, default_value = "2b")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match &args.model_id {
|
||||||
|
Some(model_id) => model_id.to_string(),
|
||||||
|
None => match args.which {
|
||||||
|
Which::Base2B => "google/recurrentgemma-2b".to_string(),
|
||||||
|
Which::Instruct2B => "google/recurrentgemma-2b-it".to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let config_filename = match args.config_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("config.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => {
|
||||||
|
if args.quantized {
|
||||||
|
let filename = match args.which {
|
||||||
|
Which::Base2B => "recurrent-gemma-2b-q4k.gguf",
|
||||||
|
Which::Instruct2B => "recurrent-gemma-7b-q4k.gguf",
|
||||||
|
};
|
||||||
|
let filename = api.model("lmz/candle-gemma".to_string()).get(filename)?;
|
||||||
|
vec![filename]
|
||||||
|
} else {
|
||||||
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let model = if args.quantized {
|
||||||
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||||
|
&filenames[0],
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
Model::Q(QModel::new(&config, vb.pp("model"))?)
|
||||||
|
} else {
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
Model::B(BModel::new(&config, vb.pp("model"))?)
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.top_k,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -42,7 +42,7 @@ impl GymEnv {
|
|||||||
/// Creates a new session of the specified OpenAI Gym environment.
|
/// Creates a new session of the specified OpenAI Gym environment.
|
||||||
pub fn new(name: &str) -> Result<GymEnv> {
|
pub fn new(name: &str) -> Result<GymEnv> {
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let gym = py.import("gymnasium")?;
|
let gym = py.import_bound("gymnasium")?;
|
||||||
let make = gym.getattr("make")?;
|
let make = gym.getattr("make")?;
|
||||||
let env = make.call1((name,))?;
|
let env = make.call1((name,))?;
|
||||||
let action_space = env.getattr("action_space")?;
|
let action_space = env.getattr("action_space")?;
|
||||||
@ -66,10 +66,10 @@ impl GymEnv {
|
|||||||
/// Resets the environment, returning the observation tensor.
|
/// Resets the environment, returning the observation tensor.
|
||||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||||
let state: Vec<f32> = Python::with_gil(|py| {
|
let state: Vec<f32> = Python::with_gil(|py| {
|
||||||
let kwargs = PyDict::new(py);
|
let kwargs = PyDict::new_bound(py);
|
||||||
kwargs.set_item("seed", seed)?;
|
kwargs.set_item("seed", seed)?;
|
||||||
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
let state = self.env.call_method_bound(py, "reset", (), Some(&kwargs))?;
|
||||||
state.as_ref(py).get_item(0)?.extract()
|
state.bind(py).get_item(0)?.extract()
|
||||||
})
|
})
|
||||||
.map_err(w)?;
|
.map_err(w)?;
|
||||||
Tensor::new(state, &Device::Cpu)
|
Tensor::new(state, &Device::Cpu)
|
||||||
@ -81,8 +81,10 @@ impl GymEnv {
|
|||||||
action: A,
|
action: A,
|
||||||
) -> Result<Step<A>> {
|
) -> Result<Step<A>> {
|
||||||
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
||||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
let step = self
|
||||||
let step = step.as_ref(py);
|
.env
|
||||||
|
.call_method_bound(py, "step", (action.clone(),), None)?;
|
||||||
|
let step = step.bind(py);
|
||||||
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
||||||
let reward: f64 = step.get_item(1)?.extract()?;
|
let reward: f64 = step.get_item(1)?.extract()?;
|
||||||
let terminated: bool = step.get_item(2)?.extract()?;
|
let terminated: bool = step.get_item(2)?.extract()?;
|
||||||
|
@ -24,13 +24,13 @@ fn w(res: PyErr) -> candle::Error {
|
|||||||
impl VecGymEnv {
|
impl VecGymEnv {
|
||||||
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
|
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let sys = py.import("sys")?;
|
let sys = py.import_bound("sys")?;
|
||||||
let path = sys.getattr("path")?;
|
let path = sys.getattr("path")?;
|
||||||
let _ = path.call_method1(
|
let _ = path.call_method1(
|
||||||
"append",
|
"append",
|
||||||
("candle-examples/examples/reinforcement-learning",),
|
("candle-examples/examples/reinforcement-learning",),
|
||||||
)?;
|
)?;
|
||||||
let gym = py.import("atari_wrappers")?;
|
let gym = py.import_bound("atari_wrappers")?;
|
||||||
let make = gym.getattr("make")?;
|
let make = gym.getattr("make")?;
|
||||||
let env = make.call1((name, img_dir, nprocesses))?;
|
let env = make.call1((name, img_dir, nprocesses))?;
|
||||||
let action_space = env.getattr("action_space")?;
|
let action_space = env.getattr("action_space")?;
|
||||||
@ -60,10 +60,10 @@ impl VecGymEnv {
|
|||||||
|
|
||||||
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
|
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
|
||||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||||
let step = self.env.call_method(py, "step", (action,), None)?;
|
let step = self.env.call_method_bound(py, "step", (action,), None)?;
|
||||||
let step = step.as_ref(py);
|
let step = step.bind(py);
|
||||||
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
|
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
|
||||||
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
|
let obs_buffer = pyo3::buffer::PyBuffer::get_bound(&obs)?;
|
||||||
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
|
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
|
||||||
let reward: Vec<f32> = step.get_item(1)?.extract()?;
|
let reward: Vec<f32> = step.get_item(1)?.extract()?;
|
||||||
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
|
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
|
||||||
|
@ -39,7 +39,7 @@ struct Args {
|
|||||||
|
|
||||||
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||||
/// mask, positive makes the mask more selective.
|
/// mask, positive makes the mask more selective.
|
||||||
#[arg(long, default_value_t = 0.)]
|
#[arg(long, allow_hyphen_values = true, default_value_t = 0.)]
|
||||||
threshold: f32,
|
threshold: f32,
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
@ -46,7 +46,8 @@ The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
|
|||||||
- `--cpu`: use the cpu rather than the gpu (much slower).
|
- `--cpu`: use the cpu rather than the gpu (much slower).
|
||||||
- `--height`, `--width`: set the height and width for the generated image.
|
- `--height`, `--width`: set the height and width for the generated image.
|
||||||
- `--n-steps`: the number of steps to be used in the diffusion process.
|
- `--n-steps`: the number of steps to be used in the diffusion process.
|
||||||
- `--num-samples`: the number of samples to generate.
|
- `--num-samples`: the number of samples to generate iteratively.
|
||||||
|
- `--bsize`: the numbers of samples to generate simultaneously.
|
||||||
- `--final-image`: the filename for the generated image(s).
|
- `--final-image`: the filename for the generated image(s).
|
||||||
|
|
||||||
### Using flash-attention
|
### Using flash-attention
|
||||||
|
@ -9,6 +9,7 @@ use candle_transformers::models::stable_diffusion;
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use stable_diffusion::vae::AutoEncoderKL;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
@ -64,9 +65,13 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
n_steps: Option<usize>,
|
n_steps: Option<usize>,
|
||||||
|
|
||||||
/// The number of samples to generate.
|
/// The number of samples to generate iteratively.
|
||||||
#[arg(long, default_value_t = 1)]
|
#[arg(long, default_value_t = 1)]
|
||||||
num_samples: i64,
|
num_samples: usize,
|
||||||
|
|
||||||
|
/// The numbers of samples to generate simultaneously.
|
||||||
|
#[arg[long, default_value_t = 1]]
|
||||||
|
bsize: usize,
|
||||||
|
|
||||||
/// The name of the final image to generate.
|
/// The name of the final image to generate.
|
||||||
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
|
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
|
||||||
@ -236,8 +241,8 @@ impl ModelFile {
|
|||||||
|
|
||||||
fn output_filename(
|
fn output_filename(
|
||||||
basename: &str,
|
basename: &str,
|
||||||
sample_idx: i64,
|
sample_idx: usize,
|
||||||
num_samples: i64,
|
num_samples: usize,
|
||||||
timestep_idx: Option<usize>,
|
timestep_idx: Option<usize>,
|
||||||
) -> String {
|
) -> String {
|
||||||
let filename = if num_samples > 1 {
|
let filename = if num_samples > 1 {
|
||||||
@ -261,6 +266,33 @@ fn output_filename(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn save_image(
|
||||||
|
vae: &AutoEncoderKL,
|
||||||
|
latents: &Tensor,
|
||||||
|
vae_scale: f64,
|
||||||
|
bsize: usize,
|
||||||
|
idx: usize,
|
||||||
|
final_image: &str,
|
||||||
|
num_samples: usize,
|
||||||
|
timestep_ids: Option<usize>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let images = vae.decode(&(latents / vae_scale)?)?;
|
||||||
|
let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||||
|
let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?;
|
||||||
|
for batch in 0..bsize {
|
||||||
|
let image = images.i(batch)?;
|
||||||
|
let image_filename = output_filename(
|
||||||
|
final_image,
|
||||||
|
(bsize * idx) + batch + 1,
|
||||||
|
batch + num_samples,
|
||||||
|
timestep_ids,
|
||||||
|
);
|
||||||
|
candle_examples::save_image(&image, image_filename)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn text_embeddings(
|
fn text_embeddings(
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
@ -382,6 +414,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
final_image,
|
final_image,
|
||||||
sliced_attention_size,
|
sliced_attention_size,
|
||||||
num_samples,
|
num_samples,
|
||||||
|
bsize,
|
||||||
sd_version,
|
sd_version,
|
||||||
clip_weights,
|
clip_weights,
|
||||||
vae_weights,
|
vae_weights,
|
||||||
@ -475,6 +508,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
|
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
|
||||||
|
let text_embeddings = text_embeddings.repeat((bsize, 1, 1))?;
|
||||||
println!("{text_embeddings:?}");
|
println!("{text_embeddings:?}");
|
||||||
|
|
||||||
println!("Building the autoencoder.");
|
println!("Building the autoencoder.");
|
||||||
@ -496,7 +530,6 @@ fn run(args: Args) -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
0
|
0
|
||||||
};
|
};
|
||||||
let bsize = 1;
|
|
||||||
|
|
||||||
let vae_scale = match sd_version {
|
let vae_scale = match sd_version {
|
||||||
StableDiffusionVersion::V1_5
|
StableDiffusionVersion::V1_5
|
||||||
@ -560,12 +593,16 @@ fn run(args: Args) -> Result<()> {
|
|||||||
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
|
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
|
||||||
|
|
||||||
if args.intermediary_images {
|
if args.intermediary_images {
|
||||||
let image = vae.decode(&(&latents / vae_scale)?)?;
|
save_image(
|
||||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
&vae,
|
||||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
&latents,
|
||||||
let image_filename =
|
vae_scale,
|
||||||
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
|
bsize,
|
||||||
candle_examples::save_image(&image, image_filename)?
|
idx,
|
||||||
|
&final_image,
|
||||||
|
num_samples,
|
||||||
|
Some(timestep_index + 1),
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -574,11 +611,16 @@ fn run(args: Args) -> Result<()> {
|
|||||||
idx + 1,
|
idx + 1,
|
||||||
num_samples
|
num_samples
|
||||||
);
|
);
|
||||||
let image = vae.decode(&(&latents / vae_scale)?)?;
|
save_image(
|
||||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
&vae,
|
||||||
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
&latents,
|
||||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
vae_scale,
|
||||||
candle_examples::save_image(&image, image_filename)?
|
bsize,
|
||||||
|
idx,
|
||||||
|
&final_image,
|
||||||
|
num_samples,
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,12 +12,23 @@ use anyhow::{Error as E, Result};
|
|||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
T5Base,
|
||||||
|
T5Small,
|
||||||
|
T5Large,
|
||||||
|
T5_3B,
|
||||||
|
Mt5Base,
|
||||||
|
Mt5Small,
|
||||||
|
Mt5Large,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -36,6 +47,15 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
/// Enable decoding.
|
/// Enable decoding.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
decode: bool,
|
decode: bool,
|
||||||
@ -71,6 +91,10 @@ struct Args {
|
|||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 64)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model to be used.
|
||||||
|
#[arg(long, default_value = "t5-small")]
|
||||||
|
which: Which,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct T5ModelBuilder {
|
struct T5ModelBuilder {
|
||||||
@ -82,8 +106,17 @@ struct T5ModelBuilder {
|
|||||||
impl T5ModelBuilder {
|
impl T5ModelBuilder {
|
||||||
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let default_model = "t5-small".to_string();
|
let (default_model, default_revision) = match args.which {
|
||||||
let default_revision = "refs/pr/15".to_string();
|
Which::T5Base => ("t5-base", "main"),
|
||||||
|
Which::T5Small => ("t5-small", "refs/pr/15"),
|
||||||
|
Which::T5Large => ("t5-large", "main"),
|
||||||
|
Which::T5_3B => ("t5-3b", "main"),
|
||||||
|
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
|
||||||
|
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
|
||||||
|
Which::Mt5Large => ("google/mt5-large", "refs/pr/2"),
|
||||||
|
};
|
||||||
|
let default_model = default_model.to_string();
|
||||||
|
let default_revision = default_revision.to_string();
|
||||||
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
@ -93,14 +126,35 @@ impl T5ModelBuilder {
|
|||||||
|
|
||||||
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let api = api.repo(repo);
|
let repo = api.repo(repo);
|
||||||
let config_filename = api.get("config.json")?;
|
let config_filename = match &args.config_file {
|
||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
None => repo.get("config.json")?,
|
||||||
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
|
Some(f) => f.into(),
|
||||||
{
|
};
|
||||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
let tokenizer_filename = match &args.tokenizer_file {
|
||||||
} else {
|
None => match args.which {
|
||||||
vec![api.get("model.safetensors")?]
|
Which::Mt5Base => api
|
||||||
|
.model("lmz/mt5-tokenizers".into())
|
||||||
|
.get("mt5-base.tokenizer.json")?,
|
||||||
|
Which::Mt5Small => api
|
||||||
|
.model("lmz/mt5-tokenizers".into())
|
||||||
|
.get("mt5-small.tokenizer.json")?,
|
||||||
|
Which::Mt5Large => api
|
||||||
|
.model("lmz/mt5-tokenizers".into())
|
||||||
|
.get("mt5-large.tokenizer.json")?,
|
||||||
|
_ => repo.get("tokenizer.json")?,
|
||||||
|
},
|
||||||
|
Some(f) => f.into(),
|
||||||
|
};
|
||||||
|
let weights_filename = match &args.model_file {
|
||||||
|
Some(f) => f.split(',').map(|v| v.into()).collect::<Vec<_>>(),
|
||||||
|
None => {
|
||||||
|
if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" {
|
||||||
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
|
} else {
|
||||||
|
vec![repo.get("model.safetensors")?]
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||||
|
@ -115,7 +115,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let processor = image_processor::ViTImageProcessor::new(&processor_config);
|
let processor = image_processor::ViTImageProcessor::new(&processor_config);
|
||||||
|
|
||||||
let image = vec![args.image.as_str()];
|
let image = vec![args.image.as_str()];
|
||||||
let image = processor.preprocess(image)?;
|
let image = processor.preprocess(image)?.to_device(&device)?;
|
||||||
|
|
||||||
let encoder_xs = model.encoder().forward(&image)?;
|
let encoder_xs = model.encoder().forward(&image)?;
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ struct Block {
|
|||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
fn get(&self, key: &str) -> Result<&str> {
|
fn get(&self, key: &str) -> Result<&str> {
|
||||||
match self.parameters.get(&key.to_string()) {
|
match self.parameters.get(key) {
|
||||||
None => candle::bail!("cannot find {} in {}", key, self.block_type),
|
None => candle::bail!("cannot find {} in {}", key, self.block_type),
|
||||||
Some(value) => Ok(value),
|
Some(value) => Ok(value),
|
||||||
}
|
}
|
||||||
@ -28,7 +28,7 @@ pub struct Darknet {
|
|||||||
|
|
||||||
impl Darknet {
|
impl Darknet {
|
||||||
fn get(&self, key: &str) -> Result<&str> {
|
fn get(&self, key: &str) -> Result<&str> {
|
||||||
match self.parameters.get(&key.to_string()) {
|
match self.parameters.get(key) {
|
||||||
None => candle::bail!("cannot find {} in net parameters", key),
|
None => candle::bail!("cannot find {} in net parameters", key),
|
||||||
Some(value) => Ok(value),
|
Some(value) => Ok(value),
|
||||||
}
|
}
|
||||||
|
BIN
candle-examples/examples/yolo-v8/assets/bike.pp.jpg
Normal file
BIN
candle-examples/examples/yolo-v8/assets/bike.pp.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 175 KiB |
@ -448,9 +448,9 @@ pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows10
|
|||||||
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
|
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
|
||||||
///
|
///
|
||||||
/// The integrated loudness measurement is not just the average power over the
|
/// The integrated loudness measurement is not just the average power over the
|
||||||
/// entire signal. BS.1770-4 defines defines two stages of gating that exclude
|
/// entire signal. BS.1770-4 defines two stages of gating that exclude
|
||||||
/// parts of the signal, to ensure that silent parts do not contribute to the
|
/// parts of the signal, to ensure that silent parts do not contribute to the
|
||||||
/// loudness measurment. This function performs that gating, and returns the
|
/// loudness measurement. This function performs that gating, and returns the
|
||||||
/// average power over the windows that were not excluded.
|
/// average power over the windows that were not excluded.
|
||||||
///
|
///
|
||||||
/// The result of this function is the integrated loudness measurement.
|
/// The result of this function is the integrated loudness measurement.
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.4.2"
|
version = "0.6.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.2" }
|
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.6.0" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|||||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||||
|
if (smem_size >= 48 * 1024) {
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||||
|
}
|
||||||
// int ctas_per_sm;
|
// int ctas_per_sm;
|
||||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||||
|
@ -139,7 +139,9 @@ impl FlashAttn {
|
|||||||
|
|
||||||
let elem_count = out_shape.elem_count();
|
let elem_count = out_shape.elem_count();
|
||||||
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
let softmax_lse = dev
|
||||||
|
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
|
||||||
|
.w()?;
|
||||||
|
|
||||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.4.2"
|
version = "0.6.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
fn main() {
|
fn main() {
|
||||||
println!("cargo:rerun-if-changed=build.rs");
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
|
println!("cargo:rerun-if-changed=src/compatibility.cuh");
|
||||||
|
println!("cargo:rerun-if-changed=src/cuda_utils.cuh");
|
||||||
|
println!("cargo:rerun-if-changed=src/binary_op_macros.cuh");
|
||||||
|
|
||||||
let builder = bindgen_cuda::Builder::default();
|
let builder = bindgen_cuda::Builder::default();
|
||||||
println!("cargo:info={builder:?}");
|
println!("cargo:info={builder:?}");
|
||||||
|
@ -97,6 +97,50 @@ __device__ void im2col1d(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void col2im1d(
|
||||||
|
const size_t dst_el,
|
||||||
|
const size_t l_out,
|
||||||
|
const size_t l_in,
|
||||||
|
const size_t c_out,
|
||||||
|
const size_t k_size,
|
||||||
|
const size_t stride,
|
||||||
|
const T *src,
|
||||||
|
T *dst
|
||||||
|
) {
|
||||||
|
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
// src: (b_size, l_in, c_out, l_k)
|
||||||
|
// dst: (b_size, c_out, l_out)
|
||||||
|
if (dst_i >= dst_el) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t dst_s0 = c_out * l_out;
|
||||||
|
const size_t dst_s1 = l_out;
|
||||||
|
const size_t src_s0 = c_out * k_size * l_in;
|
||||||
|
const size_t src_s1 = c_out * k_size;
|
||||||
|
const size_t src_s2 = k_size;
|
||||||
|
|
||||||
|
size_t tmp_dst_i = dst_i;
|
||||||
|
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||||
|
tmp_dst_i -= b_idx * dst_s0;
|
||||||
|
const size_t c_idx = tmp_dst_i / dst_s1;
|
||||||
|
tmp_dst_i -= c_idx * dst_s1;
|
||||||
|
const int l_out_idx = tmp_dst_i;
|
||||||
|
|
||||||
|
dst[dst_i] = static_cast<T>(0);
|
||||||
|
|
||||||
|
int l_in_idx = l_out_idx / stride;
|
||||||
|
int k0 = l_out_idx - l_in_idx * stride;
|
||||||
|
// l_out_idx = l_in_idx * stride + k0
|
||||||
|
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
|
||||||
|
if (l_in_idx < l_in) {
|
||||||
|
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
|
||||||
|
dst[dst_i] += src[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ void im2col(
|
__device__ void im2col(
|
||||||
const size_t dst_numel,
|
const size_t dst_numel,
|
||||||
@ -542,6 +586,20 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#define COL2IM1D_OP(TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const size_t dst_el, \
|
||||||
|
const size_t l_out, \
|
||||||
|
const size_t l_in, \
|
||||||
|
const size_t c_out, \
|
||||||
|
const size_t k_size, \
|
||||||
|
const size_t stride, \
|
||||||
|
const TYPENAME *src, \
|
||||||
|
TYPENAME *dst \
|
||||||
|
) { \
|
||||||
|
col2im1d<TYPENAME>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst); \
|
||||||
|
} \
|
||||||
|
|
||||||
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
const size_t dst_numel, \
|
const size_t dst_numel, \
|
||||||
@ -643,6 +701,7 @@ MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
|||||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||||
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
||||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||||
|
COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
@ -655,6 +714,7 @@ MAX_POOL2D_OP(__half, max_pool2d_f16)
|
|||||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||||
IM2COL_OP(__half, im2col_f16)
|
IM2COL_OP(__half, im2col_f16)
|
||||||
IM2COL1D_OP(__half, im2col1d_f16)
|
IM2COL1D_OP(__half, im2col1d_f16)
|
||||||
|
COL2IM1D_OP(__half, col2im1d_f16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
CONV1D_OP(float, float, conv1d_f32)
|
CONV1D_OP(float, float, conv1d_f32)
|
||||||
@ -701,3 +761,8 @@ IM2COL1D_OP(float, im2col1d_f32)
|
|||||||
IM2COL1D_OP(double, im2col1d_f64)
|
IM2COL1D_OP(double, im2col1d_f64)
|
||||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||||
|
|
||||||
|
COL2IM1D_OP(float, col2im1d_f32)
|
||||||
|
COL2IM1D_OP(double, col2im1d_f64)
|
||||||
|
COL2IM1D_OP(uint8_t, col2im1d_u8)
|
||||||
|
COL2IM1D_OP(uint32_t, col2im1d_u32)
|
||||||
|
@ -14,7 +14,7 @@ __device__ bool is_contiguous(
|
|||||||
size_t acc = 1;
|
size_t acc = 1;
|
||||||
for (unsigned int d = 0; d < num_dims; d++) {
|
for (unsigned int d = 0; d < num_dims; d++) {
|
||||||
unsigned int dim_idx = num_dims - 1 - d;
|
unsigned int dim_idx = num_dims - 1 - d;
|
||||||
if (acc != strides[dim_idx]) {
|
if (dims[dim_idx] > 1 && acc != strides[dim_idx]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
acc *= dims[dim_idx];
|
acc *= dims[dim_idx];
|
||||||
|
@ -6,5 +6,6 @@ pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
|||||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||||
|
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -50,6 +50,15 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
|
|||||||
dst[dst_id] = shr[0];
|
dst[dst_id] = shr[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
||||||
|
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
@ -58,6 +67,70 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LayerNorm implementation adapted from ggml, accumulation is made using f32.
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
|
||||||
|
template <typename T>
|
||||||
|
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
|
||||||
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int block_size = blockDim.x;
|
||||||
|
|
||||||
|
float2 mean_var = make_float2(0.f, 0.f);
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
const float xi = x[row*ncols + col];
|
||||||
|
mean_var.x += xi;
|
||||||
|
mean_var.y += xi * xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums
|
||||||
|
mean_var = warp_reduce_sum(mean_var);
|
||||||
|
if (block_size > WARP_SIZE) {
|
||||||
|
__shared__ float2 s_sum[32];
|
||||||
|
int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
if (lane_id == 0) {
|
||||||
|
s_sum[warp_id] = mean_var;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
mean_var = s_sum[lane_id];
|
||||||
|
mean_var = warp_reduce_sum(mean_var);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float mean = mean_var.x / ncols;
|
||||||
|
const float var = mean_var.y / ncols - mean * mean;
|
||||||
|
const float inv_std = rsqrtf(var + eps);
|
||||||
|
|
||||||
|
if (alpha == nullptr && beta == nullptr) {
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||||
|
dst[row*ncols + col] = static_cast<T>(lhs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (alpha == nullptr && beta != nullptr) {
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
float b = static_cast<float>(beta[col]);
|
||||||
|
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||||
|
dst[row*ncols + col] = static_cast<T>(lhs + b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (alpha != nullptr && beta == nullptr) {
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
float a = static_cast<float>(alpha[col]);
|
||||||
|
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||||
|
dst[row*ncols + col] = static_cast<T>(lhs * a);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
float a = static_cast<float>(alpha[col]);
|
||||||
|
float b = static_cast<float>(beta[col]);
|
||||||
|
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||||
|
dst[row*ncols + col] = static_cast<T>(lhs * a + b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
|
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
|
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -147,6 +220,65 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (2 * idx >= bh * td) return;
|
||||||
|
|
||||||
|
uint32_t rope_idx = idx % (td / 2);
|
||||||
|
T c = cos[rope_idx];
|
||||||
|
T s = sin[rope_idx];
|
||||||
|
|
||||||
|
dst[2 * idx] = src[2 * idx] * c - src[2 * idx + 1] * s;
|
||||||
|
dst[2 * idx + 1] = src[2 * idx] * s + src[2 * idx + 1] * c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (2 * idx >= bh * td) return;
|
||||||
|
|
||||||
|
uint32_t i_bh = idx / (td / 2);
|
||||||
|
uint32_t i_td = idx - (td / 2) * i_bh;
|
||||||
|
uint32_t i_t = i_td / (d / 2);
|
||||||
|
uint32_t i_d = i_td - (d / 2) * i_t;
|
||||||
|
uint32_t i1 = i_bh * td + i_t * d + i_d;
|
||||||
|
uint32_t i2 = i1 + d / 2;
|
||||||
|
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||||
|
T c = cos[i_cs];
|
||||||
|
T s = sin[i_cs];
|
||||||
|
|
||||||
|
dst[i1] = src[i1] * c - src[i2] * s;
|
||||||
|
dst[i2] = src[i1] * s + src[i2] * c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void rope_thd(
|
||||||
|
const T * src,
|
||||||
|
const T * cos,
|
||||||
|
const T * sin,
|
||||||
|
T * dst,
|
||||||
|
const uint32_t b,
|
||||||
|
const uint32_t t,
|
||||||
|
const uint32_t h,
|
||||||
|
const uint32_t d
|
||||||
|
) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (2 * idx >= b * t * h * d) return;
|
||||||
|
|
||||||
|
uint32_t i_bth = idx / (d / 2);
|
||||||
|
uint32_t i_d = idx - (d / 2) * i_bth;
|
||||||
|
uint32_t i_t = (i_bth / h) % t;
|
||||||
|
uint32_t i1 = i_bth * d + i_d;
|
||||||
|
uint32_t i2 = i1 + d / 2;
|
||||||
|
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||||
|
T c = cos[i_cs];
|
||||||
|
T s = sin[i_cs];
|
||||||
|
|
||||||
|
dst[i1] = src[i1] * c - src[i2] * s;
|
||||||
|
dst[i2] = src[i1] * s + src[i2] * c;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ void
|
__device__ void
|
||||||
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
|
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||||
@ -402,9 +534,50 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
|||||||
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#define LAYERNORM_OP(TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
|
||||||
|
const TYPENAME *beta, const int n_cols, const float eps) { \
|
||||||
|
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
|
||||||
|
extern "C" __global__ void FN_NAME_I( \
|
||||||
|
const TYPENAME *src, \
|
||||||
|
const TYPENAME *cos, \
|
||||||
|
const TYPENAME *sin, \
|
||||||
|
TYPENAME *dst, \
|
||||||
|
const uint32_t bh, \
|
||||||
|
const uint32_t td) { \
|
||||||
|
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
|
||||||
|
} \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const TYPENAME *src, \
|
||||||
|
const TYPENAME *cos, \
|
||||||
|
const TYPENAME *sin, \
|
||||||
|
TYPENAME *dst, \
|
||||||
|
const uint32_t bh, \
|
||||||
|
const uint32_t td, \
|
||||||
|
const uint32_t d) { \
|
||||||
|
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
|
||||||
|
} \
|
||||||
|
extern "C" __global__ void FN_NAME_THD( \
|
||||||
|
const TYPENAME *src, \
|
||||||
|
const TYPENAME *cos, \
|
||||||
|
const TYPENAME *sin, \
|
||||||
|
TYPENAME *dst, \
|
||||||
|
const uint32_t b, \
|
||||||
|
const uint32_t t, \
|
||||||
|
const uint32_t h, \
|
||||||
|
const uint32_t d) { \
|
||||||
|
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
|
||||||
|
} \
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
||||||
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
||||||
|
LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
|
||||||
|
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
|
||||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||||
#endif
|
#endif
|
||||||
@ -412,6 +585,8 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
|
|||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
SOFTMAX_OP(__half, float, softmax_f16)
|
SOFTMAX_OP(__half, float, softmax_f16)
|
||||||
RMSNORM_OP(__half, rmsnorm_f16)
|
RMSNORM_OP(__half, rmsnorm_f16)
|
||||||
|
LAYERNORM_OP(__half, layernorm_f16)
|
||||||
|
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
|
||||||
SUM_OP(__half, sum_f16)
|
SUM_OP(__half, sum_f16)
|
||||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||||
#endif
|
#endif
|
||||||
@ -423,6 +598,10 @@ SOFTMAX_OP(float, float, softmax_f32)
|
|||||||
SOFTMAX_OP(double, double, softmax_f64)
|
SOFTMAX_OP(double, double, softmax_f64)
|
||||||
RMSNORM_OP(float, rmsnorm_f32)
|
RMSNORM_OP(float, rmsnorm_f32)
|
||||||
RMSNORM_OP(double, rmsnorm_f64)
|
RMSNORM_OP(double, rmsnorm_f64)
|
||||||
|
LAYERNORM_OP(float, layernorm_f32)
|
||||||
|
LAYERNORM_OP(double, layernorm_f64)
|
||||||
|
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
|
||||||
|
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)
|
||||||
|
|
||||||
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
||||||
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
||||||
|
88
candle-kernels/src/sort.cu
Normal file
88
candle-kernels/src/sort.cu
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
// Adapted from https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda/argsort.cu
|
||||||
|
#define SORT_ORDER_ASC 1
|
||||||
|
#define SORT_ORDER_DESC 0
|
||||||
|
#include "cuda_utils.cuh"
|
||||||
|
#include<stdint.h>
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||||
|
T tmp = a;
|
||||||
|
a = b;
|
||||||
|
b = tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int order, typename T>
|
||||||
|
static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) {
|
||||||
|
// bitonic sort
|
||||||
|
int col = threadIdx.x;
|
||||||
|
int row = blockIdx.y;
|
||||||
|
|
||||||
|
if (col >= ncols_pad) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const T * x_row = x + row * ncols;
|
||||||
|
extern __shared__ int dst_row[];
|
||||||
|
|
||||||
|
// initialize indices
|
||||||
|
dst_row[col] = col;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||||
|
for (int j = k / 2; j > 0; j /= 2) {
|
||||||
|
int ixj = col ^ j;
|
||||||
|
if (ixj > col) {
|
||||||
|
if ((col & k) == 0) {
|
||||||
|
if (dst_row[col] >= ncols ||
|
||||||
|
(dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
|
||||||
|
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||||
|
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||||
|
) {
|
||||||
|
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (dst_row[ixj] >= ncols ||
|
||||||
|
(dst_row[col] < ncols && (order == SORT_ORDER_ASC ?
|
||||||
|
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||||
|
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||||
|
) {
|
||||||
|
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy the result to dst without the padding
|
||||||
|
if (col < ncols) {
|
||||||
|
dst[row * ncols + col] = dst_row[col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define ASORT_OP(TYPENAME, RUST_NAME) \
|
||||||
|
extern "C" __global__ void asort_asc_##RUST_NAME( \
|
||||||
|
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
|
||||||
|
) { \
|
||||||
|
k_argsort<SORT_ORDER_ASC>(x, dst, ncols, ncols_pad); \
|
||||||
|
} \
|
||||||
|
extern "C" __global__ void asort_desc_##RUST_NAME( \
|
||||||
|
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
|
||||||
|
) { \
|
||||||
|
k_argsort<SORT_ORDER_DESC>(x, dst, ncols, ncols_pad); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 800
|
||||||
|
ASORT_OP(__nv_bfloat16, bf16)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
ASORT_OP(__half, f16)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
ASORT_OP(float, f32)
|
||||||
|
ASORT_OP(double, f64)
|
||||||
|
ASORT_OP(uint8_t, u8)
|
||||||
|
ASORT_OP(uint32_t, u32)
|
||||||
|
ASORT_OP(int64_t, i64)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user