mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Compare commits
194 Commits
Author | SHA1 | Date | |
---|---|---|---|
1f23cea90c | |||
ce33d6ad2a | |||
3d0ade406a | |||
2ca086939f | |||
4349ff1fc2 | |||
7c3cfd1086 | |||
e2eb6590ed | |||
481c45d78d | |||
14a2bdc062 | |||
bfa7c8fc01 | |||
762e996ce6 | |||
ca19a9af62 | |||
ec23427d60 | |||
f83e14f68d | |||
c7e613ab5e | |||
8f63f68289 | |||
1edc3ddf24 | |||
b380657bfe | |||
60f624a902 | |||
8d6c6de8e0 | |||
7ec345c2eb | |||
671fc29b36 | |||
dc64adb8e4 | |||
c66e5d4716 | |||
bd3b243725 | |||
2813fb5dbc | |||
7cfffcac10 | |||
38de52bc4b | |||
d46670f7c0 | |||
f710fab02e | |||
f82bf2d915 | |||
df6814f34e | |||
39406a6721 | |||
976ad9f9c2 | |||
a4c4a56429 | |||
f49bf6a81d | |||
992a788da1 | |||
8d8f48c60c | |||
d31f11035f | |||
9ab3f9729f | |||
a1f41ab37b | |||
92a05b51cf | |||
c6763e3b41 | |||
347e31c9ff | |||
f4fcf60900 | |||
12561b31d3 | |||
a209ce8ceb | |||
f1e678b39c | |||
a007f8fdb4 | |||
2341aa079e | |||
9e666d4229 | |||
1b12142a02 | |||
d2c3f14773 | |||
26c4e5bf1d | |||
18d30005c5 | |||
6958384327 | |||
e6697471bb | |||
73d02f4f57 | |||
f772213e84 | |||
2feb0b054f | |||
2d28497197 | |||
f3a4f3db76 | |||
7920b45c8a | |||
d4a45c936a | |||
c912d24570 | |||
d5c2a7b64b | |||
508f811b93 | |||
a773a4b22b | |||
5a363dbc26 | |||
abc4f698c5 | |||
a923e8b53a | |||
2a45bcf943 | |||
47f4ddb011 | |||
f365a075e5 | |||
60fdab4e17 | |||
928a9d906e | |||
d1d89bac1f | |||
39ad840a90 | |||
b5e4f84bed | |||
7051fb8098 | |||
dc68c130e4 | |||
bc9a1bf239 | |||
f7c957d64f | |||
8cbb9d0e6c | |||
bfe95115c6 | |||
6fa3151820 | |||
0a58886ccb | |||
3173b1ce3b | |||
ad63f20781 | |||
1cfc5d6d0c | |||
b07b2350b6 | |||
1b5063f3ca | |||
3b0d1e7d03 | |||
be4555c5a5 | |||
6975c65112 | |||
a2a20aeecc | |||
e08fbb6543 | |||
d39d0c40fd | |||
b97463098c | |||
fbd69f952c | |||
6c990a33ea | |||
1704f1b3ae | |||
693fad511c | |||
36fb84f038 | |||
c12ad45562 | |||
7d0202710b | |||
392a00a147 | |||
4c967b9184 | |||
c05c0a8213 | |||
969960847a | |||
5fc66bd4ba | |||
174b208052 | |||
154c674a79 | |||
7bbde55c61 | |||
c3f2676d49 | |||
46d6566c99 | |||
55bc3382cf | |||
dece37c6f4 | |||
498c50348c | |||
012ae0090e | |||
95a857cf57 | |||
612f5b8156 | |||
ef33df7ae2 | |||
c8face3f95 | |||
85bea43e5b | |||
b3181455d5 | |||
e2826e70b3 | |||
916619f70b | |||
9b1158b315 | |||
70d06ab4b0 | |||
0ec5ebcec4 | |||
c8e197f68c | |||
5f20697918 | |||
e37b487767 | |||
e5dc8cb4f4 | |||
e7b886d56f | |||
6a446d9d73 | |||
0acd16751d | |||
c698e17619 | |||
e4c9adfdbe | |||
b6053b938b | |||
45dbe541bc | |||
7bd0faba75 | |||
807e3f9f52 | |||
eae94a451b | |||
86e1803191 | |||
25c3cc4149 | |||
a11af79e23 | |||
8a82d623e5 | |||
df2f89b6cf | |||
62fc965617 | |||
5b32c2a41e | |||
3115fe42e4 | |||
2531b13bf8 | |||
0d9bb4eb18 | |||
e8f760ee44 | |||
94e3373883 | |||
34d9e91748 | |||
cfb423ab76 | |||
7366aeac21 | |||
99cf13e8e2 | |||
b43ab6cd1d | |||
31ca4897bb | |||
55351ef57d | |||
6684b7127a | |||
93c25e8844 | |||
cd53c472df | |||
6f76383f38 | |||
8e773cc0c6 | |||
87eb1658e1 | |||
902d0b9166 | |||
185b54a33b | |||
620c94d12e | |||
86e7d539d2 | |||
cb034506cd | |||
63c204c79e | |||
767a6578f1 | |||
662c186fd5 | |||
2cd745a97c | |||
a72b50e2c0 | |||
872c3f14b0 | |||
f9e93f5b69 | |||
b355ab4e2e | |||
2fe24ac5b1 | |||
00948eb656 | |||
af67672207 | |||
6c588c4792 | |||
122da87580 | |||
75629981bc | |||
0106b0b04c | |||
588ad4835a | |||
b73c35cc57 | |||
8f310cc666 | |||
8921d5027c |
2
.github/workflows/ci_cuda.yaml
vendored
2
.github/workflows/ci_cuda.yaml
vendored
@ -59,7 +59,7 @@ jobs:
|
|||||||
- name: Install Rust Stable
|
- name: Install Rust Stable
|
||||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||||
- uses: Swatinem/rust-cache@v2
|
- uses: Swatinem/rust-cache@v2
|
||||||
- run: apt-get update -y && apt-get install libssl-dev -y
|
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
|
||||||
- name: Test (cuda)
|
- name: Test (cuda)
|
||||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||||
stop-runner:
|
stop-runner:
|
||||||
|
BIN
.github/workflows/maturin.yml
vendored
Normal file
BIN
.github/workflows/maturin.yml
vendored
Normal file
Binary file not shown.
68
.github/workflows/python.yml
vendored
Normal file
68
.github/workflows/python.yml
vendored
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
name: PyO3-CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
paths:
|
||||||
|
- candle-pyo3/**
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- candle-pyo3/**
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build_and_test:
|
||||||
|
name: Check everything builds & tests
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest] # For now, only test on Linux
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Install Rust
|
||||||
|
uses: actions-rs/toolchain@v1
|
||||||
|
with:
|
||||||
|
toolchain: stable
|
||||||
|
|
||||||
|
- name: Install Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: 3.11
|
||||||
|
architecture: "x64"
|
||||||
|
|
||||||
|
- name: Cache Cargo Registry
|
||||||
|
uses: actions/cache@v1
|
||||||
|
with:
|
||||||
|
path: ~/.cargo/registry
|
||||||
|
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
|
- name: Install Protoc
|
||||||
|
uses: arduino/setup-protoc@v2
|
||||||
|
with:
|
||||||
|
version: "25.0"
|
||||||
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Install
|
||||||
|
working-directory: ./candle-pyo3
|
||||||
|
run: |
|
||||||
|
python -m venv .env
|
||||||
|
source .env/bin/activate
|
||||||
|
pip install -U pip
|
||||||
|
pip install pytest maturin black
|
||||||
|
python -m maturin develop -r --features onnx
|
||||||
|
|
||||||
|
- name: Check style
|
||||||
|
working-directory: ./candle-pyo3
|
||||||
|
run: |
|
||||||
|
source .env/bin/activate
|
||||||
|
python stub.py --check
|
||||||
|
black --check .
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
working-directory: ./candle-pyo3
|
||||||
|
run: |
|
||||||
|
source .env/bin/activate
|
||||||
|
python -m pytest -s -v tests
|
19
Cargo.toml
19
Cargo.toml
@ -7,20 +7,19 @@ members = [
|
|||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-pyo3",
|
"candle-pyo3",
|
||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
"candle-wasm-examples/llama2-c",
|
"candle-wasm-examples/*",
|
||||||
"candle-wasm-examples/segment-anything",
|
|
||||||
"candle-wasm-examples/whisper",
|
|
||||||
"candle-wasm-examples/yolo",
|
|
||||||
"candle-wasm-examples/bert",
|
|
||||||
"candle-wasm-examples/phi",
|
|
||||||
"candle-wasm-examples/t5",
|
|
||||||
"candle-wasm-tests",
|
"candle-wasm-tests",
|
||||||
]
|
]
|
||||||
exclude = ["candle-flash-attn", "candle-kernels"]
|
exclude = [
|
||||||
|
"candle-flash-attn",
|
||||||
|
"candle-kernels",
|
||||||
|
"candle-metal-kernels",
|
||||||
|
"candle-onnx",
|
||||||
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -52,6 +51,7 @@ rayon = "1.7.0"
|
|||||||
rusttype = { version = "0.9", default-features = false }
|
rusttype = { version = "0.9", default-features = false }
|
||||||
safetensors = "0.3.1"
|
safetensors = "0.3.1"
|
||||||
serde = { version = "1.0.171", features = ["derive"] }
|
serde = { version = "1.0.171", features = ["derive"] }
|
||||||
|
serde_plain = "1.0.2"
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tokenizers = { version = "0.13.4", default-features = false }
|
tokenizers = { version = "0.13.4", default-features = false }
|
||||||
@ -61,6 +61,7 @@ tracing-subscriber = "0.3.7"
|
|||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
|
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
55
README.md
55
README.md
@ -51,22 +51,26 @@ For more advanced examples, please have a look at the following section.
|
|||||||
These online demos run entirely in your browser:
|
These online demos run entirely in your browser:
|
||||||
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
||||||
object recognition.
|
object recognition.
|
||||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
|
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
|
||||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||||
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||||
|
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
|
||||||
|
|
||||||
We also provide a some command line based examples using state of the art models:
|
We also provide a some command line based examples using state of the art models:
|
||||||
|
|
||||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||||
- [Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
||||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||||
pre-trained on 1T tokens of English and code datasets.
|
pre-trained on 1T tokens of English and code datasets.
|
||||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||||
performance larger than all publicly available 13b models as of 2023-09-28.
|
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||||
|
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||||
|
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||||
|
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||||
the LLaMA model using the same quantization techniques as
|
the LLaMA model using the same quantization techniques as
|
||||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||||
@ -94,10 +98,15 @@ We also provide a some command line based examples using state of the art models
|
|||||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||||
|
|
||||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||||
|
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||||
using self-supervision (can be used for imagenet classification, depth
|
using self-supervision (can be used for imagenet classification, depth
|
||||||
evaluation, segmentation).
|
evaluation, segmentation).
|
||||||
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
|
generate captions for an image.
|
||||||
|
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||||
|
model, generates the translated text from the input text.
|
||||||
|
|
||||||
Run them using commands like:
|
Run them using commands like:
|
||||||
```
|
```
|
||||||
@ -129,8 +138,18 @@ And then head over to
|
|||||||
|
|
||||||
<!--- ANCHOR: useful_libraries --->
|
<!--- ANCHOR: useful_libraries --->
|
||||||
|
|
||||||
## Useful Libraries
|
## Useful External Resources
|
||||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
|
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
|
||||||
|
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||||
|
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has
|
||||||
|
out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
|
||||||
|
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
|
||||||
|
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||||
|
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||||
|
serving local LLMs including an OpenAI compatible API server.
|
||||||
|
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||||
|
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||||
|
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||||
|
|
||||||
If you have an addition to this list, please submit a pull request.
|
If you have an addition to this list, please submit a pull request.
|
||||||
|
|
||||||
@ -155,17 +174,26 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Phi v1.5.
|
- Phi v1.5.
|
||||||
- Mistral 7b v0.1.
|
- Mistral 7b v0.1.
|
||||||
- StableLM-3B-4E1T.
|
- StableLM-3B-4E1T.
|
||||||
- T5.
|
- Replit-code-v1.5-3B.
|
||||||
- Bert.
|
- Bert.
|
||||||
|
- Yi-6B and Yi-34B.
|
||||||
|
- Quantized LLMs.
|
||||||
|
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||||
|
- Mistral 7b, and 7b instruct.
|
||||||
|
- Zephyr 7b a and b (Mistral based).
|
||||||
|
- OpenChat 3.5 (Mistral based).
|
||||||
|
- Text to text.
|
||||||
|
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||||
|
- Marian MT (Machine Translation).
|
||||||
- Whisper (multi-lingual support).
|
- Whisper (multi-lingual support).
|
||||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
- Text to image.
|
||||||
- Wurstchen v2.
|
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||||
|
- Wurstchen v2.
|
||||||
|
- Image to text.
|
||||||
|
- BLIP.
|
||||||
- Computer Vision Models.
|
- Computer Vision Models.
|
||||||
- DINOv2.
|
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||||
- ConvMixer.
|
- yolo-v3, yolo-v8.
|
||||||
- EfficientNet.
|
|
||||||
- yolo-v3.
|
|
||||||
- yolo-v8.
|
|
||||||
- Segment-Anything Model (SAM).
|
- Segment-Anything Model (SAM).
|
||||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||||
- Serverless (on CPU), small and fast deployments.
|
- Serverless (on CPU), small and fast deployments.
|
||||||
@ -203,6 +231,7 @@ Cheatsheet:
|
|||||||
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
||||||
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
||||||
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
||||||
|
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
|
||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|
||||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
@ -12,6 +12,9 @@ compute_cap
|
|||||||
8.9
|
8.9
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can also compile the Cuda kernels for a specific compute cap using the
|
||||||
|
`CUDA_COMPUTE_CAP=<compute cap>` environment variable.
|
||||||
|
|
||||||
If any of the above commands errors out, please make sure to update your Cuda version.
|
If any of the above commands errors out, please make sure to update your Cuda version.
|
||||||
|
|
||||||
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
||||||
|
@ -12,7 +12,9 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
|
||||||
|
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true }
|
||||||
|
metal = { workspace = true, optional = true}
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
gemm = { workspace = true }
|
gemm = { workspace = true }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
@ -39,3 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
|||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||||
|
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||||
|
@ -8,11 +8,10 @@ use anyhow::Result;
|
|||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
||||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
||||||
let start = std::time::Instant::now();
|
let new_a = a.slice_scatter(&b, 1, 2)?;
|
||||||
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
println!("{:?}", start.elapsed());
|
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
println!("{res:?}");
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,14 @@ pub trait BackendStorage: Sized {
|
|||||||
_params: &crate::conv::ParamsConv1D,
|
_params: &crate::conv::ParamsConv1D,
|
||||||
) -> Result<Self>;
|
) -> Result<Self>;
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self>;
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
_l: &Layout,
|
_l: &Layout,
|
||||||
|
@ -15,6 +15,17 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
|
||||||
|
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
|
||||||
|
Ok(s) => {
|
||||||
|
!s.is_empty() && s != "0"
|
||||||
|
},
|
||||||
|
Err(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||||
@ -36,6 +47,8 @@ impl Tensor {
|
|||||||
// Do not call recursively on the "leaf" nodes.
|
// Do not call recursively on the "leaf" nodes.
|
||||||
track_grad = true;
|
track_grad = true;
|
||||||
nodes
|
nodes
|
||||||
|
} else if node.dtype().is_int() {
|
||||||
|
nodes
|
||||||
} else if let Some(op) = node.op() {
|
} else if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::IndexAdd(t1, t2, t3, _)
|
Op::IndexAdd(t1, t2, t3, _)
|
||||||
@ -55,6 +68,11 @@ impl Tensor {
|
|||||||
kernel: rhs,
|
kernel: rhs,
|
||||||
..
|
..
|
||||||
}
|
}
|
||||||
|
| Op::ConvTranspose1D {
|
||||||
|
arg: lhs,
|
||||||
|
kernel: rhs,
|
||||||
|
..
|
||||||
|
}
|
||||||
| Op::Conv2D {
|
| Op::Conv2D {
|
||||||
arg: lhs,
|
arg: lhs,
|
||||||
kernel: rhs,
|
kernel: rhs,
|
||||||
@ -103,7 +121,6 @@ impl Tensor {
|
|||||||
| Op::Broadcast(node)
|
| Op::Broadcast(node)
|
||||||
| Op::Cmp(node, _)
|
| Op::Cmp(node, _)
|
||||||
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
|
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
|
||||||
| Op::ToDType(node)
|
|
||||||
| Op::ToDevice(node)
|
| Op::ToDevice(node)
|
||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
| Op::Permute(node, _)
|
| Op::Permute(node, _)
|
||||||
@ -116,6 +133,15 @@ impl Tensor {
|
|||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
|
Op::ToDType(node) => {
|
||||||
|
if node.dtype().is_float() {
|
||||||
|
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||||
|
track_grad |= tg;
|
||||||
|
nodes
|
||||||
|
} else {
|
||||||
|
nodes
|
||||||
|
}
|
||||||
|
}
|
||||||
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
|
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -140,10 +166,16 @@ impl Tensor {
|
|||||||
if node.is_variable() {
|
if node.is_variable() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let grad = grads.remove(node).unwrap();
|
let grad = grads
|
||||||
// TODO: We should perform all these operations in place (or at least not track the
|
.remove(node)
|
||||||
// whole graph). The only drawback would be if we wanted to support grad of grad but
|
.expect("candle internal error - grad not populated");
|
||||||
// this is out of scope.
|
// https://github.com/huggingface/candle/issues/1241
|
||||||
|
// Ideally, we would make these operations in place where possible to ensure that we
|
||||||
|
// do not have to allocate too often. Here we just call `.detach` to avoid computing
|
||||||
|
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||||
|
// derivatives but these are out of scope at the moment.
|
||||||
|
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||||
|
let grad = if do_not_detach { grad } else { grad.detach()? };
|
||||||
if let Some(op) = node.op() {
|
if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||||
@ -198,7 +230,44 @@ impl Tensor {
|
|||||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||||
}
|
}
|
||||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
Op::Conv1D {
|
||||||
|
arg,
|
||||||
|
kernel,
|
||||||
|
padding,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
} => {
|
||||||
|
// The output height for conv_transpose1d is:
|
||||||
|
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
|
||||||
|
let grad_l_in = grad.dim(2)?;
|
||||||
|
let k_size = kernel.dim(2)?;
|
||||||
|
let out_size =
|
||||||
|
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
|
||||||
|
let out_padding = arg.dim(2)? - out_size;
|
||||||
|
let grad_arg = grad.conv_transpose1d(
|
||||||
|
kernel,
|
||||||
|
*padding,
|
||||||
|
out_padding,
|
||||||
|
*stride,
|
||||||
|
*dilation,
|
||||||
|
)?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
|
|
||||||
|
let grad_kernel = arg
|
||||||
|
.transpose(0, 1)?
|
||||||
|
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||||
|
.transpose(0, 1)?;
|
||||||
|
let sum_grad = grads.or_insert(kernel)?;
|
||||||
|
let (_, _, k0) = kernel.dims3()?;
|
||||||
|
let (_, _, g_k0) = grad_kernel.dims3()?;
|
||||||
|
let grad_kernel = if g_k0 != k0 {
|
||||||
|
grad_kernel.narrow(2, 0, k0)?
|
||||||
|
} else {
|
||||||
|
grad_kernel
|
||||||
|
};
|
||||||
|
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||||
|
}
|
||||||
Op::Conv2D {
|
Op::Conv2D {
|
||||||
arg,
|
arg,
|
||||||
kernel,
|
kernel,
|
||||||
@ -228,8 +297,18 @@ impl Tensor {
|
|||||||
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||||
.transpose(0, 1)?;
|
.transpose(0, 1)?;
|
||||||
let sum_grad = grads.or_insert(kernel)?;
|
let sum_grad = grads.or_insert(kernel)?;
|
||||||
|
let (_, _, k0, k1) = kernel.dims4()?;
|
||||||
|
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
||||||
|
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
||||||
|
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
||||||
|
} else {
|
||||||
|
grad_kernel
|
||||||
|
};
|
||||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||||
}
|
}
|
||||||
|
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
||||||
|
op: "conv-transpose1d",
|
||||||
|
})?,
|
||||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "conv-transpose2d",
|
op: "conv-transpose2d",
|
||||||
})?,
|
})?,
|
||||||
@ -374,7 +453,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Op::ToDType(arg) => {
|
Op::ToDType(arg) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
*sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
|
||||||
}
|
}
|
||||||
Op::Copy(arg) => {
|
Op::Copy(arg) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
@ -461,17 +540,47 @@ impl Tensor {
|
|||||||
Op::Unary(_, UnaryOp::Round) => {
|
Op::Unary(_, UnaryOp::Round) => {
|
||||||
Err(Error::BackwardNotSupported { op: "round" })?
|
Err(Error::BackwardNotSupported { op: "round" })?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
Op::Unary(arg, UnaryOp::Gelu) => {
|
||||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
let cube = arg.powf(3.)?;
|
||||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
|
||||||
|
let gelu_grad = (((0.5 * &tanh)?
|
||||||
|
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
|
||||||
|
+ 0.5)?;
|
||||||
|
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
||||||
|
}
|
||||||
|
Op::Unary(arg, UnaryOp::Erf) => {
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
|
||||||
|
let erf_grad =
|
||||||
|
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
|
||||||
|
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
|
||||||
|
}
|
||||||
|
Op::Unary(arg, UnaryOp::GeluErf) => {
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
|
||||||
|
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
|
||||||
|
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
|
||||||
|
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
|
||||||
|
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
|
||||||
|
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
|
||||||
|
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
|
||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Relu) => {
|
Op::Unary(arg, UnaryOp::Relu) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||||
}
|
}
|
||||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
Op::Elu(arg, alpha) => {
|
||||||
|
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
let zeros = arg.zeros_like()?;
|
||||||
|
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||||
|
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||||
|
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
||||||
|
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||||
|
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||||
|
}
|
||||||
Op::Powf(arg, e) => {
|
Op::Powf(arg, e) => {
|
||||||
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
@ -25,6 +25,33 @@ impl ParamsConv1D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct ParamsConvTranspose1D {
|
||||||
|
pub(crate) b_size: usize,
|
||||||
|
pub(crate) l_in: usize,
|
||||||
|
pub(crate) c_out: usize,
|
||||||
|
pub(crate) c_in: usize,
|
||||||
|
pub(crate) k_size: usize,
|
||||||
|
pub(crate) padding: usize,
|
||||||
|
pub(crate) output_padding: usize,
|
||||||
|
pub(crate) stride: usize,
|
||||||
|
pub(crate) dilation: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ParamsConvTranspose1D {
|
||||||
|
pub(crate) fn l_out(&self) -> usize {
|
||||||
|
(self.l_in - 1) * self.stride - 2 * self.padding
|
||||||
|
+ self.dilation * (self.k_size - 1)
|
||||||
|
+ self.output_padding
|
||||||
|
+ 1
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||||
|
let l_out = self.l_out();
|
||||||
|
vec![self.b_size, self.c_out, l_out]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum CudnnFwdAlgo {
|
pub enum CudnnFwdAlgo {
|
||||||
ImplicitGemm,
|
ImplicitGemm,
|
||||||
@ -160,6 +187,49 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies a 1D transposed convolution over the input tensor.
|
||||||
|
pub fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
kernel: &Self,
|
||||||
|
padding: usize,
|
||||||
|
output_padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (b_size, c_in, l_in) = self.dims3()?;
|
||||||
|
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||||
|
if c_in != c_in_k {
|
||||||
|
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||||
|
}
|
||||||
|
let params = ParamsConvTranspose1D {
|
||||||
|
b_size,
|
||||||
|
l_in,
|
||||||
|
k_size,
|
||||||
|
c_out,
|
||||||
|
c_in,
|
||||||
|
padding,
|
||||||
|
output_padding,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
};
|
||||||
|
let storage = self.storage().conv_transpose1d(
|
||||||
|
self.layout(),
|
||||||
|
&kernel.storage(),
|
||||||
|
kernel.layout(),
|
||||||
|
¶ms,
|
||||||
|
)?;
|
||||||
|
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||||
|
arg,
|
||||||
|
kernel,
|
||||||
|
padding: params.padding,
|
||||||
|
output_padding: params.output_padding,
|
||||||
|
stride: params.stride,
|
||||||
|
dilation: params.dilation,
|
||||||
|
});
|
||||||
|
let out_dims = params.out_dims();
|
||||||
|
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||||
let storage =
|
let storage =
|
||||||
self.storage()
|
self.storage()
|
||||||
|
@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
|||||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||||
let ids = match self.ids_l.contiguous_offsets() {
|
let ids = match self.ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &self.ids[a..b],
|
Some((a, b)) => &self.ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||||
};
|
};
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &src[a..b],
|
Some((a, b)) => &src[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||||
};
|
};
|
||||||
let dim = self.dim;
|
let dim = self.dim;
|
||||||
let ids_dims = self.ids_l.dims();
|
let ids_dims = self.ids_l.dims();
|
||||||
@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
|||||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||||
let src = match layout.contiguous_offsets() {
|
let src = match layout.contiguous_offsets() {
|
||||||
Some((a, b)) => &src[a..b],
|
Some((a, b)) => &src[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||||
};
|
};
|
||||||
let dim = self.dim;
|
let dim = self.dim;
|
||||||
let n_ids = match self.ids_l.dims() {
|
let n_ids = match self.ids_l.dims() {
|
||||||
@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
|||||||
let mut dst = vec![T::zero(); dst_len];
|
let mut dst = vec![T::zero(); dst_len];
|
||||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
|
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||||
Some((o1, o2)) => &src[o1..o2],
|
Some((o1, o2)) => &src[o1..o2],
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
|||||||
|
|
||||||
let ids = match self.ids_l.contiguous_offsets() {
|
let ids = match self.ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &self.ids[a..b],
|
Some((a, b)) => &self.ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||||
};
|
};
|
||||||
for left_i in 0..ids_left_len {
|
for left_i in 0..ids_left_len {
|
||||||
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
||||||
@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
|||||||
let mut dst = vec![T::zero(); dst_len];
|
let mut dst = vec![T::zero(); dst_len];
|
||||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
Some((o1, o2)) => &src[o1..o2],
|
Some((o1, o2)) => &src[o1..o2],
|
||||||
};
|
};
|
||||||
let dim = self.dim;
|
let dim = self.dim;
|
||||||
@ -1256,6 +1256,74 @@ impl Map1 for Im2Col {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||||
|
|
||||||
|
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||||
|
const OP: &'static str = "conv_transpose1d";
|
||||||
|
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||||
|
let p = self.0;
|
||||||
|
let inp = &inp[inp_l.start_offset()..];
|
||||||
|
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||||
|
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||||
|
let l_out = p.l_out();
|
||||||
|
|
||||||
|
// Output shape: [b_size, c_out, l_out].
|
||||||
|
let dst_elems = p.c_out * l_out * p.b_size;
|
||||||
|
let dst = vec![T::zero(); dst_elems];
|
||||||
|
let dst_s0 = p.c_out * l_out;
|
||||||
|
let dst_s1 = l_out;
|
||||||
|
let dst_s2 = 1;
|
||||||
|
|
||||||
|
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||||
|
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
||||||
|
let cont_s0 = p.l_in * p.c_in;
|
||||||
|
let cont_s1 = p.c_in;
|
||||||
|
for b_idx in 0..p.b_size {
|
||||||
|
for l_idx in 0..p.l_in {
|
||||||
|
for c_idx in 0..p.c_in {
|
||||||
|
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
|
||||||
|
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
|
||||||
|
inp_cont[dst_idx] = inp[src_idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for k_idx in 0..p.k_size {
|
||||||
|
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||||
|
let k_cont = (0..p.c_in)
|
||||||
|
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for b_idx in 0..p.b_size {
|
||||||
|
for l_idx in 0..p.l_in {
|
||||||
|
let out_idx = l_idx * p.stride + k_idx * p.dilation;
|
||||||
|
if out_idx < p.padding {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let out_idx = out_idx - p.padding;
|
||||||
|
if out_idx < l_out {
|
||||||
|
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
|
||||||
|
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
|
||||||
|
let mut d = T::zero();
|
||||||
|
unsafe {
|
||||||
|
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
|
||||||
|
}
|
||||||
|
let dst_p = dst.as_ptr();
|
||||||
|
// Safety: dst_idx are uniques per dst_c_idx which is used to
|
||||||
|
// parallelise the different tasks so no two threads can try to
|
||||||
|
// write at the same location.
|
||||||
|
unsafe {
|
||||||
|
let ptr = dst_p.add(dst_idx) as *mut T;
|
||||||
|
*ptr += d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||||
|
|
||||||
impl<'a> Map2 for Conv2D<'a> {
|
impl<'a> Map2 for Conv2D<'a> {
|
||||||
@ -2435,6 +2503,16 @@ impl BackendStorage for CpuStorage {
|
|||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
l: &Layout,
|
||||||
|
kernel: &Self,
|
||||||
|
kernel_l: &Layout,
|
||||||
|
params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||||
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
@ -2539,25 +2617,25 @@ impl BackendStorage for CpuStorage {
|
|||||||
Self::U8(ids) => {
|
Self::U8(ids) => {
|
||||||
let ids = match ids_l.contiguous_offsets() {
|
let ids = match ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &ids[a..b],
|
Some((a, b)) => &ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
};
|
};
|
||||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||||
}
|
}
|
||||||
Self::U32(ids) => {
|
Self::U32(ids) => {
|
||||||
let ids = match ids_l.contiguous_offsets() {
|
let ids = match ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &ids[a..b],
|
Some((a, b)) => &ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
};
|
};
|
||||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||||
}
|
}
|
||||||
Self::I64(ids) => {
|
Self::I64(ids) => {
|
||||||
let ids = match ids_l.contiguous_offsets() {
|
let ids = match ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &ids[a..b],
|
Some((a, b)) => &ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
};
|
};
|
||||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||||
}
|
}
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1808,6 +1808,16 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cudnn"))]
|
#[cfg(not(feature = "cudnn"))]
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
@ -2171,7 +2181,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
if src_l.is_contiguous() {
|
if src_l.is_contiguous() {
|
||||||
dev.dtod_copy(&src, &mut dst).w()?
|
dev.dtod_copy(&src, &mut dst).w()?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
|
@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
|||||||
pub enum DeviceLocation {
|
pub enum DeviceLocation {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda { gpu_id: usize },
|
Cuda { gpu_id: usize },
|
||||||
|
Metal { gpu_id: usize },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Device {
|
pub enum Device {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda(crate::CudaDevice),
|
Cuda(crate::CudaDevice),
|
||||||
|
Metal(crate::MetalDevice),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait NdArray {
|
pub trait NdArray {
|
||||||
@ -128,10 +130,15 @@ impl Device {
|
|||||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
||||||
|
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
|
Self::Cpu => CpuDevice.set_seed(seed),
|
||||||
Self::Cuda(c) => c.set_seed(seed),
|
Self::Cuda(c) => c.set_seed(seed),
|
||||||
|
Self::Metal(m) => m.set_seed(seed),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,6 +146,7 @@ impl Device {
|
|||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Self::Cpu, Self::Cpu) => true,
|
(Self::Cpu, Self::Cpu) => true,
|
||||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -147,21 +155,20 @@ impl Device {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu => DeviceLocation::Cpu,
|
Self::Cpu => DeviceLocation::Cpu,
|
||||||
Self::Cuda(device) => device.location(),
|
Self::Cuda(device) => device.location(),
|
||||||
|
Device::Metal(device) => device.location(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_cpu(&self) -> bool {
|
pub fn is_cpu(&self) -> bool {
|
||||||
match self {
|
matches!(self, Self::Cpu)
|
||||||
Self::Cpu => true,
|
|
||||||
Self::Cuda(_) => false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_cuda(&self) -> bool {
|
pub fn is_cuda(&self) -> bool {
|
||||||
match self {
|
matches!(self, Self::Cuda(_))
|
||||||
Self::Cpu => false,
|
}
|
||||||
Self::Cuda(_) => true,
|
|
||||||
}
|
pub fn is_metal(&self) -> bool {
|
||||||
|
matches!(self, Self::Metal(_))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||||
@ -185,8 +192,19 @@ impl Device {
|
|||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||||
Ok(Storage::Cuda(storage))
|
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||||
|
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
|
||||||
|
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||||
|
} else {
|
||||||
|
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
|
Ok(Storage::Cuda(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Device::Metal(_device) => {
|
||||||
|
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
|
// Ok(Storage::Metal(storage))
|
||||||
|
crate::bail!("Metal rand_uniform not implemented")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -213,8 +231,18 @@ impl Device {
|
|||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
|
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||||
|
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||||
|
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
|
||||||
|
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||||
|
} else {
|
||||||
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
|
Ok(Storage::Cuda(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Metal(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -238,6 +266,10 @@ impl Device {
|
|||||||
let storage = device.ones_impl(shape, dtype)?;
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,6 +283,10 @@ impl Device {
|
|||||||
let storage = device.zeros_impl(shape, dtype)?;
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,6 +298,11 @@ impl Device {
|
|||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = array.to_cpu_storage();
|
||||||
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -273,6 +314,11 @@ impl Device {
|
|||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = S::to_cpu_storage_owned(data);
|
||||||
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,9 @@ impl Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
|
crate::DeviceLocation::Metal { gpu_id } => {
|
||||||
|
format!(", metal:{}", gpu_id)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(f, "Tensor[")?;
|
write!(f, "Tensor[")?;
|
||||||
@ -476,6 +479,9 @@ impl std::fmt::Display for Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
|
crate::DeviceLocation::Metal { gpu_id } => {
|
||||||
|
format!(", metal:{}", gpu_id)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(
|
write!(
|
||||||
|
@ -79,6 +79,16 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
|
223
candle-core/src/dummy_metal_backend.rs
Normal file
223
candle-core/src/dummy_metal_backend.rs
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
#![allow(dead_code)]
|
||||||
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
|
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MetalDevice;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MetalStorage;
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum MetalError {
|
||||||
|
#[error("{0}")]
|
||||||
|
Message(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for MetalError {
|
||||||
|
fn from(e: String) -> Self {
|
||||||
|
MetalError::Message(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! fail {
|
||||||
|
() => {
|
||||||
|
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::backend::BackendStorage for MetalStorage {
|
||||||
|
type Device = MetalDevice;
|
||||||
|
|
||||||
|
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dtype(&self) -> DType {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(&self) -> &Self::Device {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv1d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConv1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv2d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConv2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose2d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConvTranspose2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scatter_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matmul(
|
||||||
|
&self,
|
||||||
|
_: &Self,
|
||||||
|
_: (usize, usize, usize, usize),
|
||||||
|
_: &Layout,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::backend::BackendDevice for MetalDevice {
|
||||||
|
type Storage = MetalStorage;
|
||||||
|
fn new(_: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, _: u64) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn same_device(&self, _: &Self) -> bool {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{DType, DeviceLocation, Layout, Shape};
|
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MatMulUnexpectedStriding {
|
pub struct MatMulUnexpectedStriding {
|
||||||
@ -142,6 +142,9 @@ pub enum Error {
|
|||||||
#[error("{op} expects at least one tensor")]
|
#[error("{op} expects at least one tensor")]
|
||||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||||
|
|
||||||
|
#[error("{op} expects at least two tensors")]
|
||||||
|
OpRequiresAtLeastTwoTensors { op: &'static str },
|
||||||
|
|
||||||
#[error("backward is not supported for {op}")]
|
#[error("backward is not supported for {op}")]
|
||||||
BackwardNotSupported { op: &'static str },
|
BackwardNotSupported { op: &'static str },
|
||||||
|
|
||||||
@ -149,6 +152,9 @@ pub enum Error {
|
|||||||
#[error("the candle crate has not been built with cuda support")]
|
#[error("the candle crate has not been built with cuda support")]
|
||||||
NotCompiledWithCudaSupport,
|
NotCompiledWithCudaSupport,
|
||||||
|
|
||||||
|
#[error("the candle crate has not been built with metal support")]
|
||||||
|
NotCompiledWithMetalSupport,
|
||||||
|
|
||||||
#[error("cannot find tensor {path}")]
|
#[error("cannot find tensor {path}")]
|
||||||
CannotFindTensor { path: String },
|
CannotFindTensor { path: String },
|
||||||
|
|
||||||
@ -156,6 +162,9 @@ pub enum Error {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
|
||||||
|
#[error("Metal error {0}")]
|
||||||
|
Metal(#[from] MetalError),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||||
|
|
||||||
|
@ -104,37 +104,31 @@ impl From<&Tensor> for TensorIndexer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! impl_from_range {
|
trait RB: RangeBounds<usize> {}
|
||||||
($range_type:ty) => {
|
impl RB for Range<usize> {}
|
||||||
impl From<$range_type> for TensorIndexer {
|
impl RB for RangeFrom<usize> {}
|
||||||
fn from(range: $range_type) -> Self {
|
impl RB for RangeFull {}
|
||||||
use std::ops::Bound::*;
|
impl RB for RangeInclusive<usize> {}
|
||||||
|
impl RB for RangeTo<usize> {}
|
||||||
|
impl RB for RangeToInclusive<usize> {}
|
||||||
|
|
||||||
let start = match range.start_bound() {
|
impl<T: RB> From<T> for TensorIndexer {
|
||||||
Included(idx) => Included(*idx),
|
fn from(range: T) -> Self {
|
||||||
Excluded(idx) => Excluded(*idx),
|
use std::ops::Bound::*;
|
||||||
Unbounded => Unbounded,
|
let start = match range.start_bound() {
|
||||||
};
|
Included(idx) => Included(*idx),
|
||||||
|
Excluded(idx) => Excluded(*idx),
|
||||||
let end = match range.end_bound() {
|
Unbounded => Unbounded,
|
||||||
Included(idx) => Included(*idx),
|
};
|
||||||
Excluded(idx) => Excluded(*idx),
|
let end = match range.end_bound() {
|
||||||
Unbounded => Unbounded,
|
Included(idx) => Included(*idx),
|
||||||
};
|
Excluded(idx) => Excluded(*idx),
|
||||||
|
Unbounded => Unbounded,
|
||||||
TensorIndexer::Narrow(start, end)
|
};
|
||||||
}
|
TensorIndexer::Narrow(start, end)
|
||||||
}
|
}
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_from_range!(Range<usize>);
|
|
||||||
impl_from_range!(RangeFrom<usize>);
|
|
||||||
impl_from_range!(RangeFull);
|
|
||||||
impl_from_range!(RangeInclusive<usize>);
|
|
||||||
impl_from_range!(RangeTo<usize>);
|
|
||||||
impl_from_range!(RangeToInclusive<usize>);
|
|
||||||
|
|
||||||
/// Trait used to implement multiple signatures for ease of use of the slicing
|
/// Trait used to implement multiple signatures for ease of use of the slicing
|
||||||
/// of a tensor
|
/// of a tensor
|
||||||
pub trait IndexOp<T> {
|
pub trait IndexOp<T> {
|
||||||
|
@ -49,9 +49,12 @@ mod device;
|
|||||||
pub mod display;
|
pub mod display;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
mod dummy_cuda_backend;
|
mod dummy_cuda_backend;
|
||||||
|
mod dummy_metal_backend;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
mod indexer;
|
mod indexer;
|
||||||
pub mod layout;
|
pub mod layout;
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
pub mod metal_backend;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
mod mkl;
|
mod mkl;
|
||||||
pub mod npy;
|
pub mod npy;
|
||||||
@ -87,6 +90,12 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
|
|||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||||
|
|
||||||
|
#[cfg(not(feature = "metal"))]
|
||||||
|
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
@ -114,14 +123,20 @@ pub trait Module {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for quantized::QMatMul {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
self.forward(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
self(xs)
|
self(xs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||||
|
// separate the training and evaluation behaviors.
|
||||||
|
pub trait ModuleT {
|
||||||
|
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: Module> ModuleT for M {
|
||||||
|
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
|
||||||
|
self.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
1010
candle-core/src/metal_backend.rs
Normal file
1010
candle-core/src/metal_backend.rs
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
|||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use num_traits::float::Float;
|
use num_traits::float::Float;
|
||||||
|
|
||||||
@ -90,6 +90,16 @@ pub enum Op {
|
|||||||
dilation: usize,
|
dilation: usize,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
ConvTranspose1D {
|
||||||
|
arg: Tensor,
|
||||||
|
kernel: Tensor,
|
||||||
|
padding: usize,
|
||||||
|
output_padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
},
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
Conv2D {
|
Conv2D {
|
||||||
arg: Tensor,
|
arg: Tensor,
|
||||||
@ -174,6 +184,18 @@ pub trait CustomOp1 {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_storage: &MetalStorage,
|
||||||
|
_layout: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||||
/// The function should return the gradient of the argument.
|
/// The function should return the gradient of the argument.
|
||||||
@ -209,6 +231,20 @@ pub trait CustomOp2 {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
fn bwd(
|
fn bwd(
|
||||||
&self,
|
&self,
|
||||||
_arg1: &Tensor,
|
_arg1: &Tensor,
|
||||||
@ -251,6 +287,22 @@ pub trait CustomOp3 {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
fn bwd(
|
fn bwd(
|
||||||
&self,
|
&self,
|
||||||
_arg1: &Tensor,
|
_arg1: &Tensor,
|
||||||
@ -536,13 +588,13 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
|||||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||||
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
||||||
unary_op!(Abs, "abs", v, v.abs());
|
|
||||||
unary_op!(Neg, "neg", v, -v);
|
unary_op!(Neg, "neg", v, -v);
|
||||||
unary_op!(Recip, "recip", v, v.recip());
|
unary_op!(Recip, "recip", v, v.recip());
|
||||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||||
|
|
||||||
/// `gelu` operation
|
/// Tanh based approximation of the `gelu` operation
|
||||||
|
/// GeluErf is the more precise one.
|
||||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
impl UnaryOpT for Gelu {
|
impl UnaryOpT for Gelu {
|
||||||
const NAME: &'static str = "gelu";
|
const NAME: &'static str = "gelu";
|
||||||
@ -632,6 +684,8 @@ impl UnaryOpT for Gelu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// `erf` operation
|
||||||
|
/// <https://en.wikipedia.org/wiki/Error_function>
|
||||||
impl UnaryOpT for Erf {
|
impl UnaryOpT for Erf {
|
||||||
const NAME: &'static str = "erf";
|
const NAME: &'static str = "erf";
|
||||||
const KERNEL: &'static str = "uerf";
|
const KERNEL: &'static str = "uerf";
|
||||||
@ -666,6 +720,40 @@ impl UnaryOpT for Erf {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl UnaryOpT for Abs {
|
||||||
|
const NAME: &'static str = "abs";
|
||||||
|
const KERNEL: &'static str = "uabs";
|
||||||
|
const V: Self = Abs;
|
||||||
|
#[inline(always)]
|
||||||
|
fn bf16(v: bf16) -> bf16 {
|
||||||
|
v.abs()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f16(v: f16) -> f16 {
|
||||||
|
v.abs()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32(v: f32) -> f32 {
|
||||||
|
v.abs()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
v.abs()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u8(v: u8) -> u8 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u32(v: u32) -> u32 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn i64(v: i64) -> i64 {
|
||||||
|
v.abs()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UnaryOpT for Ceil {
|
impl UnaryOpT for Ceil {
|
||||||
const NAME: &'static str = "ceil";
|
const NAME: &'static str = "ceil";
|
||||||
const KERNEL: &'static str = "uceil";
|
const KERNEL: &'static str = "uceil";
|
||||||
@ -887,6 +975,10 @@ impl BackpropOp {
|
|||||||
};
|
};
|
||||||
Self(op)
|
Self(op)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn is_none(&self) -> bool {
|
||||||
|
self.0.is_none()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for BackpropOp {
|
impl std::ops::Deref for BackpropOp {
|
||||||
|
@ -193,6 +193,50 @@ impl Object {
|
|||||||
_ => Err(self),
|
_ => Err(self),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn into_tensor_info(
|
||||||
|
self,
|
||||||
|
name: Self,
|
||||||
|
dir_name: &std::path::Path,
|
||||||
|
) -> Result<Option<TensorInfo>> {
|
||||||
|
let name = match name.unicode() {
|
||||||
|
Ok(name) => name,
|
||||||
|
Err(_) => return Ok(None),
|
||||||
|
};
|
||||||
|
let (callable, args) = match self.reduce() {
|
||||||
|
Ok(callable_args) => callable_args,
|
||||||
|
_ => return Ok(None),
|
||||||
|
};
|
||||||
|
let (callable, args) = match callable {
|
||||||
|
Object::Class {
|
||||||
|
module_name,
|
||||||
|
class_name,
|
||||||
|
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
|
||||||
|
let mut args = args.tuple()?;
|
||||||
|
let callable = args.remove(0);
|
||||||
|
let args = args.remove(1);
|
||||||
|
(callable, args)
|
||||||
|
}
|
||||||
|
_ => (callable, args),
|
||||||
|
};
|
||||||
|
match callable {
|
||||||
|
Object::Class {
|
||||||
|
module_name,
|
||||||
|
class_name,
|
||||||
|
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
||||||
|
_ => return Ok(None),
|
||||||
|
};
|
||||||
|
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||||
|
let mut path = dir_name.to_path_buf();
|
||||||
|
path.push(file_path);
|
||||||
|
Ok(Some(TensorInfo {
|
||||||
|
name,
|
||||||
|
dtype,
|
||||||
|
layout,
|
||||||
|
path: path.to_string_lossy().into_owned(),
|
||||||
|
storage_size,
|
||||||
|
}))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<Object> for String {
|
impl TryFrom<Object> for String {
|
||||||
@ -565,6 +609,7 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
|||||||
"HalfStorage" => DType::F16,
|
"HalfStorage" => DType::F16,
|
||||||
"BFloat16Storage" => DType::BF16,
|
"BFloat16Storage" => DType::BF16,
|
||||||
"ByteStorage" => DType::U8,
|
"ByteStorage" => DType::U8,
|
||||||
|
"LongStorage" => DType::I64,
|
||||||
other => {
|
other => {
|
||||||
crate::bail!("unsupported storage type {other}")
|
crate::bail!("unsupported storage type {other}")
|
||||||
}
|
}
|
||||||
@ -623,50 +668,10 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
};
|
};
|
||||||
if let Object::Dict(key_values) = obj {
|
if let Object::Dict(key_values) = obj {
|
||||||
for (name, value) in key_values.into_iter() {
|
for (name, value) in key_values.into_iter() {
|
||||||
let name = match name.unicode() {
|
match value.into_tensor_info(name, &dir_name) {
|
||||||
Ok(name) => name,
|
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
|
||||||
Err(_) => continue,
|
Ok(None) => {}
|
||||||
};
|
Err(err) => eprintln!("skipping: {err:?}"),
|
||||||
let (callable, args) = match value.reduce() {
|
|
||||||
Ok(callable_args) => callable_args,
|
|
||||||
_ => continue,
|
|
||||||
};
|
|
||||||
let (callable, args) = match callable {
|
|
||||||
Object::Class {
|
|
||||||
module_name,
|
|
||||||
class_name,
|
|
||||||
} if module_name == "torch._tensor"
|
|
||||||
&& class_name == "_rebuild_from_type_v2" =>
|
|
||||||
{
|
|
||||||
let mut args = args.tuple()?;
|
|
||||||
let callable = args.remove(0);
|
|
||||||
let args = args.remove(1);
|
|
||||||
(callable, args)
|
|
||||||
}
|
|
||||||
_ => (callable, args),
|
|
||||||
};
|
|
||||||
match callable {
|
|
||||||
Object::Class {
|
|
||||||
module_name,
|
|
||||||
class_name,
|
|
||||||
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
|
||||||
_ => continue,
|
|
||||||
};
|
|
||||||
match rebuild_args(args) {
|
|
||||||
Ok((layout, dtype, file_path, storage_size)) => {
|
|
||||||
let mut path = dir_name.clone();
|
|
||||||
path.push(file_path);
|
|
||||||
tensor_infos.push(TensorInfo {
|
|
||||||
name,
|
|
||||||
dtype,
|
|
||||||
layout,
|
|
||||||
path: path.to_string_lossy().into_owned(),
|
|
||||||
storage_size,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
eprintln!("skipping {name}: {err:?}")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -723,3 +728,16 @@ impl PthTensors {
|
|||||||
Ok(Some(tensor))
|
Ok(Some(tensor))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Read all the tensors from a PyTorch pth file.
|
||||||
|
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||||
|
let pth = PthTensors::new(path)?;
|
||||||
|
let tensor_names = pth.tensor_infos.keys();
|
||||||
|
let mut tensors = Vec::with_capacity(tensor_names.len());
|
||||||
|
for name in tensor_names {
|
||||||
|
if let Some(tensor) = pth.get(name)? {
|
||||||
|
tensors.push((name.to_string(), tensor))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(tensors)
|
||||||
|
}
|
||||||
|
@ -50,14 +50,9 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
|
|||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||||
let qk = QK8_0;
|
let qk = QK8_0;
|
||||||
let nb = n / qk;
|
|
||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
if nb % 2 != 0 {
|
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut acc = _mm256_setzero_ps();
|
let mut acc = _mm256_setzero_ps();
|
||||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
|
@ -29,6 +29,7 @@ impl TryFrom<u32> for Magic {
|
|||||||
pub enum VersionedMagic {
|
pub enum VersionedMagic {
|
||||||
GgufV1,
|
GgufV1,
|
||||||
GgufV2,
|
GgufV2,
|
||||||
|
GgufV3,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VersionedMagic {
|
impl VersionedMagic {
|
||||||
@ -39,6 +40,7 @@ impl VersionedMagic {
|
|||||||
let versioned_magic = match (magic, version) {
|
let versioned_magic = match (magic, version) {
|
||||||
(Magic::Gguf, 1) => Self::GgufV1,
|
(Magic::Gguf, 1) => Self::GgufV1,
|
||||||
(Magic::Gguf, 2) => Self::GgufV2,
|
(Magic::Gguf, 2) => Self::GgufV2,
|
||||||
|
(Magic::Gguf, 3) => Self::GgufV3,
|
||||||
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||||
};
|
};
|
||||||
Ok(versioned_magic)
|
Ok(versioned_magic)
|
||||||
@ -84,7 +86,9 @@ pub struct Content {
|
|||||||
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
||||||
let len = match magic {
|
let len = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||||
|
reader.read_u64::<LittleEndian>()? as usize
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let mut v = vec![0u8; len];
|
let mut v = vec![0u8; len];
|
||||||
reader.read_exact(&mut v)?;
|
reader.read_exact(&mut v)?;
|
||||||
@ -284,7 +288,9 @@ impl Value {
|
|||||||
let value_type = ValueType::from_u32(value_type)?;
|
let value_type = ValueType::from_u32(value_type)?;
|
||||||
let len = match magic {
|
let len = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||||
|
reader.read_u64::<LittleEndian>()? as usize
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let mut vs = Vec::with_capacity(len);
|
let mut vs = Vec::with_capacity(len);
|
||||||
for _ in 0..len {
|
for _ in 0..len {
|
||||||
@ -381,11 +387,15 @@ impl Content {
|
|||||||
|
|
||||||
let tensor_count = match magic {
|
let tensor_count = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||||
|
reader.read_u64::<LittleEndian>()? as usize
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let metadata_kv_count = match magic {
|
let metadata_kv_count = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||||
|
reader.read_u64::<LittleEndian>()? as usize
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut metadata = HashMap::new();
|
let mut metadata = HashMap::new();
|
||||||
@ -407,7 +417,7 @@ impl Content {
|
|||||||
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
||||||
dimensions.into_iter().map(|c| c as usize).collect()
|
dimensions.into_iter().map(|c| c as usize).collect()
|
||||||
}
|
}
|
||||||
VersionedMagic::GgufV2 => {
|
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||||
let mut dimensions = vec![0; n_dimensions as usize];
|
let mut dimensions = vec![0; n_dimensions as usize];
|
||||||
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
||||||
dimensions.into_iter().map(|c| c as usize).collect()
|
dimensions.into_iter().map(|c| c as usize).collect()
|
||||||
|
@ -236,14 +236,9 @@ impl GgmlType for BlockQ4_0 {
|
|||||||
|
|
||||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
let qk = QK8_0;
|
let qk = QK8_0;
|
||||||
let nb = n / qk;
|
|
||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
if nb % 2 != 0 {
|
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generic implementation.
|
// Generic implementation.
|
||||||
let mut sumf = 0f32;
|
let mut sumf = 0f32;
|
||||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||||
|
@ -307,8 +307,8 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QMatMul {
|
impl crate::Module for QMatMul {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||||
Self::Tensor(w) => {
|
Self::Tensor(w) => {
|
||||||
|
@ -19,42 +19,29 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
|||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
if nb % 2 != 0 {
|
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
for i in 0..nb {
|
||||||
for i in (0..nb).step_by(2) {
|
|
||||||
let x0 = &xs[i];
|
let x0 = &xs[i];
|
||||||
let x1 = &xs[i + 1];
|
|
||||||
let y0 = &ys[i];
|
let y0 = &ys[i];
|
||||||
let y1 = &ys[i + 1];
|
|
||||||
|
|
||||||
let m4b = vdupq_n_u8(0x0F);
|
let m4b = vdupq_n_u8(0x0F);
|
||||||
let s8b = vdupq_n_s8(0x8);
|
let s8b = vdupq_n_s8(0x8);
|
||||||
|
|
||||||
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
||||||
let v0_1 = vld1q_u8(x1.qs.as_ptr());
|
|
||||||
|
|
||||||
// 4-bit -> 8-bit
|
// 4-bit -> 8-bit
|
||||||
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
||||||
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
||||||
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
|
||||||
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
|
||||||
|
|
||||||
// sub 8
|
// sub 8
|
||||||
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||||
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||||
let v0_1ls = vsubq_s8(v0_1l, s8b);
|
|
||||||
let v0_1hs = vsubq_s8(v0_1h, s8b);
|
|
||||||
|
|
||||||
// load y
|
// load y
|
||||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||||
let v1_1l = vld1q_s8(y1.qs.as_ptr());
|
|
||||||
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
|
|
||||||
|
|
||||||
// TODO: Support dotprod when it's available outside of nightly.
|
// TODO: Support dotprod when it's available outside of nightly.
|
||||||
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||||
@ -62,28 +49,16 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
|||||||
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||||
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||||
|
|
||||||
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
|
|
||||||
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
|
||||||
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
|
|
||||||
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
|
||||||
|
|
||||||
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||||
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||||
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
|
||||||
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
|
||||||
|
|
||||||
sumv0 = vmlaq_n_f32(
|
sumv0 = vmlaq_n_f32(
|
||||||
sumv0,
|
sumv0,
|
||||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||||
x0.d.to_f32() * y0.d.to_f32(),
|
x0.d.to_f32() * y0.d.to_f32(),
|
||||||
);
|
);
|
||||||
sumv1 = vmlaq_n_f32(
|
|
||||||
sumv1,
|
|
||||||
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
|
|
||||||
x1.d.to_f32() * y1.d.to_f32(),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
Ok(vaddvq_f32(sumv0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,28 +69,18 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
|||||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
let nb = n / QK8_0;
|
let nb = n / QK8_0;
|
||||||
if nb % 2 != 0 {
|
|
||||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
|
||||||
}
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
for i in 0..nb {
|
||||||
for i in (0..nb).step_by(2) {
|
|
||||||
let x0 = &xs[i];
|
let x0 = &xs[i];
|
||||||
let x1 = &xs[i + 1];
|
|
||||||
let y0 = &ys[i];
|
let y0 = &ys[i];
|
||||||
let y1 = &ys[i + 1];
|
|
||||||
|
|
||||||
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
||||||
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
||||||
let x1_0 = vld1q_s8(x1.qs.as_ptr());
|
|
||||||
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
|
|
||||||
|
|
||||||
// load y
|
// load y
|
||||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||||
let y1_0 = vld1q_s8(y1.qs.as_ptr());
|
|
||||||
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
|
|
||||||
|
|
||||||
// TODO dotprod once this is the intrinsics are.
|
// TODO dotprod once this is the intrinsics are.
|
||||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||||
@ -123,28 +88,16 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
|||||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||||
|
|
||||||
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
|
|
||||||
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
|
||||||
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
|
|
||||||
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
|
||||||
|
|
||||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||||
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
|
||||||
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
|
||||||
|
|
||||||
sumv0 = vmlaq_n_f32(
|
sumv0 = vmlaq_n_f32(
|
||||||
sumv0,
|
sumv0,
|
||||||
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
||||||
x0.d.to_f32() * y0.d.to_f32(),
|
x0.d.to_f32() * y0.d.to_f32(),
|
||||||
);
|
);
|
||||||
sumv1 = vmlaq_n_f32(
|
|
||||||
sumv1,
|
|
||||||
vcvtq_f32_s32(vaddq_s32(p2, p3)),
|
|
||||||
x1.d.to_f32() * y1.d.to_f32(),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
Ok(vaddvq_f32(sumv0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,10 +11,6 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
|||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
let nb = n / QK8_0;
|
|
||||||
if nb % 2 != 0 {
|
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
|
||||||
}
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut acc = f32x4_splat(0.0f32);
|
let mut acc = f32x4_splat(0.0f32);
|
||||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
@ -61,10 +57,6 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
|||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
let nb = n / QK8_0;
|
|
||||||
if nb % 2 != 0 {
|
|
||||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
|
||||||
}
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut acc = f32x4_splat(0.0f32);
|
let mut acc = f32x4_splat(0.0f32);
|
||||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
|
@ -203,7 +203,7 @@ impl Shape {
|
|||||||
|
|
||||||
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
||||||
/// broadcasted shape. This is to be used for binary pointwise ops.
|
/// broadcasted shape. This is to be used for binary pointwise ops.
|
||||||
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||||
let lhs = self;
|
let lhs = self;
|
||||||
let lhs_dims = lhs.dims();
|
let lhs_dims = lhs.dims();
|
||||||
let rhs_dims = rhs.dims();
|
let rhs_dims = rhs.dims();
|
||||||
@ -511,154 +511,119 @@ impl ShapeWithOneHole for ((),) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
|
||||||
|
if prod_d == 0 {
|
||||||
|
crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
|
||||||
|
}
|
||||||
|
if el_count % prod_d != 0 {
|
||||||
|
crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
|
||||||
|
}
|
||||||
|
Ok(el_count / prod_d)
|
||||||
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for ((), usize) {
|
impl ShapeWithOneHole for ((), usize) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let ((), d1) = self;
|
let ((), d1) = self;
|
||||||
if el_count % d1 != 0 {
|
Ok((hole_size(el_count, d1, &self)?, d1).into())
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
|
||||||
}
|
|
||||||
Ok((el_count / d1, d1).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, ()) {
|
impl ShapeWithOneHole for (usize, ()) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let (d1, ()) = self;
|
let (d1, ()) = self;
|
||||||
if el_count % d1 != 0 {
|
Ok((d1, hole_size(el_count, d1, &self)?).into())
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
|
||||||
}
|
|
||||||
Ok((d1, el_count / d1).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for ((), usize, usize) {
|
impl ShapeWithOneHole for ((), usize, usize) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let ((), d1, d2) = self;
|
let ((), d1, d2) = self;
|
||||||
let d = d1 * d2;
|
Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
|
||||||
if el_count % d != 0 {
|
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
|
||||||
}
|
|
||||||
Ok((el_count / d, d1, d2).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, (), usize) {
|
impl ShapeWithOneHole for (usize, (), usize) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let (d1, (), d2) = self;
|
let (d1, (), d2) = self;
|
||||||
let d = d1 * d2;
|
Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
|
||||||
if el_count % d != 0 {
|
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
|
||||||
}
|
|
||||||
Ok((d1, el_count / d, d2).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, usize, ()) {
|
impl ShapeWithOneHole for (usize, usize, ()) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let (d1, d2, ()) = self;
|
let (d1, d2, ()) = self;
|
||||||
let d = d1 * d2;
|
Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
|
||||||
if el_count % d != 0 {
|
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
|
||||||
}
|
|
||||||
Ok((d1, d2, el_count / d).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for ((), usize, usize, usize) {
|
impl ShapeWithOneHole for ((), usize, usize, usize) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let ((), d1, d2, d3) = self;
|
let ((), d1, d2, d3) = self;
|
||||||
let d = d1 * d2 * d3;
|
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||||
if el_count % d != 0 {
|
Ok((d, d1, d2, d3).into())
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
|
||||||
}
|
|
||||||
Ok((el_count / d, d1, d2, d3).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, (), usize, usize) {
|
impl ShapeWithOneHole for (usize, (), usize, usize) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let (d1, (), d2, d3) = self;
|
let (d1, (), d2, d3) = self;
|
||||||
let d = d1 * d2 * d3;
|
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||||
if el_count % d != 0 {
|
Ok((d1, d, d2, d3).into())
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
|
||||||
}
|
|
||||||
Ok((d1, el_count / d, d2, d3).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, usize, (), usize) {
|
impl ShapeWithOneHole for (usize, usize, (), usize) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let (d1, d2, (), d3) = self;
|
let (d1, d2, (), d3) = self;
|
||||||
let d = d1 * d2 * d3;
|
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||||
if el_count % d != 0 {
|
Ok((d1, d2, d, d3).into())
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
|
||||||
}
|
|
||||||
Ok((d1, d2, el_count / d, d3).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, usize, usize, ()) {
|
impl ShapeWithOneHole for (usize, usize, usize, ()) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let (d1, d2, d3, ()) = self;
|
let (d1, d2, d3, ()) = self;
|
||||||
let d = d1 * d2 * d3;
|
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||||
if el_count % d != 0 {
|
Ok((d1, d2, d3, d).into())
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
|
||||||
}
|
|
||||||
Ok((d1, d2, d3, el_count / d).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
|
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let ((), d1, d2, d3, d4) = self;
|
let ((), d1, d2, d3, d4) = self;
|
||||||
let d = d1 * d2 * d3 * d4;
|
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||||
if el_count % d != 0 {
|
Ok((d, d1, d2, d3, d4).into())
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
|
||||||
}
|
|
||||||
Ok((el_count / d, d1, d2, d3, d4).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
|
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let (d1, (), d2, d3, d4) = self;
|
let (d1, (), d2, d3, d4) = self;
|
||||||
let d = d1 * d2 * d3 * d4;
|
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||||
if el_count % d != 0 {
|
Ok((d1, d, d2, d3, d4).into())
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
|
||||||
}
|
|
||||||
Ok((d1, el_count / d, d2, d3, d4).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
|
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let (d1, d2, (), d3, d4) = self;
|
let (d1, d2, (), d3, d4) = self;
|
||||||
let d = d1 * d2 * d3 * d4;
|
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||||
if el_count % d != 0 {
|
Ok((d1, d2, d, d3, d4).into())
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
|
||||||
}
|
|
||||||
Ok((d1, d2, el_count / d, d3, d4).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
|
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let (d1, d2, d3, (), d4) = self;
|
let (d1, d2, d3, (), d4) = self;
|
||||||
let d = d1 * d2 * d3 * d4;
|
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||||
if el_count % d != 0 {
|
Ok((d1, d2, d3, d, d4).into())
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
|
||||||
}
|
|
||||||
Ok((d1, d2, d3, el_count / d, d4).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
let (d1, d2, d3, d4, ()) = self;
|
let (d1, d2, d3, d4, ()) = self;
|
||||||
let d = d1 * d2 * d3 * d4;
|
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||||
if el_count % d != 0 {
|
Ok((d1, d2, d3, d4, d).into())
|
||||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
|
||||||
}
|
|
||||||
Ok((d1, d2, d3, d4, el_count / d).into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||||
|
|
||||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||||
// out of memory. Instead try_clone should be used.
|
// out of memory. Instead try_clone should be used.
|
||||||
@ -8,6 +8,7 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape
|
|||||||
pub enum Storage {
|
pub enum Storage {
|
||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
Cuda(CudaStorage),
|
Cuda(CudaStorage),
|
||||||
|
Metal(MetalStorage),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Storage {
|
impl Storage {
|
||||||
@ -18,6 +19,10 @@ impl Storage {
|
|||||||
let storage = storage.try_clone(layout)?;
|
let storage = storage.try_clone(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.try_clone(layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -25,6 +30,7 @@ impl Storage {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu(_) => Device::Cpu,
|
Self::Cpu(_) => Device::Cpu,
|
||||||
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||||
|
Self::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,6 +38,7 @@ impl Storage {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu(storage) => storage.dtype(),
|
Self::Cpu(storage) => storage.dtype(),
|
||||||
Self::Cuda(storage) => storage.dtype(),
|
Self::Cuda(storage) => storage.dtype(),
|
||||||
|
Self::Metal(storage) => storage.dtype(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,6 +72,10 @@ impl Storage {
|
|||||||
let storage = storage.affine(layout, mul, add)?;
|
let storage = storage.affine(layout, mul, add)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.affine(layout, mul, add)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,6 +89,10 @@ impl Storage {
|
|||||||
let storage = storage.powf(layout, alpha)?;
|
let storage = storage.powf(layout, alpha)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.powf(layout, alpha)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,6 +106,10 @@ impl Storage {
|
|||||||
let storage = storage.elu(layout, alpha)?;
|
let storage = storage.elu(layout, alpha)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.elu(layout, alpha)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,6 +131,10 @@ impl Storage {
|
|||||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => {
|
(lhs, rhs) => {
|
||||||
// Should not happen because of the same device check above but we're defensive
|
// Should not happen because of the same device check above but we're defensive
|
||||||
// anyway.
|
// anyway.
|
||||||
@ -135,6 +158,10 @@ impl Storage {
|
|||||||
let storage = storage.reduce_op(op, layout, s)?;
|
let storage = storage.reduce_op(op, layout, s)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.reduce_op(op, layout, s)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,6 +175,10 @@ impl Storage {
|
|||||||
let storage = storage.to_dtype(layout, dtype)?;
|
let storage = storage.to_dtype(layout, dtype)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.to_dtype(layout, dtype)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,6 +192,10 @@ impl Storage {
|
|||||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||||
Ok((Self::Cuda(storage), shape))
|
Ok((Self::Cuda(storage), shape))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let (storage, shape) = c.metal_fwd(storage, l)?;
|
||||||
|
Ok((Self::Metal(storage), shape))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,6 +216,10 @@ impl Storage {
|
|||||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||||
Ok((Self::Cuda(s), shape))
|
Ok((Self::Cuda(s), shape))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s1), Self::Metal(s2)) => {
|
||||||
|
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
|
||||||
|
Ok((Self::Metal(s), shape))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -205,6 +244,10 @@ impl Storage {
|
|||||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||||
Ok((Self::Cuda(s), shape))
|
Ok((Self::Cuda(s), shape))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
||||||
|
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||||
|
Ok((Self::Metal(s), shape))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -219,6 +262,10 @@ impl Storage {
|
|||||||
let storage = storage.unary_impl::<B>(layout)?;
|
let storage = storage.unary_impl::<B>(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.unary_impl::<B>(layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -239,6 +286,10 @@ impl Storage {
|
|||||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => {
|
(lhs, rhs) => {
|
||||||
// Should not happen because of the same device check above but we're defensive
|
// Should not happen because of the same device check above but we're defensive
|
||||||
// anyway.
|
// anyway.
|
||||||
@ -270,6 +321,10 @@ impl Storage {
|
|||||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -279,6 +334,33 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
l: &Layout,
|
||||||
|
kernel: &Self,
|
||||||
|
kernel_l: &Layout,
|
||||||
|
params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.same_device(kernel, "conv-transpose1d")?;
|
||||||
|
self.same_dtype(kernel, "conv-transpose1d")?;
|
||||||
|
match (self, &kernel) {
|
||||||
|
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||||
|
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Cpu(s))
|
||||||
|
}
|
||||||
|
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||||
|
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Cuda(s))
|
||||||
|
}
|
||||||
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: lhs.device().location(),
|
||||||
|
rhs: rhs.device().location(),
|
||||||
|
op: "conv-transpose1d",
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn conv2d(
|
pub(crate) fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
@ -297,6 +379,10 @@ impl Storage {
|
|||||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -324,6 +410,10 @@ impl Storage {
|
|||||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -348,6 +438,10 @@ impl Storage {
|
|||||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -366,6 +460,10 @@ impl Storage {
|
|||||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -379,6 +477,10 @@ impl Storage {
|
|||||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -392,6 +494,10 @@ impl Storage {
|
|||||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -415,6 +521,10 @@ impl Storage {
|
|||||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
|
||||||
|
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -441,6 +551,10 @@ impl Storage {
|
|||||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes)) => {
|
||||||
|
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -465,6 +579,10 @@ impl Storage {
|
|||||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||||
|
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -489,6 +607,10 @@ impl Storage {
|
|||||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||||
|
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -510,6 +632,10 @@ impl Storage {
|
|||||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -537,6 +663,10 @@ impl Storage {
|
|||||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -556,6 +686,9 @@ impl Storage {
|
|||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||||
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||||
|
(Self::Metal(src), Self::Metal(dst)) => {
|
||||||
|
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
|
@ -6,7 +6,7 @@ use crate::op::{
|
|||||||
};
|
};
|
||||||
use crate::scalar::TensorOrScalar;
|
use crate::scalar::TensorOrScalar;
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
/// Unique identifier for tensors.
|
/// Unique identifier for tensors.
|
||||||
@ -385,11 +385,21 @@ impl Tensor {
|
|||||||
step: D,
|
step: D,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
if D::is_zero(&step) {
|
||||||
|
crate::bail!("step cannot be zero")
|
||||||
|
}
|
||||||
let mut data = vec![];
|
let mut data = vec![];
|
||||||
let mut current = start;
|
let mut current = start;
|
||||||
while current < end {
|
if step >= D::zero() {
|
||||||
data.push(current);
|
while current < end {
|
||||||
current += step;
|
data.push(current);
|
||||||
|
current += step;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
while current > end {
|
||||||
|
data.push(current);
|
||||||
|
current += step;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let len = data.len();
|
let len = data.len();
|
||||||
Self::from_vec_impl(data, len, device, false)
|
Self::from_vec_impl(data, len, device, false)
|
||||||
@ -449,7 +459,7 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns true if the computation graph should track this op, that is if it is
|
/// Returns true if the computation graph should track this op, that is if it is
|
||||||
/// a variable or if it has some variable as dependencies.
|
/// a variable or if it has some variable as dependencies.
|
||||||
pub(crate) fn track_op(&self) -> bool {
|
pub fn track_op(&self) -> bool {
|
||||||
self.is_variable || self.op.is_some()
|
self.is_variable || self.op.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -467,6 +477,12 @@ impl Tensor {
|
|||||||
broadcast_binary_op!(broadcast_div, div);
|
broadcast_binary_op!(broadcast_div, div);
|
||||||
broadcast_binary_op!(broadcast_maximum, maximum);
|
broadcast_binary_op!(broadcast_maximum, maximum);
|
||||||
broadcast_binary_op!(broadcast_minimum, minimum);
|
broadcast_binary_op!(broadcast_minimum, minimum);
|
||||||
|
broadcast_binary_op!(broadcast_eq, eq);
|
||||||
|
broadcast_binary_op!(broadcast_ne, ne);
|
||||||
|
broadcast_binary_op!(broadcast_lt, lt);
|
||||||
|
broadcast_binary_op!(broadcast_le, le);
|
||||||
|
broadcast_binary_op!(broadcast_gt, gt);
|
||||||
|
broadcast_binary_op!(broadcast_ge, ge);
|
||||||
|
|
||||||
unary_op!(recip, Recip);
|
unary_op!(recip, Recip);
|
||||||
unary_op!(neg, Neg);
|
unary_op!(neg, Neg);
|
||||||
@ -513,6 +529,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -540,6 +557,73 @@ impl Tensor {
|
|||||||
Ok(inp)
|
Ok(inp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates grids of coordinates specified by the 1D inputs.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `args` - A slice of 1D tensors.
|
||||||
|
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
|
||||||
|
/// first dimension corresponds to the cardinality of the second input and the second
|
||||||
|
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
||||||
|
/// dimensions are in the same order as the cardinality of the inputs.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use candle_core::{Tensor, Device, Shape};
|
||||||
|
/// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
|
||||||
|
/// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(grids_xy.len(), 2);
|
||||||
|
/// assert_eq!(grids_xy[0].dims(), &[3, 3]);
|
||||||
|
///
|
||||||
|
/// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
|
||||||
|
/// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
|
||||||
|
///
|
||||||
|
/// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
|
||||||
|
/// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// * Will return `Err` if `args` contains less than 2 tensors.
|
||||||
|
///
|
||||||
|
pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
|
||||||
|
if args.len() <= 1 {
|
||||||
|
Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
|
||||||
|
}
|
||||||
|
let args: Vec<_> = if xy_indexing {
|
||||||
|
args.iter().rev().collect()
|
||||||
|
} else {
|
||||||
|
args.iter().collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut shape = Vec::with_capacity(args.len());
|
||||||
|
for arg in args.iter() {
|
||||||
|
shape.push(arg.as_ref().dims1()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut grids = Vec::with_capacity(args.len());
|
||||||
|
for idx in 0..args.len() {
|
||||||
|
let mut ones = vec![1usize; args.len()];
|
||||||
|
ones[idx] = shape[idx];
|
||||||
|
let arg = args[idx].as_ref().reshape(ones)?;
|
||||||
|
let mut repeats = shape.clone();
|
||||||
|
repeats[idx] = 1;
|
||||||
|
let repeated_tensor = arg.repeat(repeats)?;
|
||||||
|
grids.push(repeated_tensor);
|
||||||
|
}
|
||||||
|
if xy_indexing {
|
||||||
|
grids.reverse();
|
||||||
|
}
|
||||||
|
Ok(grids)
|
||||||
|
}
|
||||||
|
|
||||||
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
|
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
|
||||||
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
|
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
|
||||||
/// be performed.
|
/// be performed.
|
||||||
@ -615,15 +699,23 @@ impl Tensor {
|
|||||||
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
||||||
let dims = self.dims();
|
let dims = self.dims();
|
||||||
let dim = dim.to_index(self.shape(), "narrow")?;
|
let dim = dim.to_index(self.shape(), "narrow")?;
|
||||||
if start + len > dims[dim] {
|
let err = |msg| {
|
||||||
Err(Error::NarrowInvalidArgs {
|
Err::<(), _>(
|
||||||
shape: self.shape().clone(),
|
Error::NarrowInvalidArgs {
|
||||||
dim,
|
shape: self.shape().clone(),
|
||||||
start,
|
dim,
|
||||||
len,
|
start,
|
||||||
msg: "start + len > dim_len",
|
len,
|
||||||
}
|
msg,
|
||||||
.bt())?
|
}
|
||||||
|
.bt(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if start > dims[dim] {
|
||||||
|
err("start > dim_len")?
|
||||||
|
}
|
||||||
|
if start.saturating_add(len) > dims[dim] {
|
||||||
|
err("start + len > dim_len")?
|
||||||
}
|
}
|
||||||
if start == 0 && dims[dim] == len {
|
if start == 0 && dims[dim] == len {
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
@ -764,6 +856,20 @@ impl Tensor {
|
|||||||
self.sum_impl(mean_dims, false)? * scale
|
self.sum_impl(mean_dims, false)? * scale
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the unbiased variance over the selected dimension.
|
||||||
|
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "var")?;
|
||||||
|
let mean = self.mean_keepdim(dim)?;
|
||||||
|
let squares = self.broadcast_sub(&mean)?.sqr()?;
|
||||||
|
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the unbiased variance over the selected dimension.
|
||||||
|
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "var")?;
|
||||||
|
self.var_keepdim(dim)?.squeeze(dim)
|
||||||
|
}
|
||||||
|
|
||||||
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
||||||
/// number of dimensions as the original tensor and the select dimension has a single element.
|
/// number of dimensions as the original tensor and the select dimension has a single element.
|
||||||
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
@ -1111,14 +1217,16 @@ impl Tensor {
|
|||||||
op: "scatter-add (self, src)",
|
op: "scatter-add (self, src)",
|
||||||
lhs: self.shape().clone(),
|
lhs: self.shape().clone(),
|
||||||
rhs: source.shape().clone(),
|
rhs: source.shape().clone(),
|
||||||
})?
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
if indexes.dims() != source.dims() {
|
if indexes.dims() != source.dims() {
|
||||||
Err(Error::ShapeMismatchBinaryOp {
|
Err(Error::ShapeMismatchBinaryOp {
|
||||||
op: "scatter-add (indexes, src)",
|
op: "scatter-add (indexes, src)",
|
||||||
lhs: indexes.shape().clone(),
|
lhs: indexes.shape().clone(),
|
||||||
rhs: source.shape().clone(),
|
rhs: source.shape().clone(),
|
||||||
})?
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
let storage = self.storage().scatter_add(
|
let storage = self.storage().scatter_add(
|
||||||
self.layout(),
|
self.layout(),
|
||||||
@ -1190,7 +1298,8 @@ impl Tensor {
|
|||||||
op: "slice-scatter (self, src)",
|
op: "slice-scatter (self, src)",
|
||||||
lhs: self.shape().clone(),
|
lhs: self.shape().clone(),
|
||||||
rhs: src.shape().clone(),
|
rhs: src.shape().clone(),
|
||||||
})?
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||||
self.storage()
|
self.storage()
|
||||||
@ -1224,7 +1333,8 @@ impl Tensor {
|
|||||||
op: "index-add (self, source)",
|
op: "index-add (self, source)",
|
||||||
lhs: self.shape().clone(),
|
lhs: self.shape().clone(),
|
||||||
rhs: source.shape().clone(),
|
rhs: source.shape().clone(),
|
||||||
})?
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
// The number of element in indexes must match the dimension on which the add is
|
// The number of element in indexes must match the dimension on which the add is
|
||||||
// performed on the source tensor (and the index values from `indexes` are taken from
|
// performed on the source tensor (and the index values from `indexes` are taken from
|
||||||
@ -1235,7 +1345,8 @@ impl Tensor {
|
|||||||
op: "index-add (ids, source))",
|
op: "index-add (ids, source))",
|
||||||
lhs: indexes.shape().clone(),
|
lhs: indexes.shape().clone(),
|
||||||
rhs: source.shape().clone(),
|
rhs: source.shape().clone(),
|
||||||
})?
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
let storage = self.storage().index_add(
|
let storage = self.storage().index_add(
|
||||||
self.layout(),
|
self.layout(),
|
||||||
@ -1283,7 +1394,8 @@ impl Tensor {
|
|||||||
op: "gather",
|
op: "gather",
|
||||||
lhs: self.shape().clone(),
|
lhs: self.shape().clone(),
|
||||||
rhs: indexes.shape().clone(),
|
rhs: indexes.shape().clone(),
|
||||||
})?
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
let storage =
|
let storage =
|
||||||
self.storage()
|
self.storage()
|
||||||
@ -1357,6 +1469,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1387,6 +1500,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1427,6 +1541,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1590,6 +1705,24 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||||
|
/// let t = tensor.get_on_dim(1, 0)?;
|
||||||
|
/// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
|
||||||
|
/// let t = tensor.get_on_dim(1, 1)?;
|
||||||
|
/// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
|
||||||
|
/// let t = tensor.get_on_dim(0, 1)?;
|
||||||
|
/// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
|
||||||
|
let dim = dim.to_index(self.shape(), "get_on_dim")?;
|
||||||
|
self.narrow(dim, index, 1)?.squeeze(dim)
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
|
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
|
||||||
/// input are swapped.
|
/// input are swapped.
|
||||||
///
|
///
|
||||||
@ -1698,17 +1831,23 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
||||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||||
|
///
|
||||||
|
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||||
pub fn detach(&self) -> Result<Tensor> {
|
pub fn detach(&self) -> Result<Tensor> {
|
||||||
let tensor_ = Tensor_ {
|
if self.op.is_none() && !self.is_variable {
|
||||||
id: TensorId::new(),
|
Ok(self.clone())
|
||||||
storage: self.storage.clone(),
|
} else {
|
||||||
layout: self.layout.clone(),
|
let tensor_ = Tensor_ {
|
||||||
op: BackpropOp::none(),
|
id: TensorId::new(),
|
||||||
is_variable: false,
|
storage: self.storage.clone(),
|
||||||
dtype: self.dtype,
|
layout: self.layout.clone(),
|
||||||
device: self.device.clone(),
|
op: BackpropOp::none(),
|
||||||
};
|
is_variable: false,
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
dtype: self.dtype,
|
||||||
|
device: self.device.clone(),
|
||||||
|
};
|
||||||
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
||||||
@ -1720,7 +1859,14 @@ impl Tensor {
|
|||||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||||
}
|
}
|
||||||
|
(Storage::Cpu(storage), Device::Metal(metal)) => {
|
||||||
|
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||||
|
}
|
||||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||||
|
(Storage::Metal(storage), Device::Cpu) => {
|
||||||
|
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
|
||||||
|
Storage::Cpu(storage.to_cpu_storage()?)
|
||||||
|
}
|
||||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||||
// are the same.
|
// are the same.
|
||||||
@ -1728,6 +1874,9 @@ impl Tensor {
|
|||||||
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
||||||
}
|
}
|
||||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||||
|
_ => {
|
||||||
|
bail!("not implemented yet")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
@ -2127,11 +2276,56 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the
|
||||||
|
/// input tensor values and `right` elements after.
|
||||||
|
pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||||
|
if left == 0 && right == 0 {
|
||||||
|
Ok(self.clone())
|
||||||
|
} else if self.elem_count() == 0 {
|
||||||
|
crate::bail!("cannot use pad_with_same on an empty tensor")
|
||||||
|
} else if left == 0 {
|
||||||
|
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||||
|
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||||
|
let mut v = vec![self];
|
||||||
|
for _ in 0..right {
|
||||||
|
v.push(&r)
|
||||||
|
}
|
||||||
|
Tensor::cat(&v, dim)
|
||||||
|
} else if right == 0 {
|
||||||
|
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||||
|
let l = self.narrow(dim, 0, 1)?;
|
||||||
|
let mut v = vec![];
|
||||||
|
for _ in 0..left {
|
||||||
|
v.push(&l)
|
||||||
|
}
|
||||||
|
v.push(self);
|
||||||
|
Tensor::cat(&v, dim)
|
||||||
|
} else {
|
||||||
|
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||||
|
let l = self.narrow(dim, 0, 1)?;
|
||||||
|
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||||
|
let mut v = vec![];
|
||||||
|
for _ in 0..left {
|
||||||
|
v.push(&l)
|
||||||
|
}
|
||||||
|
v.push(self);
|
||||||
|
for _ in 0..right {
|
||||||
|
v.push(&r)
|
||||||
|
}
|
||||||
|
Tensor::cat(&v, dim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Run the `forward` method of `m` on `self`.
|
/// Run the `forward` method of `m` on `self`.
|
||||||
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
||||||
m.forward(self)
|
m.forward(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run the `forward` method of `m` on `self`.
|
||||||
|
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
|
||||||
|
m.forward_t(self, train)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||||
self.storage.read().unwrap()
|
self.storage.read().unwrap()
|
||||||
}
|
}
|
||||||
@ -2246,6 +2440,127 @@ impl Tensor {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||||
|
/// values means counting the dimensions from the back.
|
||||||
|
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||||
|
let rank = self.rank() as i64;
|
||||||
|
if rank <= axis {
|
||||||
|
crate::bail!("axis {axis} is too large, tensor rank {rank}")
|
||||||
|
} else if 0 <= axis {
|
||||||
|
Ok(axis as usize)
|
||||||
|
} else {
|
||||||
|
let naxis = rank + axis;
|
||||||
|
if naxis < 0 {
|
||||||
|
crate::bail!("axis {axis} is too small, tensor rank {rank}")
|
||||||
|
}
|
||||||
|
Ok(naxis as usize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a lower triangular matrix of ones of size n by n.
|
||||||
|
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||||
|
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||||
|
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||||
|
t1.le(&t2)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an upper triangular matrix of ones of size n by n.
|
||||||
|
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||||
|
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||||
|
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||||
|
t1.ge(&t2)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a matrix with a diagonal of ones of size n by n.
|
||||||
|
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||||
|
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||||
|
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||||
|
t1.eq(&t2)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the cumulative sum of elements of the input tensor summed over the specified
|
||||||
|
/// dimension.
|
||||||
|
///
|
||||||
|
/// This operation is most efficient when dim is the last dimension of the tensor.
|
||||||
|
pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "cumsum")?;
|
||||||
|
let rank = self.rank();
|
||||||
|
if rank == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
|
let n_axis = self.dim(dim)?;
|
||||||
|
let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
|
||||||
|
if rank == 1 {
|
||||||
|
self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
|
||||||
|
} else {
|
||||||
|
let last = rank - 1;
|
||||||
|
let t = self.transpose(dim, last)?;
|
||||||
|
let t = t.broadcast_matmul(&triu)?;
|
||||||
|
t.transpose(dim, last)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a copy of `self` where the values within `ranges` have been replaced with the
|
||||||
|
/// content of `src`.
|
||||||
|
pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
|
||||||
|
&self,
|
||||||
|
ranges: &[D],
|
||||||
|
src: &Tensor,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let src_dims = src.dims();
|
||||||
|
let self_dims = self.dims();
|
||||||
|
if self_dims.len() != src_dims.len() {
|
||||||
|
crate::bail!(
|
||||||
|
"slice-assign requires input with the same rank {} <> {}",
|
||||||
|
self_dims.len(),
|
||||||
|
src_dims.len()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if self_dims.len() != ranges.len() {
|
||||||
|
crate::bail!(
|
||||||
|
"slice-assign requires input with the same rank as there are ranges {} <> {}",
|
||||||
|
self_dims.len(),
|
||||||
|
ranges.len()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let mut src = src.clone();
|
||||||
|
let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
|
||||||
|
for (i, range) in ranges.iter().enumerate() {
|
||||||
|
let start_included = match range.start_bound() {
|
||||||
|
std::ops::Bound::Unbounded => 0,
|
||||||
|
std::ops::Bound::Included(v) => *v,
|
||||||
|
std::ops::Bound::Excluded(v) => *v + 1,
|
||||||
|
};
|
||||||
|
let end_excluded = match range.end_bound() {
|
||||||
|
std::ops::Bound::Unbounded => self_dims[i],
|
||||||
|
std::ops::Bound::Included(v) => *v + 1,
|
||||||
|
std::ops::Bound::Excluded(v) => *v,
|
||||||
|
};
|
||||||
|
if end_excluded <= start_included {
|
||||||
|
crate::bail!(
|
||||||
|
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if self_dims[i] < end_excluded {
|
||||||
|
crate::bail!(
|
||||||
|
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
|
||||||
|
self_dims[i]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if end_excluded - start_included != src_dims[i] {
|
||||||
|
crate::bail!(
|
||||||
|
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
|
||||||
|
mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
|
||||||
|
}
|
||||||
|
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
@ -4,7 +4,7 @@ use crate::{Result, Tensor};
|
|||||||
macro_rules! test_device {
|
macro_rules! test_device {
|
||||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
|
||||||
#[test]
|
#[test]
|
||||||
fn $test_cpu() -> Result<()> {
|
fn $test_cpu() -> Result<()> {
|
||||||
$fn_name(&Device::Cpu)
|
$fn_name(&Device::Cpu)
|
||||||
@ -15,6 +15,12 @@ macro_rules! test_device {
|
|||||||
fn $test_cuda() -> Result<()> {
|
fn $test_cuda() -> Result<()> {
|
||||||
$fn_name(&Device::new_cuda(0)?)
|
$fn_name(&Device::new_cuda(0)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
#[test]
|
||||||
|
fn $test_metal() -> Result<()> {
|
||||||
|
$fn_name(&Device::new_metal(0)?)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,6 +23,10 @@ pub fn cuda_is_available() -> bool {
|
|||||||
cfg!(feature = "cuda")
|
cfg!(feature = "cuda")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn metal_is_available() -> bool {
|
||||||
|
cfg!(feature = "metal")
|
||||||
|
}
|
||||||
|
|
||||||
pub fn with_avx() -> bool {
|
pub fn with_avx() -> bool {
|
||||||
cfg!(target_feature = "avx")
|
cfg!(target_feature = "avx")
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,11 @@ res = torch.nn.functional.conv1d(t, w)
|
|||||||
print(res.flatten())
|
print(res.flatten())
|
||||||
res = torch.nn.functional.conv1d(t, w, padding=1)
|
res = torch.nn.functional.conv1d(t, w, padding=1)
|
||||||
print(res.flatten())
|
print(res.flatten())
|
||||||
|
|
||||||
|
w_t = w.transpose(0, 1)
|
||||||
|
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||||
|
print(res.shape)
|
||||||
|
print(res)
|
||||||
*/
|
*/
|
||||||
fn conv1d(dev: &Device) -> Result<()> {
|
fn conv1d(dev: &Device) -> Result<()> {
|
||||||
let t = Tensor::new(
|
let t = Tensor::new(
|
||||||
@ -45,6 +50,17 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||||
);
|
);
|
||||||
|
if dev.is_cpu() {
|
||||||
|
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
|
assert_eq!(res.dims(), [1, 2, 7]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
|
[
|
||||||
|
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||||
|
4.7076, -5.9745, -0.8276, 1.621
|
||||||
|
],
|
||||||
|
);
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -479,17 +495,103 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
|
||||||
|
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
|
||||||
|
let loss = res.sqr()?.sum_all()?;
|
||||||
|
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
let grad_t = grads.get(&t).unwrap();
|
||||||
|
let grad_w = grads.get(&w).unwrap();
|
||||||
|
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||||
|
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[9.29, -7.03, 7.87, 0.0, 0.0],
|
||||||
|
[-1.8, -7.82, 5.9, 0.0, 0.0],
|
||||||
|
[-3.12, 4.49, 5.52, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[21.73, 3.39, 4.77, 0.0, 0.0],
|
||||||
|
[8.25, 3.73, 27.61, 0.0, 0.0],
|
||||||
|
[-20.55, -5.61, -2.77, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-8.98, 9.91, -7.15, 0.0, 0.0],
|
||||||
|
[4.93, -0.33, 4.56, 0.0, 0.0],
|
||||||
|
[-6.7, -5.76, -8.05, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[23.54, 6.98, -10.0, 0.0, 0.0],
|
||||||
|
[9.65, 6.18, 18.72, 0.0, 0.0],
|
||||||
|
[3.29, -5.27, 0.79, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[-3.47, 7.44, 0.66],
|
||||||
|
[12.89, -3.4, -9.29],
|
||||||
|
[-14.16, -0.83, 7.14]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-3.23, 5.37, -3.02],
|
||||||
|
[-2.12, -11.24, 1.94],
|
||||||
|
[6.97, 7.2, 2.99]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-4.04, -3.31, 4.87],
|
||||||
|
[-6.68, -5.68, 1.73],
|
||||||
|
[-5.54, 4.32, 0.52]
|
||||||
|
],
|
||||||
|
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
|
||||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
test_device!(
|
||||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
conv1d_small,
|
||||||
|
conv1d_small_cpu,
|
||||||
|
conv1d_small_gpu,
|
||||||
|
conv1d_small_metal
|
||||||
|
);
|
||||||
|
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
|
||||||
test_device!(
|
test_device!(
|
||||||
conv2d_non_square,
|
conv2d_non_square,
|
||||||
conv2d_non_square_cpu,
|
conv2d_non_square_cpu,
|
||||||
conv2d_non_square_gpu
|
conv2d_non_square_gpu,
|
||||||
|
conv2d_non_square_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
conv2d_small,
|
||||||
|
conv2d_small_cpu,
|
||||||
|
conv2d_small_gpu,
|
||||||
|
conv2d_small_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
conv2d_smaller,
|
||||||
|
conv2d_smaller_cpu,
|
||||||
|
conv2d_smaller_gpu,
|
||||||
|
conv2d_smaller_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
conv2d_grad,
|
||||||
|
conv2d_grad_cpu,
|
||||||
|
conv2d_grad_gpu,
|
||||||
|
conv2_grad_metal
|
||||||
);
|
);
|
||||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
|
||||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
|
||||||
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
|
||||||
|
@ -192,6 +192,84 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(grad_x, 2)?,
|
test_utils::to_vec1_round(grad_x, 2)?,
|
||||||
[0.01, 0.42, 0.0, 0.98],
|
[0.01, 0.42, 0.0, 0.98],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// testing compared to pytorch nn.GELU(approximate = 'tanh')
|
||||||
|
let y = x.gelu()?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[2.9964, 0.8412, 3.9999, 0.0839]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[1.0116, 1.0830, 1.0003, 0.6188],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Testing compared to pytorch torch.erf
|
||||||
|
//
|
||||||
|
// import torch
|
||||||
|
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||||
|
// y = x.erf()
|
||||||
|
// print(y)
|
||||||
|
// loss = y.sum()
|
||||||
|
// loss.backward()
|
||||||
|
// print(x.grad)
|
||||||
|
let y = x.erf()?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[0.0001, 0.4151, 0.0, 1.1033],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Testing compared to pytorch nn.GELU(approximate = 'none')
|
||||||
|
//
|
||||||
|
// import torch
|
||||||
|
// import torch.nn.functional as F
|
||||||
|
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||||
|
// y = F.gelu(x, approximate='none')
|
||||||
|
// print(y)
|
||||||
|
// loss = y.sum()
|
||||||
|
// loss.backward()
|
||||||
|
// print(x.grad)
|
||||||
|
let y = x.gelu_erf()?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[2.9960, 0.8413, 3.9999, 0.0839]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[1.0119, 1.0833, 1.0005, 0.6188],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Testing compared to pytorch elu
|
||||||
|
//
|
||||||
|
// import torch
|
||||||
|
// import torch.nn.functional as F
|
||||||
|
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
|
||||||
|
// y = F.elu(x, alpha=2.0)
|
||||||
|
// print(y)
|
||||||
|
// loss = y.min
|
||||||
|
// loss = y.sum()
|
||||||
|
// loss.backward()
|
||||||
|
// print(x.grad)
|
||||||
|
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
|
||||||
|
let y = elu_x.elu(2.)?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,9 +315,29 @@ fn binary_grad(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
test_device!(
|
||||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
simple_grad,
|
||||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
simple_grad_cpu,
|
||||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
simple_grad_gpu,
|
||||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
simple_grad_metal
|
||||||
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
|
);
|
||||||
|
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
|
||||||
|
test_device!(
|
||||||
|
matmul_grad,
|
||||||
|
matmul_grad_cpu,
|
||||||
|
matmul_grad_gpu,
|
||||||
|
matmul_grad_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
grad_descent,
|
||||||
|
grad_descent_cpu,
|
||||||
|
grad_descent_gpu,
|
||||||
|
grad_descent_metal
|
||||||
|
);
|
||||||
|
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
|
||||||
|
test_device!(
|
||||||
|
binary_grad,
|
||||||
|
binary_grad_cpu,
|
||||||
|
binary_grad_gpu,
|
||||||
|
binary_grad_metal
|
||||||
|
);
|
||||||
|
@ -91,3 +91,32 @@ fn index_3d() -> Result<()> {
|
|||||||
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn slice_assign() -> Result<()> {
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
|
||||||
|
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
|
||||||
|
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
|
||||||
|
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
|
||||||
|
assert_eq!(
|
||||||
|
out.to_vec2::<u32>()?,
|
||||||
|
&[
|
||||||
|
[0, 1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 0, 1],
|
||||||
|
[10, 11, 12, 2, 3],
|
||||||
|
[15, 16, 17, 4, 5]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
|
||||||
|
assert_eq!(
|
||||||
|
out.to_vec2::<u32>()?,
|
||||||
|
&[
|
||||||
|
[0, 1, 2, 3, 4],
|
||||||
|
[2, 3, 7, 8, 9],
|
||||||
|
[4, 5, 12, 13, 14],
|
||||||
|
[15, 16, 17, 18, 19]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn strided_blocks() -> Result<()> {
|
fn strided_blocks() -> Result<()> {
|
||||||
|
@ -98,15 +98,17 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
|
||||||
test_device!(
|
test_device!(
|
||||||
avg_pool2d_pytorch,
|
avg_pool2d_pytorch,
|
||||||
avg_pool2d_pytorch_cpu,
|
avg_pool2d_pytorch_cpu,
|
||||||
avg_pool2d_pytorch_gpu
|
avg_pool2d_pytorch_gpu,
|
||||||
|
avg_pool2d_pytorch_metal
|
||||||
);
|
);
|
||||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
|
||||||
test_device!(
|
test_device!(
|
||||||
upsample_nearest2d,
|
upsample_nearest2d,
|
||||||
upsample_nearest2d_cpu,
|
upsample_nearest2d_cpu,
|
||||||
upsample_nearest2d_gpu
|
upsample_nearest2d_gpu,
|
||||||
|
upsample_nearest2d_metal
|
||||||
);
|
);
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use candle_core::{
|
use candle_core::{
|
||||||
quantized::{self, GgmlDType},
|
quantized::{self, GgmlDType},
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, Result, Tensor,
|
Device, Module, Result, Tensor,
|
||||||
};
|
};
|
||||||
use quantized::{k_quants, GgmlType};
|
use quantized::{k_quants, GgmlType};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
@ -29,7 +29,26 @@ fn ones(device: &Device) -> Result<()> {
|
|||||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn arange(device: &Device) -> Result<()> {
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
|
||||||
|
[0, 1, 2, 3, 4],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
|
||||||
|
[0, 2, 4],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
|
||||||
|
[0, 3],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
|
||||||
|
[5, 4, 3, 2, 1],
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,6 +180,22 @@ fn transpose(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn var(device: &Device) -> Result<()> {
|
||||||
|
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
|
||||||
|
let data = &[
|
||||||
|
[0.2035f32, 1.2959, 1.8101, -0.4644],
|
||||||
|
[1.5027, -0.3270, 0.5905, 0.6538],
|
||||||
|
[-1.5745, 1.3330, -0.5596, -0.6548],
|
||||||
|
[0.1264, -0.5080, 1.6420, 0.1992],
|
||||||
|
];
|
||||||
|
let tensor = Tensor::new(data, device)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
|
||||||
|
&[[1.0631], [0.559], [1.4893], [0.8258]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn sum(device: &Device) -> Result<()> {
|
fn sum(device: &Device) -> Result<()> {
|
||||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
@ -1035,33 +1070,60 @@ fn randn(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||||
test_device!(ones, ones_cpu, ones_gpu);
|
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||||
test_device!(cat, cat_cpu, cat_gpu);
|
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||||
test_device!(sum, sum_cpu, sum_gpu);
|
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||||
test_device!(min, min_cpu, min_gpu);
|
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||||
test_device!(max, max_cpu, max_gpu);
|
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
test_device!(max, max_cpu, max_gpu, max_metal);
|
||||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
|
||||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
|
||||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
|
||||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||||
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
|
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
test_device!(
|
||||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
broadcast_matmul,
|
||||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
broadcast_matmul_cpu,
|
||||||
test_device!(gather, gather_cpu, gather_gpu);
|
broadcast_matmul_gpu,
|
||||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
broadcast_matmul_metal
|
||||||
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
);
|
||||||
test_device!(randn, randn_cpu, randn_gpu);
|
test_device!(
|
||||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
broadcasting,
|
||||||
|
broadcasting_cpu,
|
||||||
|
broadcasting_gpu,
|
||||||
|
broadcasting_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
index_select,
|
||||||
|
index_select_cpu,
|
||||||
|
index_select_gpu,
|
||||||
|
index_select_metal
|
||||||
|
);
|
||||||
|
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||||
|
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||||
|
test_device!(
|
||||||
|
scatter_add,
|
||||||
|
scatter_add_cpu,
|
||||||
|
scatter_add_gpu,
|
||||||
|
scatter_add_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
slice_scatter,
|
||||||
|
slice_scatter_cpu,
|
||||||
|
slice_scatter_gpu,
|
||||||
|
slice_scatter_metal
|
||||||
|
);
|
||||||
|
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||||
|
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||||
|
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
@ -1073,3 +1135,89 @@ fn randn_hasneg() -> Result<()> {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pad_with_same() -> Result<()> {
|
||||||
|
let t = Tensor::arange(1f32, 5f32, &Device::Cpu)?.reshape((2, 2))?;
|
||||||
|
let t0 = t.pad_with_same(0, 1, 2)?;
|
||||||
|
assert_eq!(
|
||||||
|
t0.to_vec2::<f32>()?,
|
||||||
|
[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
|
||||||
|
);
|
||||||
|
let t1 = t.pad_with_same(1, 1, 2)?;
|
||||||
|
assert_eq!(
|
||||||
|
t1.to_vec2::<f32>()?,
|
||||||
|
[[1.0, 1.0, 2.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0, 4.0]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn i64_abs() -> Result<()> {
|
||||||
|
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
|
||||||
|
let t = t.abs()?;
|
||||||
|
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tril_triu_eye() -> Result<()> {
|
||||||
|
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[1.0, 0.0, 0.0, 0.0],
|
||||||
|
[1.0, 1.0, 0.0, 0.0],
|
||||||
|
[1.0, 1.0, 1.0, 0.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0]
|
||||||
|
],
|
||||||
|
);
|
||||||
|
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[1.0, 1.0, 1.0, 1.0],
|
||||||
|
[0.0, 1.0, 1.0, 1.0],
|
||||||
|
[0.0, 0.0, 1.0, 1.0],
|
||||||
|
[0.0, 0.0, 0.0, 1.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[1.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 1.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 1.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cumsum() -> Result<()> {
|
||||||
|
let t = &[3f32, 1., 4., 1., 5.];
|
||||||
|
let t = Tensor::new(t, &Device::Cpu)?;
|
||||||
|
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
|
||||||
|
let t = t.unsqueeze(1)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||||
|
[[3.0], [4.0], [8.0], [9.0], [14.0]]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||||
|
[[3.0], [1.0], [4.0], [1.0], [5.0]]
|
||||||
|
);
|
||||||
|
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||||
|
let t = Tensor::new(t, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||||
|
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||||
|
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -11,8 +11,8 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||||
hf-hub = { workspace = true}
|
hf-hub = { workspace = true}
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
|
@ -4,7 +4,9 @@
|
|||||||
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
|
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
|
||||||
//! The binary version of the dataset is used.
|
//! The binary version of the dataset is used.
|
||||||
use crate::vision::Dataset;
|
use crate::vision::Dataset;
|
||||||
use candle::{DType, Device, Result, Tensor};
|
use candle::{DType, Device, Error, Result, Tensor};
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{BufReader, Read};
|
use std::io::{BufReader, Read};
|
||||||
|
|
||||||
@ -60,3 +62,58 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
|
|||||||
labels: 10,
|
labels: 10,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
|
||||||
|
let samples = parquet.metadata().file_metadata().num_rows() as usize;
|
||||||
|
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 1_024);
|
||||||
|
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
|
||||||
|
for row in parquet.into_iter().flatten() {
|
||||||
|
for (_name, field) in row.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Group(subrow) = field {
|
||||||
|
for (_name, field) in subrow.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Bytes(value) = field {
|
||||||
|
let image = image::load_from_memory(value.data()).unwrap();
|
||||||
|
buffer_images.extend(image.to_rgb8().as_raw());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if let parquet::record::Field::Long(label) = field {
|
||||||
|
buffer_labels.push(*label as u8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
|
||||||
|
.to_dtype(DType::U8)?
|
||||||
|
/ 255.)?;
|
||||||
|
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||||
|
Ok((images, labels))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load() -> Result<Dataset> {
|
||||||
|
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
|
let dataset_id = "cifar10".to_string();
|
||||||
|
let repo = Repo::with_revision(
|
||||||
|
dataset_id,
|
||||||
|
RepoType::Dataset,
|
||||||
|
"refs/convert/parquet".to_string(),
|
||||||
|
);
|
||||||
|
let repo = api.repo(repo);
|
||||||
|
let test_parquet_filename = repo
|
||||||
|
.get("plain_text/test/0000.parquet")
|
||||||
|
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
|
let train_parquet_filename = repo
|
||||||
|
.get("plain_text/train/0000.parquet")
|
||||||
|
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
|
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
||||||
|
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||||
|
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
|
||||||
|
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||||
|
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
||||||
|
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
||||||
|
Ok(crate::vision::Dataset {
|
||||||
|
train_images,
|
||||||
|
train_labels,
|
||||||
|
test_images,
|
||||||
|
test_labels,
|
||||||
|
labels: 10,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -11,17 +11,18 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||||
|
candle-onnx = { path = "../candle-onnx", version = "0.3.1", optional = true }
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true }
|
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
@ -55,6 +56,8 @@ cudnn = ["candle/cudnn"]
|
|||||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
|
onnx = ["candle-onnx"]
|
||||||
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
@ -63,3 +66,11 @@ required-features = ["cuda", "nccl", "flash-attn"]
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "reinforcement-learning"
|
name = "reinforcement-learning"
|
||||||
required-features = ["pyo3"]
|
required-features = ["pyo3"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "onnx"
|
||||||
|
required-features = ["onnx"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "onnx_basics"
|
||||||
|
required-features = ["onnx"]
|
||||||
|
@ -5,11 +5,11 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||||
|
|
||||||
use anyhow::{anyhow, Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::Tensor;
|
use candle::Tensor;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::{PaddingParams, Tokenizer};
|
use tokenizers::{PaddingParams, Tokenizer};
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -19,10 +19,6 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
/// Run offline (you must have the files already cached)
|
|
||||||
#[arg(long)]
|
|
||||||
offline: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
@ -38,6 +34,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
/// Use the pytorch weights rather than the safetensors ones
|
||||||
|
#[arg(long)]
|
||||||
|
use_pth: bool,
|
||||||
|
|
||||||
/// The number of times to run the prompt.
|
/// The number of times to run the prompt.
|
||||||
#[arg(long, default_value = "1")]
|
#[arg(long, default_value = "1")]
|
||||||
n: usize,
|
n: usize,
|
||||||
@ -60,34 +60,27 @@ impl Args {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||||
let cache = Cache::default().repo(repo);
|
|
||||||
(
|
|
||||||
cache
|
|
||||||
.get("config.json")
|
|
||||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
|
||||||
cache
|
|
||||||
.get("tokenizer.json")
|
|
||||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
|
||||||
cache
|
|
||||||
.get("model.safetensors")
|
|
||||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let api = api.repo(repo);
|
let api = api.repo(repo);
|
||||||
(
|
let config = api.get("config.json")?;
|
||||||
api.get("config.json")?,
|
let tokenizer = api.get("tokenizer.json")?;
|
||||||
api.get("tokenizer.json")?,
|
let weights = if self.use_pth {
|
||||||
api.get("model.safetensors")?,
|
api.get("pytorch_model.bin")?
|
||||||
)
|
} else {
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
};
|
||||||
|
(config, tokenizer, weights)
|
||||||
};
|
};
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let config: Config = serde_json::from_str(&config)?;
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let vb =
|
let vb = if self.use_pth {
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
|
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||||
|
} else {
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||||
|
};
|
||||||
let model = BertModel::load(vb, &config)?;
|
let model = BertModel::load(vb, &config)?;
|
||||||
Ok((model, tokenizer))
|
Ok((model, tokenizer))
|
||||||
}
|
}
|
||||||
|
19
candle-examples/examples/blip/README.md
Normal file
19
candle-examples/examples/blip/README.md
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# candle-blip
|
||||||
|
|
||||||
|
The
|
||||||
|
[blip-image-captioning](https://huggingface.co/Salesforce/blip-image-captioning-base)
|
||||||
|
model can generate captions for an input image.
|
||||||
|
|
||||||
|
## Running on an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example blip --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||||
|
loaded image Tensor[dims 3, 384, 384; f32]
|
||||||
|
model built
|
||||||
|
several cyclists are riding down a road with cars behind them%
|
||||||
|
```
|
||||||
|

|
154
candle-examples/examples/blip/main.rs
Normal file
154
candle-examples/examples/blip/main.rs
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Result, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::blip;
|
||||||
|
use candle_transformers::models::quantized_blip;
|
||||||
|
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
M(blip::BlipForConditionalGeneration),
|
||||||
|
Q(quantized_blip::BlipForConditionalGeneration),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::M(m) => m.text_decoder().forward(xs, img_xs),
|
||||||
|
Self::Q(m) => m.text_decoder().forward(xs, img_xs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Maybe add support for the conditional prompt.
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Use the quantized version of the model.
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
const SEP_TOKEN_ID: u32 = 102;
|
||||||
|
|
||||||
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
|
/// (3, 384, 384). OpenAI normalization is applied.
|
||||||
|
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||||
|
let img = image::io::Reader::open(p)?
|
||||||
|
.decode()
|
||||||
|
.map_err(candle::Error::wrap)?
|
||||||
|
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
let data = img.into_raw();
|
||||||
|
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||||
|
let mean =
|
||||||
|
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
|
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
|
||||||
|
.reshape((3, 1, 1))?;
|
||||||
|
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||||
|
.broadcast_sub(&mean)?
|
||||||
|
.broadcast_div(&std)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
if args.quantized {
|
||||||
|
let api = api.model("lmz/candle-blip".to_string());
|
||||||
|
api.get("blip-image-captioning-large-q4k.gguf")?
|
||||||
|
} else {
|
||||||
|
let api = api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Salesforce/blip-image-captioning-large".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/18".to_string(),
|
||||||
|
));
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
let tokenizer = match args.tokenizer {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("Salesforce/blip-image-captioning-large".to_string());
|
||||||
|
api.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
Some(file) => file.into(),
|
||||||
|
};
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
let mut tokenizer = TokenOutputStream::new(tokenizer);
|
||||||
|
let mut logits_processor =
|
||||||
|
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||||
|
|
||||||
|
let config = blip::Config::image_captioning_large();
|
||||||
|
|
||||||
|
let (image_embeds, device, mut model) = if args.quantized {
|
||||||
|
let device = Device::Cpu;
|
||||||
|
let image = load_image(args.image)?.to_device(&device)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
|
||||||
|
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||||
|
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||||
|
(image_embeds, device, Model::Q(model))
|
||||||
|
} else {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let image = load_image(args.image)?.to_device(&device)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||||
|
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||||
|
(image_embeds, device, Model::M(model))
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut token_ids = vec![30522u32];
|
||||||
|
for index in 0..1000 {
|
||||||
|
let context_size = if index > 0 { 1 } else { token_ids.len() };
|
||||||
|
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||||
|
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.text_decoder_forward(&input_ids, &image_embeds)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||||
|
let token = logits_processor.sample(&logits)?;
|
||||||
|
if token == SEP_TOKEN_ID {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
token_ids.push(token);
|
||||||
|
if let Some(t) = tokenizer.next_token(token)? {
|
||||||
|
use std::io::Write;
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
Ok(())
|
||||||
|
}
|
22
candle-examples/examples/distilbert/README.md
Normal file
22
candle-examples/examples/distilbert/README.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# candle-distilbert
|
||||||
|
|
||||||
|
DistilBert is a distiled version of the Bert model.
|
||||||
|
|
||||||
|
## Sentence embeddings
|
||||||
|
|
||||||
|
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
|
||||||
|
are downloaded from the hub on the first run.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||||
|
|
||||||
|
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
||||||
|
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
||||||
|
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
|
||||||
|
> ...
|
||||||
|
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
|
||||||
|
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
|
||||||
|
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
|
||||||
|
> Tensor[[1, 7, 768], f32]
|
||||||
|
|
||||||
|
```
|
135
candle-examples/examples/distilbert/main.rs
Normal file
135
candle-examples/examples/distilbert/main.rs
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use candle::{Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use clap::Parser;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
/// When set, compute embeddings for this prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// Use the pytorch weights rather than the safetensors ones
|
||||||
|
#[arg(long)]
|
||||||
|
use_pth: bool,
|
||||||
|
|
||||||
|
/// The number of times to run the prompt.
|
||||||
|
#[arg(long, default_value = "1")]
|
||||||
|
n: usize,
|
||||||
|
|
||||||
|
/// L2 normalization for embeddings.
|
||||||
|
#[arg(long, default_value = "true")]
|
||||||
|
normalize_embeddings: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
||||||
|
let device = candle_examples::device(self.cpu)?;
|
||||||
|
let default_model = "distilbert-base-uncased".to_string();
|
||||||
|
let default_revision = "main".to_string();
|
||||||
|
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||||
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
|
(None, Some(revision)) => (default_model, revision),
|
||||||
|
(None, None) => (default_model, default_revision),
|
||||||
|
};
|
||||||
|
|
||||||
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
|
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||||
|
let api = Api::new()?;
|
||||||
|
let api = api.repo(repo);
|
||||||
|
let config = api.get("config.json")?;
|
||||||
|
let tokenizer = api.get("tokenizer.json")?;
|
||||||
|
let weights = if self.use_pth {
|
||||||
|
api.get("pytorch_model.bin")?
|
||||||
|
} else {
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
};
|
||||||
|
(config, tokenizer, weights)
|
||||||
|
};
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let vb = if self.use_pth {
|
||||||
|
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||||
|
} else {
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||||
|
};
|
||||||
|
let model = DistilBertModel::load(vb, &config)?;
|
||||||
|
Ok((model, tokenizer))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mask(size: usize, device: &Device) -> Tensor {
|
||||||
|
let mask: Vec<_> = (0..size)
|
||||||
|
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
println!("tracing...");
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||||
|
let device = &model.device;
|
||||||
|
|
||||||
|
let tokenizer = tokenizer
|
||||||
|
.with_padding(None)
|
||||||
|
.with_truncation(None)
|
||||||
|
.map_err(E::msg)?;
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(args.prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
|
let mask = get_mask(tokens.len(), device);
|
||||||
|
|
||||||
|
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
||||||
|
println!("mask: {:?}", mask.to_vec2::<u8>());
|
||||||
|
|
||||||
|
let ys = model.forward(&token_ids, &mask)?;
|
||||||
|
println!("{ys}");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||||
|
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||||
|
}
|
45
candle-examples/examples/jina-bert/README.md
Normal file
45
candle-examples/examples/jina-bert/README.md
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# candle-jina-bert
|
||||||
|
|
||||||
|
Jina-Bert is a general large language model with a context size of 8192, [model
|
||||||
|
card](https://huggingface.co/jinaai/jina-embeddings-v2-base-en). In this example
|
||||||
|
it can be used for two different tasks:
|
||||||
|
- Compute sentence embeddings for a prompt.
|
||||||
|
- Compute similarities between a set of sentences.
|
||||||
|
|
||||||
|
|
||||||
|
## Sentence embeddings
|
||||||
|
|
||||||
|
Jina-Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||||
|
are downloaded from the hub on the first run.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example jina-bert --release -- --prompt "Here is a test sentence"
|
||||||
|
|
||||||
|
> [[[ 0.1595, -0.9885, 0.6494, ..., 0.3003, -0.6901, -1.2355],
|
||||||
|
> [ 0.0374, -0.1798, 1.3359, ..., 0.6731, 0.2133, -1.6807],
|
||||||
|
> [ 0.1700, -0.8534, 0.8924, ..., -0.1785, -0.0727, -1.5087],
|
||||||
|
> ...
|
||||||
|
> [-0.3113, -1.3665, 0.2027, ..., -0.2519, 0.1711, -1.5811],
|
||||||
|
> [ 0.0907, -1.0492, 0.5382, ..., 0.0242, -0.7077, -1.0830],
|
||||||
|
> [ 0.0369, -0.6343, 0.6105, ..., 0.0671, 0.3778, -1.1505]]]
|
||||||
|
> Tensor[[1, 7, 768], f32]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Similarities
|
||||||
|
|
||||||
|
In this example, Jina-Bert is used to compute the sentence embeddings for a set of
|
||||||
|
sentences (hardcoded in the examples). Then cosine similarities are computed for
|
||||||
|
each sentence pair and they are reported by decreasing values, hence the first
|
||||||
|
reported pair contains the two sentences that have the highest similarity score.
|
||||||
|
The sentence embeddings are computed using average pooling through all the
|
||||||
|
sentence tokens, including some potential padding.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example jina-bert --release
|
||||||
|
|
||||||
|
> score: 0.94 'The new movie is awesome' 'The new movie is so great'
|
||||||
|
> score: 0.81 'The cat sits outside' 'The cat plays in the garden'
|
||||||
|
> score: 0.78 'I love pasta' 'Do you like pizza?'
|
||||||
|
> score: 0.68 'I love pasta' 'The new movie is awesome'
|
||||||
|
> score: 0.67 'A man is playing guitar' 'A woman watches TV'
|
||||||
|
```
|
180
candle-examples/examples/jina-bert/main.rs
Normal file
180
candle-examples/examples/jina-bert/main.rs
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle_transformers::models::jina_bert::{BertModel, Config};
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use candle::{DType, Module, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// When set, compute embeddings for this prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
/// The number of times to run the prompt.
|
||||||
|
#[arg(long, default_value = "1")]
|
||||||
|
n: usize,
|
||||||
|
|
||||||
|
/// L2 normalization for embeddings.
|
||||||
|
#[arg(long, default_value = "true")]
|
||||||
|
normalize_embeddings: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
let model = match &self.model {
|
||||||
|
Some(model_file) => std::path::PathBuf::from(model_file),
|
||||||
|
None => Api::new()?
|
||||||
|
.repo(Repo::new(
|
||||||
|
"jinaai/jina-embeddings-v2-base-en".to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
))
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
};
|
||||||
|
let tokenizer = match &self.tokenizer {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => Api::new()?
|
||||||
|
.repo(Repo::new(
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
))
|
||||||
|
.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let device = candle_examples::device(self.cpu)?;
|
||||||
|
let config = Config::v2_base();
|
||||||
|
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||||
|
let model = BertModel::new(vb, &config)?;
|
||||||
|
Ok((model, tokenizer))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
println!("tracing...");
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
|
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||||
|
let device = &model.device;
|
||||||
|
|
||||||
|
if let Some(prompt) = args.prompt {
|
||||||
|
let tokenizer = tokenizer
|
||||||
|
.with_padding(None)
|
||||||
|
.with_truncation(None)
|
||||||
|
.map_err(E::msg)?;
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
|
println!("Loaded and encoded {:?}", start.elapsed());
|
||||||
|
for idx in 0..args.n {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let ys = model.forward(&token_ids)?;
|
||||||
|
if idx == 0 {
|
||||||
|
println!("{ys}");
|
||||||
|
}
|
||||||
|
println!("Took {:?}", start.elapsed());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let sentences = [
|
||||||
|
"The cat sits outside",
|
||||||
|
"A man is playing guitar",
|
||||||
|
"I love pasta",
|
||||||
|
"The new movie is awesome",
|
||||||
|
"The cat plays in the garden",
|
||||||
|
"A woman watches TV",
|
||||||
|
"The new movie is so great",
|
||||||
|
"Do you like pizza?",
|
||||||
|
];
|
||||||
|
let n_sentences = sentences.len();
|
||||||
|
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||||
|
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||||
|
} else {
|
||||||
|
let pp = tokenizers::PaddingParams {
|
||||||
|
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
tokenizer.with_padding(Some(pp));
|
||||||
|
}
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode_batch(sentences.to_vec(), true)
|
||||||
|
.map_err(E::msg)?;
|
||||||
|
let token_ids = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|tokens| {
|
||||||
|
let tokens = tokens.get_ids().to_vec();
|
||||||
|
Tensor::new(tokens.as_slice(), device)
|
||||||
|
})
|
||||||
|
.collect::<candle::Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||||
|
println!("running inference on batch {:?}", token_ids.shape());
|
||||||
|
let embeddings = model.forward(&token_ids)?;
|
||||||
|
println!("generated embeddings {:?}", embeddings.shape());
|
||||||
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||||
|
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||||
|
let embeddings = if args.normalize_embeddings {
|
||||||
|
normalize_l2(&embeddings)?
|
||||||
|
} else {
|
||||||
|
embeddings
|
||||||
|
};
|
||||||
|
println!("pooled embeddings {:?}", embeddings.shape());
|
||||||
|
|
||||||
|
let mut similarities = vec![];
|
||||||
|
for i in 0..n_sentences {
|
||||||
|
let e_i = embeddings.get(i)?;
|
||||||
|
for j in (i + 1)..n_sentences {
|
||||||
|
let e_j = embeddings.get(j)?;
|
||||||
|
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||||
|
similarities.push((cosine_similarity, i, j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
||||||
|
for &(score, i, j) in similarities[..5].iter() {
|
||||||
|
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||||
|
}
|
@ -6,9 +6,10 @@ extern crate accelerate_src;
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
mod model;
|
use candle_transformers::models::llama2_c as model;
|
||||||
|
use candle_transformers::models::llama2_c_weights as weights;
|
||||||
|
use candle_transformers::models::quantized_llama2_c as qmodel;
|
||||||
mod training;
|
mod training;
|
||||||
mod weights;
|
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
@ -19,6 +20,7 @@ use std::io::Write;
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
use model::{Config, Llama};
|
use model::{Config, Llama};
|
||||||
|
use qmodel::QLlama;
|
||||||
use weights::TransformerWeights;
|
use weights::TransformerWeights;
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
@ -152,6 +154,20 @@ fn main() -> anyhow::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
Llama(Llama),
|
||||||
|
QLlama(QLlama),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||||
|
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||||
use std::io::BufRead;
|
use std::io::BufRead;
|
||||||
|
|
||||||
@ -241,24 +257,66 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(common_args.cpu)?;
|
let device = candle_examples::device(common_args.cpu)?;
|
||||||
|
|
||||||
|
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
|
||||||
let is_safetensors = config_path
|
let is_safetensors = config_path
|
||||||
.extension()
|
.extension()
|
||||||
.map_or(false, |v| v == "safetensors");
|
.map_or(false, |v| v == "safetensors");
|
||||||
let (vb, config) = if is_safetensors {
|
let (model, config) = if is_gguf {
|
||||||
let config = Config::tiny();
|
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
|
||||||
|
let (_vocab_size, dim) = vb
|
||||||
|
.get_no_shape("model.embed_tokens.weight")?
|
||||||
|
.shape()
|
||||||
|
.dims2()?;
|
||||||
|
let config = match dim {
|
||||||
|
64 => Config::tiny_260k(),
|
||||||
|
288 => Config::tiny_15m(),
|
||||||
|
512 => Config::tiny_42m(),
|
||||||
|
768 => Config::tiny_110m(),
|
||||||
|
_ => anyhow::bail!("no config for dim {dim}"),
|
||||||
|
};
|
||||||
|
let freq_cis_real = vb
|
||||||
|
.get(
|
||||||
|
(config.seq_len, config.head_size() / 2),
|
||||||
|
"rot.freq_cis_real",
|
||||||
|
)?
|
||||||
|
.dequantize(&candle::Device::Cpu)?;
|
||||||
|
let freq_cis_imag = vb
|
||||||
|
.get(
|
||||||
|
(config.seq_len, config.head_size() / 2),
|
||||||
|
"rot.freq_cis_imag",
|
||||||
|
)?
|
||||||
|
.dequantize(&candle::Device::Cpu)?;
|
||||||
|
|
||||||
|
let fake_vb = candle_nn::VarBuilder::from_tensors(
|
||||||
|
[
|
||||||
|
("freq_cis_real".to_string(), freq_cis_real),
|
||||||
|
("freq_cis_imag".to_string(), freq_cis_imag),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
candle::DType::F32,
|
||||||
|
&candle::Device::Cpu,
|
||||||
|
);
|
||||||
|
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||||
|
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||||
|
(model, config)
|
||||||
|
} else if is_safetensors {
|
||||||
|
let config = Config::tiny_15m();
|
||||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||||
(vb, config)
|
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||||
|
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||||
|
(model, config)
|
||||||
} else {
|
} else {
|
||||||
let mut file = std::fs::File::open(config_path)?;
|
let mut file = std::fs::File::open(config_path)?;
|
||||||
let config = Config::from_reader(&mut file)?;
|
let config = Config::from_reader(&mut file)?;
|
||||||
println!("{config:?}");
|
println!("{config:?}");
|
||||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||||
let vb = weights.var_builder(&config, &device)?;
|
let vb = weights.var_builder(&config, &device)?;
|
||||||
(vb, config)
|
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||||
|
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||||
|
(model, config)
|
||||||
};
|
};
|
||||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
|
||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
||||||
@ -273,7 +331,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0.. {
|
for index in 0.. {
|
||||||
if tokens.len() >= model.config.seq_len {
|
if tokens.len() >= config.seq_len {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|||||||
);
|
);
|
||||||
let varmap = candle_nn::VarMap::new();
|
let varmap = candle_nn::VarMap::new();
|
||||||
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||||
let config = Config::tiny();
|
let config = Config::tiny_15m();
|
||||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
|
|
||||||
|
38
candle-examples/examples/marian-mt/README.md
Normal file
38
candle-examples/examples/marian-mt/README.md
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# candle-marian-mt
|
||||||
|
|
||||||
|
`marian-mt` is a neural machine translation model. In this example it is used to
|
||||||
|
translate text from French to English. See the associated [model
|
||||||
|
card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on
|
||||||
|
the model itself.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example marian-mt --release -- \
|
||||||
|
--text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps."
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
<NIL> Tomorrow, at dawn, at the time when the country is whitening, I will go. See,
|
||||||
|
I know you are waiting for me. I will go through the forest, I will go through the
|
||||||
|
mountain. I cannot stay far from you any longer.</s>
|
||||||
|
```
|
||||||
|
|
||||||
|
## Generating the tokenizer.json files
|
||||||
|
|
||||||
|
You can use the following script to generate the `tokenizer.json` config files
|
||||||
|
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
|
||||||
|
packages to be install and use the `convert_slow_tokenizer.py` script from this
|
||||||
|
directory.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from convert_slow_tokenizer import MarianConverter
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
||||||
|
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
||||||
|
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
|
||||||
|
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
||||||
|
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
|
||||||
|
```
|
1385
candle-examples/examples/marian-mt/convert_slow_tokenizer.py
Normal file
1385
candle-examples/examples/marian-mt/convert_slow_tokenizer.py
Normal file
File diff suppressed because it is too large
Load Diff
152
candle-examples/examples/marian-mt/main.rs
Normal file
152
candle-examples/examples/marian-mt/main.rs
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::marian;
|
||||||
|
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
Base,
|
||||||
|
Big,
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Maybe add support for the conditional prompt.
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_dec: Option<String>,
|
||||||
|
|
||||||
|
/// Choose the variant of the model to run.
|
||||||
|
#[arg(long, default_value = "big")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Use the quantized version of the model.
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
|
||||||
|
/// Text to be translated
|
||||||
|
#[arg(long)]
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let config = match args.which {
|
||||||
|
Which::Base => marian::Config::opus_mt_fr_en(),
|
||||||
|
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
|
||||||
|
};
|
||||||
|
let tokenizer = {
|
||||||
|
let tokenizer = match args.tokenizer {
|
||||||
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
|
None => {
|
||||||
|
let name = match args.which {
|
||||||
|
Which::Base => "tokenizer-marian-base-fr.json",
|
||||||
|
Which::Big => "tokenizer-marian-fr.json",
|
||||||
|
};
|
||||||
|
Api::new()?
|
||||||
|
.model("lmz/candle-marian".to_string())
|
||||||
|
.get(name)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokenizer_dec = {
|
||||||
|
let tokenizer = match args.tokenizer_dec {
|
||||||
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
|
None => {
|
||||||
|
let name = match args.which {
|
||||||
|
Which::Base => "tokenizer-marian-base-en.json",
|
||||||
|
Which::Big => "tokenizer-marian-en.json",
|
||||||
|
};
|
||||||
|
Api::new()?
|
||||||
|
.model("lmz/candle-marian".to_string())
|
||||||
|
.get(name)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
|
};
|
||||||
|
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let vb = {
|
||||||
|
let model = match args.model {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => match args.which {
|
||||||
|
Which::Base => Api::new()?
|
||||||
|
.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/4".to_string(),
|
||||||
|
))
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
Which::Big => Api::new()?
|
||||||
|
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||||
|
};
|
||||||
|
let mut model = marian::MTModel::new(&config, vb)?;
|
||||||
|
|
||||||
|
let mut logits_processor =
|
||||||
|
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||||
|
|
||||||
|
let encoder_xs = {
|
||||||
|
let mut tokens = tokenizer
|
||||||
|
.encode(args.text, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
tokens.push(config.eos_token_id);
|
||||||
|
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
|
model.encoder().forward(&tokens, 0)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut token_ids = vec![config.decoder_start_token_id];
|
||||||
|
for index in 0..1000 {
|
||||||
|
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||||
|
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||||
|
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||||
|
let token = logits_processor.sample(&logits)?;
|
||||||
|
token_ids.push(token);
|
||||||
|
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||||
|
use std::io::Write;
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
|
|||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
|
||||||
use candle::{DType, Result, Tensor, D};
|
use candle::{DType, Result, Tensor, D};
|
||||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
|
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
|
||||||
|
|
||||||
const IMAGE_DIM: usize = 784;
|
const IMAGE_DIM: usize = 784;
|
||||||
const LABELS: usize = 10;
|
const LABELS: usize = 10;
|
||||||
@ -95,7 +95,7 @@ impl ConvNet {
|
|||||||
.flatten_from(1)?
|
.flatten_from(1)?
|
||||||
.apply(&self.fc1)?
|
.apply(&self.fc1)?
|
||||||
.relu()?;
|
.relu()?;
|
||||||
self.dropout.forward(&xs, train)?.apply(&self.fc2)
|
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
|||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
enum NormType {
|
enum NormType {
|
||||||
WeightNorm,
|
WeightNorm,
|
||||||
|
TimeGroupNorm,
|
||||||
None,
|
None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -268,6 +269,7 @@ impl Module for EncodecConvTranspose1d {
|
|||||||
struct EncodecConv1d {
|
struct EncodecConv1d {
|
||||||
causal: bool,
|
causal: bool,
|
||||||
conv: Conv1d,
|
conv: Conv1d,
|
||||||
|
norm: Option<candle_nn::GroupNorm>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecConv1d {
|
impl EncodecConv1d {
|
||||||
@ -292,7 +294,7 @@ impl EncodecConv1d {
|
|||||||
},
|
},
|
||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
NormType::None => conv1d(
|
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||||
in_c,
|
in_c,
|
||||||
out_c,
|
out_c,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
@ -305,9 +307,17 @@ impl EncodecConv1d {
|
|||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
};
|
};
|
||||||
|
let norm = match cfg.norm_type {
|
||||||
|
NormType::None | NormType::WeightNorm => None,
|
||||||
|
NormType::TimeGroupNorm => {
|
||||||
|
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||||
|
Some(gn)
|
||||||
|
}
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
causal: cfg.use_causal_conv,
|
causal: cfg.use_causal_conv,
|
||||||
conv,
|
conv,
|
||||||
|
norm,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -316,8 +326,10 @@ impl Module for EncodecConv1d {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
// TODO: padding, depending on causal.
|
// TODO: padding, depending on causal.
|
||||||
let xs = self.conv.forward(xs)?;
|
let xs = self.conv.forward(xs)?;
|
||||||
// If we add support for NormType "time_group_norm", we should add some normalization here.
|
match &self.norm {
|
||||||
Ok(xs)
|
None => Ok(xs),
|
||||||
|
Some(norm) => xs.apply(norm),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
10
candle-examples/examples/onnx/README.md
Normal file
10
candle-examples/examples/onnx/README.md
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
## Using ONNX models in Candle
|
||||||
|
|
||||||
|
This example demonstrates how to run ONNX based models in Candle, the model
|
||||||
|
being used here is a small sequeezenet variant.
|
||||||
|
|
||||||
|
You can run the example with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
```
|
78
candle-examples/examples/onnx/main.rs
Normal file
78
candle-examples/examples/onnx/main.rs
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::{IndexOp, D};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
SqueezeNet,
|
||||||
|
EfficientNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// The model to be used.
|
||||||
|
#[arg(value_enum, long, default_value_t = Which::SqueezeNet)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
|
let image = match args.which {
|
||||||
|
Which::SqueezeNet => image,
|
||||||
|
Which::EfficientNet => image.permute((1, 2, 0))?,
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model = match args.model {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => match args.which {
|
||||||
|
Which::SqueezeNet => hf_hub::api::sync::Api::new()?
|
||||||
|
.model("lmz/candle-onnx".into())
|
||||||
|
.get("squeezenet1.1-7.onnx")?,
|
||||||
|
Which::EfficientNet => hf_hub::api::sync::Api::new()?
|
||||||
|
.model("onnx/EfficientNet-Lite4".into())
|
||||||
|
.get("efficientnet-lite4-11.onnx")?,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = candle_onnx::read_file(model)?;
|
||||||
|
let graph = model.graph.as_ref().unwrap();
|
||||||
|
let mut inputs = std::collections::HashMap::new();
|
||||||
|
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
|
||||||
|
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||||
|
let output = outputs.remove(&graph.output[0].name).unwrap();
|
||||||
|
let prs = match args.which {
|
||||||
|
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
|
||||||
|
Which::EfficientNet => output,
|
||||||
|
};
|
||||||
|
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
// Sort the predictions and take the top 5
|
||||||
|
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
||||||
|
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||||
|
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Print the top predictions
|
||||||
|
for &(i, p) in &top {
|
||||||
|
println!(
|
||||||
|
"{:50}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[i],
|
||||||
|
p * 100.0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
87
candle-examples/examples/onnx_basics.rs
Normal file
87
candle-examples/examples/onnx_basics.rs
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use candle::{Device, Tensor};
|
||||||
|
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
|
#[derive(Subcommand, Debug, Clone)]
|
||||||
|
enum Command {
|
||||||
|
Print {
|
||||||
|
#[arg(long)]
|
||||||
|
file: String,
|
||||||
|
},
|
||||||
|
SimpleEval {
|
||||||
|
#[arg(long)]
|
||||||
|
file: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
pub struct Args {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Command,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
match args.command {
|
||||||
|
Command::Print { file } => {
|
||||||
|
let model = candle_onnx::read_file(file)?;
|
||||||
|
println!("{model:?}");
|
||||||
|
let graph = model.graph.unwrap();
|
||||||
|
for node in graph.node.iter() {
|
||||||
|
println!("{node:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Command::SimpleEval { file } => {
|
||||||
|
let model = candle_onnx::read_file(file)?;
|
||||||
|
let graph = model.graph.as_ref().unwrap();
|
||||||
|
let constants: std::collections::HashSet<_> =
|
||||||
|
graph.initializer.iter().map(|i| i.name.as_str()).collect();
|
||||||
|
let mut inputs = std::collections::HashMap::new();
|
||||||
|
for input in graph.input.iter() {
|
||||||
|
use candle_onnx::onnx::tensor_proto::DataType;
|
||||||
|
if constants.contains(input.name.as_str()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let type_ = input.r#type.as_ref().expect("no type for input");
|
||||||
|
let type_ = type_.value.as_ref().expect("no type.value for input");
|
||||||
|
let value = match type_ {
|
||||||
|
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
|
||||||
|
let dt = match DataType::try_from(tt.elem_type) {
|
||||||
|
Ok(dt) => match candle_onnx::dtype(dt) {
|
||||||
|
Some(dt) => dt,
|
||||||
|
None => {
|
||||||
|
anyhow::bail!(
|
||||||
|
"unsupported 'value' data-type {dt:?} for {}",
|
||||||
|
input.name
|
||||||
|
)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||||
|
};
|
||||||
|
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
|
||||||
|
let dims = shape
|
||||||
|
.dim
|
||||||
|
.iter()
|
||||||
|
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
||||||
|
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
||||||
|
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<usize>>>()?;
|
||||||
|
Tensor::zeros(dims, dt, &Device::Cpu)?
|
||||||
|
}
|
||||||
|
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||||
|
};
|
||||||
|
println!("input {}: {value:?}", input.name);
|
||||||
|
inputs.insert(input.name.clone(), value);
|
||||||
|
}
|
||||||
|
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||||
|
for (name, value) in outputs.iter() {
|
||||||
|
println!("output {name}: {value:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -41,3 +41,16 @@ def median(arr):
|
|||||||
else:
|
else:
|
||||||
return arr[n//2]
|
return arr[n//2]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
This also supports the [Puffin Phi v2
|
||||||
|
model](https://huggingface.co/teknium/Puffin-Phi-v2) for human interaction.
|
||||||
|
```
|
||||||
|
$ cargo run --example phi --release -- \
|
||||||
|
--prompt "USER: What would you do on a sunny day in Paris?\nASSISTANT:" \
|
||||||
|
--sample-len 200 --model puffin-phi-v2 --quantized
|
||||||
|
USER: What would you do on a sunny day in Paris?
|
||||||
|
ASSISTANT: On a sunny day in Paris, you could visit the Musée du Louvre to admire the famous
|
||||||
|
painting "Mona Lisa" by Leonardo da Vinci. You might also want to stroll along the Champs-Élysées
|
||||||
|
and enjoy the beautiful architecture of the buildings around you. Don't forget to stop by a café
|
||||||
|
for a cup of coffee and to soak up the sun!"
|
||||||
|
```
|
||||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||||
@ -28,6 +28,7 @@ struct TextGeneration {
|
|||||||
logits_processor: LogitsProcessor,
|
logits_processor: LogitsProcessor,
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TextGeneration {
|
impl TextGeneration {
|
||||||
@ -40,6 +41,7 @@ impl TextGeneration {
|
|||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
@ -49,6 +51,7 @@ impl TextGeneration {
|
|||||||
logits_processor,
|
logits_processor,
|
||||||
repeat_penalty,
|
repeat_penalty,
|
||||||
repeat_last_n,
|
repeat_last_n,
|
||||||
|
verbose_prompt,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -56,20 +59,24 @@ impl TextGeneration {
|
|||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
print!("{prompt}");
|
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||||
std::io::stdout().flush()?;
|
if tokens.is_empty() {
|
||||||
let mut tokens = self
|
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
||||||
.tokenizer
|
}
|
||||||
.encode(prompt, true)
|
if self.verbose_prompt {
|
||||||
.map_err(E::msg)?
|
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||||
.get_ids()
|
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
.to_vec();
|
println!("{id:7} -> '{token}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut tokens = tokens.get_ids().to_vec();
|
||||||
let mut generated_tokens = 0usize;
|
let mut generated_tokens = 0usize;
|
||||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||||
Some(token) => *token,
|
Some(token) => *token,
|
||||||
None => anyhow::bail!("cannot find the endoftext token"),
|
None => anyhow::bail!("cannot find the endoftext token"),
|
||||||
};
|
};
|
||||||
|
print!("{prompt}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..sample_len {
|
for index in 0..sample_len {
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
@ -110,6 +117,16 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum WhichModel {
|
||||||
|
#[value(name = "1")]
|
||||||
|
V1,
|
||||||
|
#[value(name = "1.5")]
|
||||||
|
V1_5,
|
||||||
|
PuffinPhiV2,
|
||||||
|
PhiHermes,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -121,6 +138,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Display the token for the specified prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
verbose_prompt: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
@ -140,15 +161,21 @@ struct Args {
|
|||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long, default_value = "microsoft/phi-1_5")]
|
#[arg(long)]
|
||||||
model_id: String,
|
model_id: Option<String>,
|
||||||
|
|
||||||
#[arg(long, default_value = "refs/pr/18")]
|
#[arg(long, default_value = "1.5")]
|
||||||
revision: String,
|
model: WhichModel,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
weight_file: Option<String>,
|
weight_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
quantized: bool,
|
quantized: bool,
|
||||||
|
|
||||||
@ -189,20 +216,62 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let repo = api.repo(Repo::with_revision(
|
let model_id = match args.model_id {
|
||||||
args.model_id,
|
Some(model_id) => model_id.to_string(),
|
||||||
RepoType::Model,
|
None => {
|
||||||
args.revision,
|
if args.quantized {
|
||||||
));
|
"lmz/candle-quantized-phi".to_string()
|
||||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
} else {
|
||||||
|
match args.model {
|
||||||
|
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||||
|
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||||
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
|
"lmz/candle-quantized-phi".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let revision = match args.revision {
|
||||||
|
Some(rev) => rev.to_string(),
|
||||||
|
None => {
|
||||||
|
if args.quantized {
|
||||||
|
"main".to_string()
|
||||||
|
} else {
|
||||||
|
match args.model {
|
||||||
|
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||||
|
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||||
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
let tokenizer_filename = match args.tokenizer {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => match args.model {
|
||||||
|
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
|
||||||
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
|
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
let filename = match args.weight_file {
|
let filename = match args.weight_file {
|
||||||
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||||
None => {
|
None => {
|
||||||
if args.quantized {
|
if args.quantized {
|
||||||
api.model("lmz/candle-quantized-phi".to_string())
|
match args.model {
|
||||||
.get("model-q4k.gguf")?
|
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
|
||||||
|
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
|
||||||
|
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
|
||||||
|
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
repo.get("model.safetensors")?
|
match args.model {
|
||||||
|
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
|
||||||
|
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
|
||||||
|
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -210,7 +279,12 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = Config::v1_5();
|
let config = match args.model {
|
||||||
|
WhichModel::V1 => Config::v1(),
|
||||||
|
WhichModel::V1_5 => Config::v1_5(),
|
||||||
|
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||||
|
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||||
|
};
|
||||||
let (model, device) = if args.quantized {
|
let (model, device) = if args.quantized {
|
||||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||||
let model = QMixFormer::new(&config, vb)?;
|
let model = QMixFormer::new(&config, vb)?;
|
||||||
@ -231,6 +305,7 @@ fn main() -> Result<()> {
|
|||||||
args.top_p,
|
args.top_p,
|
||||||
args.repeat_penalty,
|
args.repeat_penalty,
|
||||||
args.repeat_last_n,
|
args.repeat_last_n,
|
||||||
|
args.verbose_prompt,
|
||||||
&device,
|
&device,
|
||||||
);
|
);
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# candle-quantized-t5
|
# candle-quantized-t5
|
||||||
|
|
||||||
|
## Seq2Seq example
|
||||||
|
|
||||||
This example uses a quantized version of the t5 model.
|
This example uses a quantized version of the t5 model.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -8,6 +10,8 @@ $ cargo run --example quantized-t5 --release -- --prompt "translate to German: A
|
|||||||
Eine schöne Kerze.
|
Eine schöne Kerze.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Generating Quantized weight files
|
||||||
|
|
||||||
The weight file is automatically retrieved from the hub. It is also possible to
|
The weight file is automatically retrieved from the hub. It is also possible to
|
||||||
generate quantized weight files from the original safetensors file by using the
|
generate quantized weight files from the original safetensors file by using the
|
||||||
`tensor-tools` command line utility via:
|
`tensor-tools` command line utility via:
|
||||||
@ -16,8 +20,11 @@ generate quantized weight files from the original safetensors file by using the
|
|||||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
To use a different model, specify the `model-id`. For example, you can use
|
## Using custom models
|
||||||
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
|
||||||
|
To use a different model, specify the `model-id`.
|
||||||
|
|
||||||
|
For example, for text editing, you can use quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example quantized-t5 --release -- \
|
$ cargo run --example quantized-t5 --release -- \
|
||||||
@ -26,6 +33,7 @@ $ cargo run --example quantized-t5 --release -- \
|
|||||||
--temperature 0
|
--temperature 0
|
||||||
...
|
...
|
||||||
Although their flight is weak, they run quickly through the tree canopy.
|
Although their flight is weak, they run quickly through the tree canopy.
|
||||||
|
```
|
||||||
|
|
||||||
By default, it will look for `model.gguf` and `config.json`, but you can specify
|
By default, it will look for `model.gguf` and `config.json`, but you can specify
|
||||||
custom local or remote `weight-file` and `config-file`s:
|
custom local or remote `weight-file` and `config-file`s:
|
||||||
@ -40,3 +48,16 @@ cargo run --example quantized-t5 --release -- \
|
|||||||
...
|
...
|
||||||
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
|
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### [MADLAD-400](https://arxiv.org/abs/2309.04662)
|
||||||
|
|
||||||
|
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example quantized-t5 --release -- \
|
||||||
|
--model-id "jbochi/madlad400-3b-mt" --weight-file "model-q4k.gguf" \
|
||||||
|
--prompt "<2de> How are you, my friend?" \
|
||||||
|
--temperature 0
|
||||||
|
...
|
||||||
|
Wie geht es dir, mein Freund?
|
||||||
|
```
|
||||||
|
@ -173,7 +173,11 @@ fn main() -> Result<()> {
|
|||||||
.to_vec();
|
.to_vec();
|
||||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
let mut model = builder.build_model()?;
|
let mut model = builder.build_model()?;
|
||||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
let mut output_token_ids = [builder
|
||||||
|
.config
|
||||||
|
.decoder_start_token_id
|
||||||
|
.unwrap_or(builder.config.pad_token_id) as u32]
|
||||||
|
.to_vec();
|
||||||
let temperature = if args.temperature <= 0. {
|
let temperature = if args.temperature <= 0. {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
|
@ -12,6 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
|
|||||||
use candle::{Device, Tensor};
|
use candle::{Device, Tensor};
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_transformers::models::quantized_llama as model;
|
use candle_transformers::models::quantized_llama as model;
|
||||||
use model::ModelWeights;
|
use model::ModelWeights;
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ enum Prompt {
|
|||||||
One(String),
|
One(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
enum Which {
|
enum Which {
|
||||||
#[value(name = "7b")]
|
#[value(name = "7b")]
|
||||||
L7b,
|
L7b,
|
||||||
@ -48,6 +49,12 @@ enum Which {
|
|||||||
Mistral7b,
|
Mistral7b,
|
||||||
#[value(name = "7b-mistral-instruct")]
|
#[value(name = "7b-mistral-instruct")]
|
||||||
Mistral7bInstruct,
|
Mistral7bInstruct,
|
||||||
|
#[value(name = "7b-zephyr-a")]
|
||||||
|
Zephyr7bAlpha,
|
||||||
|
#[value(name = "7b-zephyr-b")]
|
||||||
|
Zephyr7bBeta,
|
||||||
|
#[value(name = "7b-open-chat-3.5")]
|
||||||
|
OpenChat35,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -62,7 +69,50 @@ impl Which {
|
|||||||
| Self::L7bCode
|
| Self::L7bCode
|
||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode => false,
|
| Self::L34bCode => false,
|
||||||
Self::Mistral7b | Self::Mistral7bInstruct => true,
|
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||||
|
// same way.
|
||||||
|
Self::OpenChat35
|
||||||
|
| Self::Zephyr7bAlpha
|
||||||
|
| Self::Zephyr7bBeta
|
||||||
|
| Self::Mistral7b
|
||||||
|
| Self::Mistral7bInstruct => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_zephyr(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::L7b
|
||||||
|
| Self::L13b
|
||||||
|
| Self::L70b
|
||||||
|
| Self::L7bChat
|
||||||
|
| Self::L13bChat
|
||||||
|
| Self::L70bChat
|
||||||
|
| Self::L7bCode
|
||||||
|
| Self::L13bCode
|
||||||
|
| Self::L34bCode
|
||||||
|
| Self::Mistral7b
|
||||||
|
| Self::Mistral7bInstruct
|
||||||
|
| Self::OpenChat35 => false,
|
||||||
|
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_open_chat(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Which::L7b
|
||||||
|
| Which::L13b
|
||||||
|
| Which::L70b
|
||||||
|
| Which::L7bChat
|
||||||
|
| Which::L13bChat
|
||||||
|
| Which::L70bChat
|
||||||
|
| Which::L7bCode
|
||||||
|
| Which::L13bCode
|
||||||
|
| Which::L34bCode
|
||||||
|
| Which::Mistral7b
|
||||||
|
| Which::Mistral7bInstruct
|
||||||
|
| Which::Zephyr7bAlpha
|
||||||
|
| Which::Zephyr7bBeta => false,
|
||||||
|
Which::OpenChat35 => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -81,7 +131,7 @@ struct Args {
|
|||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(short = 'n', long, default_value_t = 100)]
|
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
/// The tokenizer config in json format.
|
/// The tokenizer config in json format.
|
||||||
@ -131,7 +181,9 @@ impl Args {
|
|||||||
Some(config) => std::path::PathBuf::from(config),
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let repo = if self.which.is_mistral() {
|
let repo = if self.which.is_open_chat() {
|
||||||
|
"openchat/openchat_3.5"
|
||||||
|
} else if self.which.is_mistral() {
|
||||||
"mistralai/Mistral-7B-v0.1"
|
"mistralai/Mistral-7B-v0.1"
|
||||||
} else {
|
} else {
|
||||||
"hf-internal-testing/llama-tokenizer"
|
"hf-internal-testing/llama-tokenizer"
|
||||||
@ -174,6 +226,14 @@ impl Args {
|
|||||||
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
||||||
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
||||||
),
|
),
|
||||||
|
Which::Zephyr7bAlpha => (
|
||||||
|
"TheBloke/zephyr-7B-alpha-GGUF",
|
||||||
|
"zephyr-7b-alpha.Q4_K_M.gguf",
|
||||||
|
),
|
||||||
|
Which::Zephyr7bBeta => {
|
||||||
|
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
||||||
|
}
|
||||||
|
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
@ -184,31 +244,6 @@ impl Args {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
|
|
||||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
|
||||||
// heuristics as it seems to work well enough for this example. See the following for more
|
|
||||||
// details:
|
|
||||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
|
||||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
|
||||||
let text = text.replace('▁', " ");
|
|
||||||
let ascii = text
|
|
||||||
.strip_prefix("<0x")
|
|
||||||
.and_then(|t| t.strip_suffix('>'))
|
|
||||||
.and_then(|t| u8::from_str_radix(t, 16).ok());
|
|
||||||
match ascii {
|
|
||||||
None => print!("{text}"),
|
|
||||||
Some(ascii) => {
|
|
||||||
if let Some(chr) = char::from_u32(ascii as u32) {
|
|
||||||
if chr.is_ascii() {
|
|
||||||
print!("{chr}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let _ = std::io::stdout().flush();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_size(size_in_bytes: usize) -> String {
|
fn format_size(size_in_bytes: usize) -> String {
|
||||||
if size_in_bytes < 1_000 {
|
if size_in_bytes < 1_000 {
|
||||||
format!("{}B", size_in_bytes)
|
format!("{}B", size_in_bytes)
|
||||||
@ -295,7 +330,13 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L7bCode
|
| Which::L7bCode
|
||||||
| Which::L13bCode
|
| Which::L13bCode
|
||||||
| Which::L34bCode => 1,
|
| Which::L34bCode => 1,
|
||||||
Which::Mistral7b | Which::Mistral7bInstruct | Which::L70b | Which::L70bChat => 8,
|
Which::Mistral7b
|
||||||
|
| Which::Mistral7bInstruct
|
||||||
|
| Which::Zephyr7bAlpha
|
||||||
|
| Which::Zephyr7bBeta
|
||||||
|
| Which::L70b
|
||||||
|
| Which::L70bChat
|
||||||
|
| Which::OpenChat35 => 8,
|
||||||
};
|
};
|
||||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||||
}
|
}
|
||||||
@ -303,6 +344,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
println!("model built");
|
println!("model built");
|
||||||
|
|
||||||
let tokenizer = args.tokenizer()?;
|
let tokenizer = args.tokenizer()?;
|
||||||
|
let mut tos = TokenOutputStream::new(tokenizer);
|
||||||
let prompt = match args.prompt.as_deref() {
|
let prompt = match args.prompt.as_deref() {
|
||||||
Some("chat") => Prompt::Chat,
|
Some("chat") => Prompt::Chat,
|
||||||
Some("interactive") => Prompt::Interactive,
|
Some("interactive") => Prompt::Interactive,
|
||||||
@ -311,10 +353,11 @@ fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let mut pre_prompt_tokens = vec![];
|
let mut pre_prompt_tokens = vec![];
|
||||||
loop {
|
for prompt_index in 0.. {
|
||||||
let prompt_str = match &prompt {
|
let prompt_str = match &prompt {
|
||||||
Prompt::One(prompt) => prompt.clone(),
|
Prompt::One(prompt) => prompt.clone(),
|
||||||
Prompt::Interactive | Prompt::Chat => {
|
Prompt::Interactive | Prompt::Chat => {
|
||||||
|
let is_interactive = matches!(prompt, Prompt::Interactive);
|
||||||
print!("> ");
|
print!("> ");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
@ -325,7 +368,15 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt.pop();
|
prompt.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if args.which.is_mistral() {
|
if args.which.is_open_chat() {
|
||||||
|
format!("User: {prompt}<|end_of_turn|>Assistant: ")
|
||||||
|
} else if args.which.is_zephyr() {
|
||||||
|
if prompt_index == 0 || is_interactive {
|
||||||
|
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
|
||||||
|
} else {
|
||||||
|
format!("<|user|>\n{prompt}</s>\n<|assistant|>")
|
||||||
|
}
|
||||||
|
} else if args.which.is_mistral() {
|
||||||
format!("[INST] {prompt} [/INST]")
|
format!("[INST] {prompt} [/INST]")
|
||||||
} else {
|
} else {
|
||||||
prompt
|
prompt
|
||||||
@ -333,7 +384,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
print!("{}", &prompt_str);
|
print!("{}", &prompt_str);
|
||||||
let tokens = tokenizer
|
let tokens = tos
|
||||||
|
.tokenizer()
|
||||||
.encode(prompt_str, true)
|
.encode(prompt_str, true)
|
||||||
.map_err(anyhow::Error::msg)?;
|
.map_err(anyhow::Error::msg)?;
|
||||||
if args.verbose_prompt {
|
if args.verbose_prompt {
|
||||||
@ -363,11 +415,19 @@ fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
let prompt_dt = start_prompt_processing.elapsed();
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
print_token(next_token, &tokenizer);
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let eos_token = if args.which.is_open_chat() {
|
||||||
|
"<|end_of_turn|>"
|
||||||
|
} else {
|
||||||
|
"</s>"
|
||||||
|
};
|
||||||
|
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||||
let start_post_prompt = std::time::Instant::now();
|
let start_post_prompt = std::time::Instant::now();
|
||||||
|
let mut sampled = 0;
|
||||||
for index in 0..to_sample {
|
for index in 0..to_sample {
|
||||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
@ -384,11 +444,19 @@ fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
next_token = logits_processor.sample(&logits)?;
|
next_token = logits_processor.sample(&logits)?;
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
print_token(next_token, &tokenizer);
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
sampled += 1;
|
||||||
if next_token == eos_token {
|
if next_token == eos_token {
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
let dt = start_post_prompt.elapsed();
|
let dt = start_post_prompt.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||||
@ -396,9 +464,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||||
);
|
);
|
||||||
println!(
|
println!(
|
||||||
"{:4} tokens generated: {:.2} token/s",
|
"{sampled:4} tokens generated: {:.2} token/s",
|
||||||
to_sample,
|
sampled as f64 / dt.as_secs_f64(),
|
||||||
to_sample as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
match prompt {
|
match prompt {
|
||||||
|
@ -1,360 +1,451 @@
|
|||||||
/* Deep Deterministic Policy Gradient.
|
use std::collections::VecDeque;
|
||||||
|
use std::fmt::Display;
|
||||||
|
|
||||||
Continuous control with deep reinforcement learning, Lillicrap et al. 2015
|
use candle::{DType, Device, Error, Module, Result, Tensor, Var};
|
||||||
https://arxiv.org/abs/1509.02971
|
use candle_nn::{
|
||||||
|
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
|
||||||
|
VarBuilder, VarMap,
|
||||||
|
};
|
||||||
|
use rand::{distributions::Uniform, thread_rng, Rng};
|
||||||
|
|
||||||
See https://spinningup.openai.com/en/latest/algorithms/ddpg.html for a
|
pub struct OuNoise {
|
||||||
reference python implementation.
|
|
||||||
*/
|
|
||||||
use super::gym_env::GymEnv;
|
|
||||||
use candle::{DType, Device, Result, Tensor};
|
|
||||||
use candle_nn::VarMap;
|
|
||||||
|
|
||||||
// The impact of the q value of the next state on the current state's q value.
|
|
||||||
const GAMMA: f64 = 0.99;
|
|
||||||
// The weight for updating the target networks.
|
|
||||||
const TAU: f64 = 0.005;
|
|
||||||
// The capacity of the replay buffer used for sampling training data.
|
|
||||||
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
|
||||||
// The training batch size for each training iteration.
|
|
||||||
const TRAINING_BATCH_SIZE: usize = 100;
|
|
||||||
// The total number of episodes.
|
|
||||||
const MAX_EPISODES: usize = 100;
|
|
||||||
// The maximum length of an episode.
|
|
||||||
const EPISODE_LENGTH: usize = 200;
|
|
||||||
// The number of training iterations after one episode finishes.
|
|
||||||
const TRAINING_ITERATIONS: usize = 200;
|
|
||||||
|
|
||||||
// Ornstein-Uhlenbeck process parameters.
|
|
||||||
const MU: f64 = 0.0;
|
|
||||||
const THETA: f64 = 0.15;
|
|
||||||
const SIGMA: f64 = 0.1;
|
|
||||||
|
|
||||||
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
|
||||||
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
|
||||||
|
|
||||||
struct OuNoise {
|
|
||||||
mu: f64,
|
mu: f64,
|
||||||
theta: f64,
|
theta: f64,
|
||||||
sigma: f64,
|
sigma: f64,
|
||||||
state: Tensor,
|
state: Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OuNoise {
|
impl OuNoise {
|
||||||
fn new(mu: f64, theta: f64, sigma: f64, num_actions: usize) -> Result<Self> {
|
pub fn new(mu: f64, theta: f64, sigma: f64, size_action: usize) -> Result<Self> {
|
||||||
let state = Tensor::ones(num_actions, DType::F32, &Device::Cpu)?;
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
mu,
|
mu,
|
||||||
theta,
|
theta,
|
||||||
sigma,
|
sigma,
|
||||||
state,
|
state: Tensor::ones(size_action, DType::F32, &Device::Cpu)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample(&mut self) -> Result<Tensor> {
|
pub fn sample(&mut self) -> Result<Tensor> {
|
||||||
let dx = (((self.mu - &self.state)? * self.theta)?
|
let rand = Tensor::randn_like(&self.state, 0.0, 1.0)?;
|
||||||
+ (self.state.randn_like(0., 1.)? * self.beta)?)?;
|
let dx = ((self.theta * (self.mu - &self.state)?)? + (self.sigma * rand)?)?;
|
||||||
self.state = (self.state + dx)?;
|
self.state = (&self.state + dx)?;
|
||||||
Ok(self.state.clone())
|
Ok(self.state.clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ReplayBuffer {
|
#[derive(Clone)]
|
||||||
obs: Tensor,
|
struct Transition {
|
||||||
next_obs: Vec<Tensor>,
|
state: Tensor,
|
||||||
rewards: Vec<Tensor>,
|
action: Tensor,
|
||||||
actions: Vec<Tensor>,
|
reward: Tensor,
|
||||||
capacity: usize,
|
next_state: Tensor,
|
||||||
len: usize,
|
terminated: bool,
|
||||||
i: usize,
|
truncated: bool,
|
||||||
}
|
}
|
||||||
|
impl Transition {
|
||||||
impl ReplayBuffer {
|
|
||||||
fn new(capacity: usize, num_obs: usize, num_actions: usize) -> Self {
|
|
||||||
let cpu = Device::Cpu;
|
|
||||||
let obs = vec![Tensor::zeros(num_obs, DType::F32, &cpu)?; capacity];
|
|
||||||
let next_obs = vec![Tensor::zeros(num_obs, DType::F32, &cpu)?; capacity];
|
|
||||||
let rewards = vec![Tensor::zeros(1, DType::F32, &cpu)?; capacity];
|
|
||||||
let actions = vec![Tensor::zeros(num_actions, DType::F32, &cpu)?; capacity];
|
|
||||||
Ok(Self {
|
|
||||||
obs,
|
|
||||||
next_obs,
|
|
||||||
rewards,
|
|
||||||
actions,
|
|
||||||
capacity,
|
|
||||||
len: 0,
|
|
||||||
i: 0,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn push(&mut self, obs: &Tensor, actions: &Tensor, reward: &Tensor, next_obs: &Tensor) {
|
|
||||||
let i = self.i % self.capacity;
|
|
||||||
self.obs.get(i as _).copy_(obs);
|
|
||||||
self.rewards.get(i as _).copy_(reward);
|
|
||||||
self.actions.get(i as _).copy_(actions);
|
|
||||||
self.next_obs.get(i as _).copy_(next_obs);
|
|
||||||
self.i += 1;
|
|
||||||
if self.len < self.capacity {
|
|
||||||
self.len += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn random_batch(&self, batch_size: usize) -> Option<(Tensor, Tensor, Tensor, Tensor)> {
|
|
||||||
if self.len < 3 {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let batch_size = batch_size.min(self.len - 1);
|
|
||||||
let batch_indexes = Tensor::randint((self.len - 2) as _, [batch_size as _], INT64_CPU);
|
|
||||||
|
|
||||||
let states = self.obs.index_select(0, &batch_indexes);
|
|
||||||
let next_states = self.next_obs.index_select(0, &batch_indexes);
|
|
||||||
let actions = self.actions.index_select(0, &batch_indexes);
|
|
||||||
let rewards = self.rewards.index_select(0, &batch_indexes);
|
|
||||||
|
|
||||||
Some((states, actions, rewards, next_states))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Actor {
|
|
||||||
varmap: VarMap,
|
|
||||||
network: candle_nn::Func,
|
|
||||||
num_obs: usize,
|
|
||||||
num_actions: usize,
|
|
||||||
opt: candle_nn::AdamW,
|
|
||||||
learning_rate: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Actor {
|
|
||||||
fn new(num_obs: usize, num_actions: usize, learning_rate: f64) -> Self {
|
|
||||||
let mut varmap = VarMap::new();
|
|
||||||
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
|
||||||
let al1 = candle_nn::linear(num_obs, 400, vb.pp("al1"))?;
|
|
||||||
let al2 = candle_nn::linear(400, 300, vb.pp("al2"))?;
|
|
||||||
let al3 = candle_nn::linear(300, num_actions, vb.pp("al3"))?;
|
|
||||||
let network = Func::new(|xs| {
|
|
||||||
xs.apply(al1)?
|
|
||||||
.relu()?
|
|
||||||
.apply(al2)?
|
|
||||||
.relu()?
|
|
||||||
.apply(al3)?
|
|
||||||
.tanh()
|
|
||||||
});
|
|
||||||
let opt = nn::Adam::default()
|
|
||||||
.build(&var_store, learning_rate)
|
|
||||||
.unwrap();
|
|
||||||
let p = &var_store.root();
|
|
||||||
Self {
|
|
||||||
network,
|
|
||||||
num_obs,
|
|
||||||
num_actions,
|
|
||||||
varmap,
|
|
||||||
opt,
|
|
||||||
learning_rate,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, obs: &Tensor) -> Result<Tensor> {
|
|
||||||
obs.apply(&self.network)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Critic {
|
|
||||||
varmap: VarMap,
|
|
||||||
network: candle_nn::Func,
|
|
||||||
num_obs: usize,
|
|
||||||
num_actions: usize,
|
|
||||||
opt: candle_nn::AdamW,
|
|
||||||
learning_rate: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Critic {
|
|
||||||
fn new(num_obs: usize, num_actions: usize, learning_rate: f64) -> Result<Self> {
|
|
||||||
let varmap = VarMap::new();
|
|
||||||
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
|
|
||||||
let cl1 = candle_nn::linear(num_obs + num_actions, 400, vb.pp("cl1"))?;
|
|
||||||
let cl2 = candle_nn::linear(400, 300, vb.pp("cl2"))?;
|
|
||||||
let cl3 = candle_nn::linear(300, 1, vb.pp("cl3"))?;
|
|
||||||
let network = Func::new(|xs| xs.apply(cl1)?.relu()?.apply(&cl2)?.relu()?.apply(cl3));
|
|
||||||
let adamw_params = candle_nn::ParamsAdamW {
|
|
||||||
lr: 1e-3,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let opt = AdamW::new(varmap.all_vars(), adamw_params);
|
|
||||||
Ok(Self {
|
|
||||||
network,
|
|
||||||
varmap,
|
|
||||||
num_obs,
|
|
||||||
num_actions,
|
|
||||||
opt,
|
|
||||||
learning_rate,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, obs: &Tensor, actions: &Tensor) -> Result<Tensor> {
|
|
||||||
let xs = Tensor::cat(&[actions, obs], 1)?;
|
|
||||||
xs.apply(&self.network)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* TODO: enable tracking
|
|
||||||
fn track(dest: &mut nn::VarStore, src: &nn::VarStore, tau: f64) {
|
|
||||||
tch::no_grad(|| {
|
|
||||||
for (dest, src) in dest
|
|
||||||
.trainable_variables()
|
|
||||||
.iter_mut()
|
|
||||||
.zip(src.trainable_variables().iter())
|
|
||||||
{
|
|
||||||
dest.copy_(&(tau * src + (1.0 - tau) * &*dest));
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
struct Agent {
|
|
||||||
actor: Actor,
|
|
||||||
actor_target: Actor,
|
|
||||||
|
|
||||||
critic: Critic,
|
|
||||||
critic_target: Critic,
|
|
||||||
|
|
||||||
replay_buffer: ReplayBuffer,
|
|
||||||
|
|
||||||
ou_noise: OuNoise,
|
|
||||||
|
|
||||||
train: bool,
|
|
||||||
|
|
||||||
gamma: f64,
|
|
||||||
tau: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Agent {
|
|
||||||
fn new(
|
fn new(
|
||||||
actor: Actor,
|
state: &Tensor,
|
||||||
critic: Critic,
|
action: &Tensor,
|
||||||
ou_noise: OuNoise,
|
reward: &Tensor,
|
||||||
replay_buffer_capacity: usize,
|
next_state: &Tensor,
|
||||||
train: bool,
|
terminated: bool,
|
||||||
gamma: f64,
|
truncated: bool,
|
||||||
tau: f64,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let actor_target = actor.clone();
|
|
||||||
let critic_target = critic.clone();
|
|
||||||
let replay_buffer =
|
|
||||||
ReplayBuffer::new(replay_buffer_capacity, actor.num_obs, actor.num_actions);
|
|
||||||
Self {
|
Self {
|
||||||
actor,
|
state: state.clone(),
|
||||||
actor_target,
|
action: action.clone(),
|
||||||
critic,
|
reward: reward.clone(),
|
||||||
critic_target,
|
next_state: next_state.clone(),
|
||||||
replay_buffer,
|
terminated,
|
||||||
ou_noise,
|
truncated,
|
||||||
train,
|
|
||||||
gamma,
|
|
||||||
tau,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn actions(&mut self, obs: &Tensor) -> Result<Tensor> {
|
|
||||||
let mut actions = tch::no_grad(|| self.actor.forward(obs));
|
|
||||||
if self.train {
|
|
||||||
actions += self.ou_noise.sample();
|
|
||||||
}
|
|
||||||
actions
|
|
||||||
}
|
|
||||||
|
|
||||||
fn remember(&mut self, obs: &Tensor, actions: &Tensor, reward: &Tensor, next_obs: &Tensor) {
|
|
||||||
self.replay_buffer.push(obs, actions, reward, next_obs);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn train(&mut self, batch_size: usize) {
|
|
||||||
let (states, actions, rewards, next_states) =
|
|
||||||
match self.replay_buffer.random_batch(batch_size) {
|
|
||||||
Some(v) => v,
|
|
||||||
_ => return, // We don't have enough samples for training yet.
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut q_target = self
|
|
||||||
.critic_target
|
|
||||||
.forward(&next_states, &self.actor_target.forward(&next_states));
|
|
||||||
q_target = rewards + (self.gamma * q_target).detach();
|
|
||||||
|
|
||||||
let q = self.critic.forward(&states, &actions);
|
|
||||||
|
|
||||||
let diff = q_target - q;
|
|
||||||
let critic_loss = (&diff * &diff).mean(Float);
|
|
||||||
|
|
||||||
self.critic.opt.zero_grad();
|
|
||||||
critic_loss.backward();
|
|
||||||
self.critic.opt.step();
|
|
||||||
|
|
||||||
let actor_loss = -self
|
|
||||||
.critic
|
|
||||||
.forward(&states, &self.actor.forward(&states))
|
|
||||||
.mean(Float);
|
|
||||||
|
|
||||||
self.actor.opt.zero_grad();
|
|
||||||
actor_loss.backward();
|
|
||||||
self.actor.opt.step();
|
|
||||||
|
|
||||||
track(
|
|
||||||
&mut self.critic_target.var_store,
|
|
||||||
&self.critic.var_store,
|
|
||||||
self.tau,
|
|
||||||
);
|
|
||||||
track(
|
|
||||||
&mut self.actor_target.var_store,
|
|
||||||
&self.actor.var_store,
|
|
||||||
self.tau,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run() -> Result<()> {
|
pub struct ReplayBuffer {
|
||||||
let env = GymEnv::new("Pendulum-v1")?;
|
buffer: VecDeque<Transition>,
|
||||||
println!("action space: {}", env.action_space());
|
capacity: usize,
|
||||||
println!("observation space: {:?}", env.observation_space());
|
size: usize,
|
||||||
|
}
|
||||||
let num_obs = env.observation_space().iter().product::<usize>();
|
impl ReplayBuffer {
|
||||||
let num_actions = env.action_space();
|
pub fn new(capacity: usize) -> Self {
|
||||||
|
Self {
|
||||||
let actor = Actor::new(num_obs, num_actions, ACTOR_LEARNING_RATE);
|
buffer: VecDeque::with_capacity(capacity),
|
||||||
let critic = Critic::new(num_obs, num_actions, CRITIC_LEARNING_RATE);
|
capacity,
|
||||||
let ou_noise = OuNoise::new(MU, THETA, SIGMA, num_actions);
|
size: 0,
|
||||||
let mut agent = Agent::new(
|
|
||||||
actor,
|
|
||||||
critic,
|
|
||||||
ou_noise,
|
|
||||||
REPLAY_BUFFER_CAPACITY,
|
|
||||||
true,
|
|
||||||
GAMMA,
|
|
||||||
TAU,
|
|
||||||
);
|
|
||||||
|
|
||||||
for episode in 0..MAX_EPISODES as u64 {
|
|
||||||
let mut obs = env.reset(episode)?;
|
|
||||||
|
|
||||||
let mut total_reward = 0.0;
|
|
||||||
for _ in 0..EPISODE_LENGTH {
|
|
||||||
let actions: f32 = 2.0 * agent.actions(&obs)?.to_vec0::<f32>()?;
|
|
||||||
let actions = actions.clamp(-2.0, 2.0);
|
|
||||||
let step = env.step(vec![action_vec])?;
|
|
||||||
total_reward += step.reward;
|
|
||||||
|
|
||||||
agent.remember(&obs, &actions.into(), &step.reward.into(), &step.obs);
|
|
||||||
|
|
||||||
if step.is_done {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
obs = step.obs;
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("episode {episode} with total reward of {total_reward}");
|
|
||||||
|
|
||||||
for _ in 0..TRAINING_ITERATIONS {
|
|
||||||
agent.train(TRAINING_BATCH_SIZE);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn push(
|
||||||
|
&mut self,
|
||||||
|
state: &Tensor,
|
||||||
|
action: &Tensor,
|
||||||
|
reward: &Tensor,
|
||||||
|
next_state: &Tensor,
|
||||||
|
terminated: bool,
|
||||||
|
truncated: bool,
|
||||||
|
) {
|
||||||
|
if self.size == self.capacity {
|
||||||
|
self.buffer.pop_front();
|
||||||
|
} else {
|
||||||
|
self.size += 1;
|
||||||
|
}
|
||||||
|
self.buffer.push_back(Transition::new(
|
||||||
|
state, action, reward, next_state, terminated, truncated,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
pub fn random_batch(
|
||||||
|
&self,
|
||||||
|
batch_size: usize,
|
||||||
|
) -> Result<Option<(Tensor, Tensor, Tensor, Tensor, Vec<bool>, Vec<bool>)>> {
|
||||||
|
if self.size < batch_size {
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
|
let transitions: Vec<&Transition> = thread_rng()
|
||||||
|
.sample_iter(Uniform::from(0..self.size))
|
||||||
|
.take(batch_size)
|
||||||
|
.map(|i| self.buffer.get(i).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let states: Vec<Tensor> = transitions
|
||||||
|
.iter()
|
||||||
|
.map(|t| t.state.unsqueeze(0))
|
||||||
|
.collect::<Result<_>>()?;
|
||||||
|
let actions: Vec<Tensor> = transitions
|
||||||
|
.iter()
|
||||||
|
.map(|t| t.action.unsqueeze(0))
|
||||||
|
.collect::<Result<_>>()?;
|
||||||
|
let rewards: Vec<Tensor> = transitions
|
||||||
|
.iter()
|
||||||
|
.map(|t| t.reward.unsqueeze(0))
|
||||||
|
.collect::<Result<_>>()?;
|
||||||
|
let next_states: Vec<Tensor> = transitions
|
||||||
|
.iter()
|
||||||
|
.map(|t| t.next_state.unsqueeze(0))
|
||||||
|
.collect::<Result<_>>()?;
|
||||||
|
let terminateds: Vec<bool> = transitions.iter().map(|t| t.terminated).collect();
|
||||||
|
let truncateds: Vec<bool> = transitions.iter().map(|t| t.truncated).collect();
|
||||||
|
|
||||||
|
Ok(Some((
|
||||||
|
Tensor::cat(&states, 0)?,
|
||||||
|
Tensor::cat(&actions, 0)?,
|
||||||
|
Tensor::cat(&rewards, 0)?,
|
||||||
|
Tensor::cat(&next_states, 0)?,
|
||||||
|
terminateds,
|
||||||
|
truncateds,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn track(
|
||||||
|
varmap: &mut VarMap,
|
||||||
|
vb: &VarBuilder,
|
||||||
|
target_prefix: &str,
|
||||||
|
network_prefix: &str,
|
||||||
|
dims: &[(usize, usize)],
|
||||||
|
tau: f64,
|
||||||
|
) -> Result<()> {
|
||||||
|
for (i, &(in_dim, out_dim)) in dims.iter().enumerate() {
|
||||||
|
let target_w = vb.get((out_dim, in_dim), &format!("{target_prefix}-fc{i}.weight"))?;
|
||||||
|
let network_w = vb.get((out_dim, in_dim), &format!("{network_prefix}-fc{i}.weight"))?;
|
||||||
|
varmap.set_one(
|
||||||
|
format!("{target_prefix}-fc{i}.weight"),
|
||||||
|
((tau * network_w)? + ((1.0 - tau) * target_w)?)?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let target_b = vb.get(out_dim, &format!("{target_prefix}-fc{i}.bias"))?;
|
||||||
|
let network_b = vb.get(out_dim, &format!("{network_prefix}-fc{i}.bias"))?;
|
||||||
|
varmap.set_one(
|
||||||
|
format!("{target_prefix}-fc{i}.bias"),
|
||||||
|
((tau * network_b)? + ((1.0 - tau) * target_b)?)?,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Actor<'a> {
|
||||||
|
varmap: VarMap,
|
||||||
|
vb: VarBuilder<'a>,
|
||||||
|
network: Sequential,
|
||||||
|
target_network: Sequential,
|
||||||
|
size_state: usize,
|
||||||
|
size_action: usize,
|
||||||
|
dims: Vec<(usize, usize)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Actor<'_> {
|
||||||
|
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
|
||||||
|
let mut varmap = VarMap::new();
|
||||||
|
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
|
||||||
|
|
||||||
|
let dims = vec![(size_state, 400), (400, 300), (300, size_action)];
|
||||||
|
|
||||||
|
let make_network = |prefix: &str| {
|
||||||
|
let seq = seq()
|
||||||
|
.add(linear(
|
||||||
|
dims[0].0,
|
||||||
|
dims[0].1,
|
||||||
|
vb.pp(format!("{prefix}-fc0")),
|
||||||
|
)?)
|
||||||
|
.add(Activation::Relu)
|
||||||
|
.add(linear(
|
||||||
|
dims[1].0,
|
||||||
|
dims[1].1,
|
||||||
|
vb.pp(format!("{prefix}-fc1")),
|
||||||
|
)?)
|
||||||
|
.add(Activation::Relu)
|
||||||
|
.add(linear(
|
||||||
|
dims[2].0,
|
||||||
|
dims[2].1,
|
||||||
|
vb.pp(format!("{prefix}-fc2")),
|
||||||
|
)?)
|
||||||
|
.add(func(|xs| xs.tanh()));
|
||||||
|
Ok::<Sequential, Error>(seq)
|
||||||
|
};
|
||||||
|
|
||||||
|
let network = make_network("actor")?;
|
||||||
|
let target_network = make_network("target-actor")?;
|
||||||
|
|
||||||
|
// this sets the two networks to be equal to each other using tau = 1.0
|
||||||
|
track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
varmap,
|
||||||
|
vb,
|
||||||
|
network,
|
||||||
|
target_network,
|
||||||
|
size_state,
|
||||||
|
size_action,
|
||||||
|
dims,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, state: &Tensor) -> Result<Tensor> {
|
||||||
|
self.network.forward(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn target_forward(&self, state: &Tensor) -> Result<Tensor> {
|
||||||
|
self.target_network.forward(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn track(&mut self, tau: f64) -> Result<()> {
|
||||||
|
track(
|
||||||
|
&mut self.varmap,
|
||||||
|
&self.vb,
|
||||||
|
"target-actor",
|
||||||
|
"actor",
|
||||||
|
&self.dims,
|
||||||
|
tau,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Critic<'a> {
|
||||||
|
varmap: VarMap,
|
||||||
|
vb: VarBuilder<'a>,
|
||||||
|
network: Sequential,
|
||||||
|
target_network: Sequential,
|
||||||
|
size_state: usize,
|
||||||
|
size_action: usize,
|
||||||
|
dims: Vec<(usize, usize)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Critic<'_> {
|
||||||
|
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
|
||||||
|
let mut varmap = VarMap::new();
|
||||||
|
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
|
||||||
|
|
||||||
|
let dims: Vec<(usize, usize)> = vec![(size_state + size_action, 400), (400, 300), (300, 1)];
|
||||||
|
|
||||||
|
let make_network = |prefix: &str| {
|
||||||
|
let seq = seq()
|
||||||
|
.add(linear(
|
||||||
|
dims[0].0,
|
||||||
|
dims[0].1,
|
||||||
|
vb.pp(format!("{prefix}-fc0")),
|
||||||
|
)?)
|
||||||
|
.add(Activation::Relu)
|
||||||
|
.add(linear(
|
||||||
|
dims[1].0,
|
||||||
|
dims[1].1,
|
||||||
|
vb.pp(format!("{prefix}-fc1")),
|
||||||
|
)?)
|
||||||
|
.add(Activation::Relu)
|
||||||
|
.add(linear(
|
||||||
|
dims[2].0,
|
||||||
|
dims[2].1,
|
||||||
|
vb.pp(format!("{prefix}-fc2")),
|
||||||
|
)?);
|
||||||
|
Ok::<Sequential, Error>(seq)
|
||||||
|
};
|
||||||
|
|
||||||
|
let network = make_network("critic")?;
|
||||||
|
let target_network = make_network("target-critic")?;
|
||||||
|
|
||||||
|
// this sets the two networks to be equal to each other using tau = 1.0
|
||||||
|
track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
varmap,
|
||||||
|
vb,
|
||||||
|
network,
|
||||||
|
target_network,
|
||||||
|
size_state,
|
||||||
|
size_action,
|
||||||
|
dims,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = Tensor::cat(&[action, state], 1)?;
|
||||||
|
self.network.forward(&xs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn target_forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = Tensor::cat(&[action, state], 1)?;
|
||||||
|
self.target_network.forward(&xs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn track(&mut self, tau: f64) -> Result<()> {
|
||||||
|
track(
|
||||||
|
&mut self.varmap,
|
||||||
|
&self.vb,
|
||||||
|
"target-critic",
|
||||||
|
"critic",
|
||||||
|
&self.dims,
|
||||||
|
tau,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
pub struct DDPG<'a> {
|
||||||
|
actor: Actor<'a>,
|
||||||
|
actor_optim: AdamW,
|
||||||
|
critic: Critic<'a>,
|
||||||
|
critic_optim: AdamW,
|
||||||
|
gamma: f64,
|
||||||
|
tau: f64,
|
||||||
|
replay_buffer: ReplayBuffer,
|
||||||
|
ou_noise: OuNoise,
|
||||||
|
|
||||||
|
size_state: usize,
|
||||||
|
size_action: usize,
|
||||||
|
pub train: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DDPG<'_> {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn new(
|
||||||
|
device: &Device,
|
||||||
|
size_state: usize,
|
||||||
|
size_action: usize,
|
||||||
|
train: bool,
|
||||||
|
actor_lr: f64,
|
||||||
|
critic_lr: f64,
|
||||||
|
gamma: f64,
|
||||||
|
tau: f64,
|
||||||
|
buffer_capacity: usize,
|
||||||
|
ou_noise: OuNoise,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let filter_by_prefix = |varmap: &VarMap, prefix: &str| {
|
||||||
|
varmap
|
||||||
|
.data()
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(name, var)| name.starts_with(prefix).then_some(var.clone()))
|
||||||
|
.collect::<Vec<Var>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
let actor = Actor::new(device, DType::F32, size_state, size_action)?;
|
||||||
|
let actor_optim = AdamW::new(
|
||||||
|
filter_by_prefix(&actor.varmap, "actor"),
|
||||||
|
ParamsAdamW {
|
||||||
|
lr: actor_lr,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let critic = Critic::new(device, DType::F32, size_state, size_action)?;
|
||||||
|
let critic_optim = AdamW::new(
|
||||||
|
filter_by_prefix(&critic.varmap, "critic"),
|
||||||
|
ParamsAdamW {
|
||||||
|
lr: critic_lr,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
actor,
|
||||||
|
actor_optim,
|
||||||
|
critic,
|
||||||
|
critic_optim,
|
||||||
|
gamma,
|
||||||
|
tau,
|
||||||
|
replay_buffer: ReplayBuffer::new(buffer_capacity),
|
||||||
|
ou_noise,
|
||||||
|
size_state,
|
||||||
|
size_action,
|
||||||
|
train,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remember(
|
||||||
|
&mut self,
|
||||||
|
state: &Tensor,
|
||||||
|
action: &Tensor,
|
||||||
|
reward: &Tensor,
|
||||||
|
next_state: &Tensor,
|
||||||
|
terminated: bool,
|
||||||
|
truncated: bool,
|
||||||
|
) {
|
||||||
|
self.replay_buffer
|
||||||
|
.push(state, action, reward, next_state, terminated, truncated)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
|
||||||
|
let actions = self
|
||||||
|
.actor
|
||||||
|
.forward(&state.detach()?.unsqueeze(0)?)?
|
||||||
|
.squeeze(0)?;
|
||||||
|
let actions = if self.train {
|
||||||
|
(actions + self.ou_noise.sample()?)?
|
||||||
|
} else {
|
||||||
|
actions
|
||||||
|
};
|
||||||
|
actions.squeeze(0)?.to_scalar::<f32>()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train(&mut self, batch_size: usize) -> Result<()> {
|
||||||
|
let (states, actions, rewards, next_states, _, _) =
|
||||||
|
match self.replay_buffer.random_batch(batch_size)? {
|
||||||
|
Some(v) => v,
|
||||||
|
_ => return Ok(()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let q_target = self
|
||||||
|
.critic
|
||||||
|
.target_forward(&next_states, &self.actor.target_forward(&next_states)?)?;
|
||||||
|
let q_target = (rewards + (self.gamma * q_target)?.detach())?;
|
||||||
|
let q = self.critic.forward(&states, &actions)?;
|
||||||
|
let diff = (q_target - q)?;
|
||||||
|
|
||||||
|
let critic_loss = diff.sqr()?.mean_all()?;
|
||||||
|
self.critic_optim.backward_step(&critic_loss)?;
|
||||||
|
|
||||||
|
let actor_loss = self
|
||||||
|
.critic
|
||||||
|
.forward(&states, &self.actor.forward(&states)?)?
|
||||||
|
.mean_all()?
|
||||||
|
.neg()?;
|
||||||
|
self.actor_optim.backward_step(&actor_loss)?;
|
||||||
|
|
||||||
|
self.critic.track(self.tau)?;
|
||||||
|
self.actor.track(self.tau)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -7,20 +7,22 @@ use pyo3::types::PyDict;
|
|||||||
/// The return value for a step.
|
/// The return value for a step.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Step<A> {
|
pub struct Step<A> {
|
||||||
pub obs: Tensor,
|
pub state: Tensor,
|
||||||
pub action: A,
|
pub action: A,
|
||||||
pub reward: f64,
|
pub reward: f64,
|
||||||
pub is_done: bool,
|
pub terminated: bool,
|
||||||
|
pub truncated: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A: Copy> Step<A> {
|
impl<A: Copy> Step<A> {
|
||||||
/// Returns a copy of this step changing the observation tensor.
|
/// Returns a copy of this step changing the observation tensor.
|
||||||
pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> {
|
pub fn copy_with_obs(&self, state: &Tensor) -> Step<A> {
|
||||||
Step {
|
Step {
|
||||||
obs: obs.clone(),
|
state: state.clone(),
|
||||||
action: self.action,
|
action: self.action,
|
||||||
reward: self.reward,
|
reward: self.reward,
|
||||||
is_done: self.is_done,
|
terminated: self.terminated,
|
||||||
|
truncated: self.truncated,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -63,14 +65,14 @@ impl GymEnv {
|
|||||||
|
|
||||||
/// Resets the environment, returning the observation tensor.
|
/// Resets the environment, returning the observation tensor.
|
||||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||||
let obs: Vec<f32> = Python::with_gil(|py| {
|
let state: Vec<f32> = Python::with_gil(|py| {
|
||||||
let kwargs = PyDict::new(py);
|
let kwargs = PyDict::new(py);
|
||||||
kwargs.set_item("seed", seed)?;
|
kwargs.set_item("seed", seed)?;
|
||||||
let obs = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||||
obs.as_ref(py).get_item(0)?.extract()
|
state.as_ref(py).get_item(0)?.extract()
|
||||||
})
|
})
|
||||||
.map_err(w)?;
|
.map_err(w)?;
|
||||||
Tensor::new(obs, &Device::Cpu)
|
Tensor::new(state, &Device::Cpu)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies an environment step using the specified action.
|
/// Applies an environment step using the specified action.
|
||||||
@ -78,21 +80,23 @@ impl GymEnv {
|
|||||||
&self,
|
&self,
|
||||||
action: A,
|
action: A,
|
||||||
) -> Result<Step<A>> {
|
) -> Result<Step<A>> {
|
||||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
||||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
||||||
let step = step.as_ref(py);
|
let step = step.as_ref(py);
|
||||||
let obs: Vec<f32> = step.get_item(0)?.extract()?;
|
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
||||||
let reward: f64 = step.get_item(1)?.extract()?;
|
let reward: f64 = step.get_item(1)?.extract()?;
|
||||||
let is_done: bool = step.get_item(2)?.extract()?;
|
let terminated: bool = step.get_item(2)?.extract()?;
|
||||||
Ok((obs, reward, is_done))
|
let truncated: bool = step.get_item(3)?.extract()?;
|
||||||
|
Ok((state, reward, terminated, truncated))
|
||||||
})
|
})
|
||||||
.map_err(w)?;
|
.map_err(w)?;
|
||||||
let obs = Tensor::new(obs, &Device::Cpu)?;
|
let state = Tensor::new(state, &Device::Cpu)?;
|
||||||
Ok(Step {
|
Ok(Step {
|
||||||
obs,
|
state,
|
||||||
reward,
|
|
||||||
is_done,
|
|
||||||
action,
|
action,
|
||||||
|
reward,
|
||||||
|
terminated,
|
||||||
|
truncated,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,18 +6,37 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
mod ddpg;
|
|
||||||
mod gym_env;
|
mod gym_env;
|
||||||
mod vec_gym_env;
|
mod vec_gym_env;
|
||||||
|
|
||||||
use candle::Result;
|
mod ddpg;
|
||||||
|
|
||||||
|
use candle::{Device, Result, Tensor};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
|
||||||
|
// The impact of the q value of the next state on the current state's q value.
|
||||||
|
const GAMMA: f64 = 0.99;
|
||||||
|
// The weight for updating the target networks.
|
||||||
|
const TAU: f64 = 0.005;
|
||||||
|
// The capacity of the replay buffer used for sampling training data.
|
||||||
|
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
||||||
|
// The training batch size for each training iteration.
|
||||||
|
const TRAINING_BATCH_SIZE: usize = 100;
|
||||||
// The total number of episodes.
|
// The total number of episodes.
|
||||||
const MAX_EPISODES: usize = 100;
|
const MAX_EPISODES: usize = 100;
|
||||||
// The maximum length of an episode.
|
// The maximum length of an episode.
|
||||||
const EPISODE_LENGTH: usize = 200;
|
const EPISODE_LENGTH: usize = 200;
|
||||||
|
// The number of training iterations after one episode finishes.
|
||||||
|
const TRAINING_ITERATIONS: usize = 200;
|
||||||
|
|
||||||
|
// Ornstein-Uhlenbeck process parameters.
|
||||||
|
const MU: f64 = 0.0;
|
||||||
|
const THETA: f64 = 0.15;
|
||||||
|
const SIGMA: f64 = 0.1;
|
||||||
|
|
||||||
|
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
||||||
|
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
@ -49,28 +68,77 @@ fn main() -> Result<()> {
|
|||||||
println!("action space: {}", env.action_space());
|
println!("action space: {}", env.action_space());
|
||||||
println!("observation space: {:?}", env.observation_space());
|
println!("observation space: {:?}", env.observation_space());
|
||||||
|
|
||||||
let _num_obs = env.observation_space().iter().product::<usize>();
|
let size_state = env.observation_space().iter().product::<usize>();
|
||||||
let _num_actions = env.action_space();
|
let size_action = env.action_space();
|
||||||
|
|
||||||
|
let mut agent = ddpg::DDPG::new(
|
||||||
|
&Device::Cpu,
|
||||||
|
size_state,
|
||||||
|
size_action,
|
||||||
|
true,
|
||||||
|
ACTOR_LEARNING_RATE,
|
||||||
|
CRITIC_LEARNING_RATE,
|
||||||
|
GAMMA,
|
||||||
|
TAU,
|
||||||
|
REPLAY_BUFFER_CAPACITY,
|
||||||
|
ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
||||||
|
)?;
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
for episode in 0..MAX_EPISODES {
|
for episode in 0..MAX_EPISODES {
|
||||||
let mut obs = env.reset(episode as u64)?;
|
// let mut state = env.reset(episode as u64)?;
|
||||||
|
let mut state = env.reset(rng.gen::<u64>())?;
|
||||||
|
|
||||||
let mut total_reward = 0.0;
|
let mut total_reward = 0.0;
|
||||||
for _ in 0..EPISODE_LENGTH {
|
for _ in 0..EPISODE_LENGTH {
|
||||||
let actions = rng.gen_range(-2.0..2.0);
|
let mut action = 2.0 * agent.actions(&state)?;
|
||||||
|
action = action.clamp(-2.0, 2.0);
|
||||||
|
|
||||||
let step = env.step(vec![actions])?;
|
let step = env.step(vec![action])?;
|
||||||
total_reward += step.reward;
|
total_reward += step.reward;
|
||||||
|
|
||||||
if step.is_done {
|
agent.remember(
|
||||||
|
&state,
|
||||||
|
&Tensor::new(vec![action], &Device::Cpu)?,
|
||||||
|
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
|
||||||
|
&step.state,
|
||||||
|
step.terminated,
|
||||||
|
step.truncated,
|
||||||
|
);
|
||||||
|
|
||||||
|
if step.terminated || step.truncated {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
obs = step.obs;
|
state = step.state;
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("episode {episode} with total reward of {total_reward}");
|
println!("episode {episode} with total reward of {total_reward}");
|
||||||
|
|
||||||
|
for _ in 0..TRAINING_ITERATIONS {
|
||||||
|
agent.train(TRAINING_BATCH_SIZE)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("Testing...");
|
||||||
|
agent.train = false;
|
||||||
|
for episode in 0..10 {
|
||||||
|
// let mut state = env.reset(episode as u64)?;
|
||||||
|
let mut state = env.reset(rng.gen::<u64>())?;
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
for _ in 0..EPISODE_LENGTH {
|
||||||
|
let mut action = 2.0 * agent.actions(&state)?;
|
||||||
|
action = action.clamp(-2.0, 2.0);
|
||||||
|
|
||||||
|
let step = env.step(vec![action])?;
|
||||||
|
total_reward += step.reward;
|
||||||
|
|
||||||
|
if step.terminated || step.truncated {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
state = step.state;
|
||||||
|
}
|
||||||
|
println!("episode {episode} with total reward of {total_reward}");
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
40
candle-examples/examples/replit-code/README.md
Normal file
40
candle-examples/examples/replit-code/README.md
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# candle-replit-code: code completion specialized model.
|
||||||
|
|
||||||
|
[replit-code-v1_5-3b](https://huggingface.co/replit/replit-code-v1_5-3b) is a
|
||||||
|
language model specialized for code completion. This model uses 3.3B parameters
|
||||||
|
in `bfloat16` (so the GPU version will only work on recent nvidia cards).
|
||||||
|
|
||||||
|
## Running some example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example replit-code --release -- --prompt 'def fibonacci(n): '
|
||||||
|
```
|
||||||
|
This produces the following output.
|
||||||
|
|
||||||
|
```
|
||||||
|
def fibonacci(n): # write Fibonacci series up to n
|
||||||
|
"""Print a Fibonacci series up to n."""
|
||||||
|
a, b = 0, 1
|
||||||
|
while a < n:
|
||||||
|
print(a, end=' ')
|
||||||
|
a, b = b, a+b
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def fibonacci_loop(n): # write Fibonacci series up to n
|
||||||
|
"""Print a Fibonacci series up to n."""
|
||||||
|
result = []
|
||||||
|
a, b = 0, 1
|
||||||
|
while a < n:
|
||||||
|
result.append(a)
|
||||||
|
a, b = b, a+b
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def fibonacci_generator(n): # write Fibonacci series up to n
|
||||||
|
"""Print a Fibonacci series up to n."""
|
||||||
|
a, b = 0, 1
|
||||||
|
while a < n:
|
||||||
|
yield a
|
||||||
|
a, b = b, a+b
|
||||||
|
```
|
265
candle-examples/examples/replit-code/main.rs
Normal file
265
candle-examples/examples/replit-code/main.rs
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::mpt::{Config, Model as M};
|
||||||
|
use candle_transformers::models::quantized_mpt::Model as Q;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
M(M),
|
||||||
|
Q(Q),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&mut self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::M(model) => model.forward(xs),
|
||||||
|
Self::Q(model) => model.forward(xs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
verbose_prompt,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
println!("starting the inference loop");
|
||||||
|
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||||
|
if tokens.is_empty() {
|
||||||
|
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
||||||
|
}
|
||||||
|
if self.verbose_prompt {
|
||||||
|
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||||
|
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
println!("{id:7} -> '{token}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut tokens = tokens.get_ids().to_vec();
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||||
|
Some(token) => *token,
|
||||||
|
None => anyhow::bail!("cannot find the endoftext token"),
|
||||||
|
};
|
||||||
|
print!("{prompt}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input)?;
|
||||||
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||||
|
print!("{token}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Display the token for the specified prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
verbose_prompt: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 1000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id.to_string(),
|
||||||
|
None => "lmz/candle-replit-code".to_string(),
|
||||||
|
};
|
||||||
|
let revision = match args.revision {
|
||||||
|
Some(rev) => rev.to_string(),
|
||||||
|
None => "main".to_string(),
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
let tokenizer_filename = match args.tokenizer {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let filename = match args.weight_file {
|
||||||
|
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||||
|
None => {
|
||||||
|
if args.quantized {
|
||||||
|
repo.get("model-replit-code-v1_5-q4k.gguf")?
|
||||||
|
} else {
|
||||||
|
repo.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = Config::replit_code_v1_5_3b();
|
||||||
|
let (model, device) = if args.quantized {
|
||||||
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||||
|
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
|
||||||
|
(model, Device::Cpu)
|
||||||
|
} else {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||||
|
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
|
||||||
|
(model, device)
|
||||||
|
};
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
args.verbose_prompt,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
19
candle-examples/examples/resnet/README.md
Normal file
19
candle-examples/examples/resnet/README.md
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# candle-resnet
|
||||||
|
|
||||||
|
A candle implementation of inference using a pre-trained [ResNet](https://arxiv.org/abs/1512.03385).
|
||||||
|
This uses a classification head trained on the ImageNet dataset and returns the
|
||||||
|
probabilities for the top-5 classes.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example resnet --release -- --image tiger.jpg
|
||||||
|
|
||||||
|
loaded image Tensor[dims 3, 224, 224; f32]
|
||||||
|
model built
|
||||||
|
tiger, Panthera tigris : 90.21%
|
||||||
|
tiger cat : 8.93%
|
||||||
|
lion, king of beasts, Panthera leo: 0.35%
|
||||||
|
leopard, Panthera pardus: 0.16%
|
||||||
|
jaguar, panther, Panthera onca, Felis onca: 0.09%
|
||||||
|
```
|
12
candle-examples/examples/resnet/export_models.py
Normal file
12
candle-examples/examples/resnet/export_models.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# This script exports pre-trained model weights in the safetensors format.
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from safetensors import torch as stt
|
||||||
|
|
||||||
|
m = torchvision.models.resnet50(pretrained=True)
|
||||||
|
stt.save_file(m.state_dict(), 'resnet50.safetensors')
|
||||||
|
m = torchvision.models.resnet101(pretrained=True)
|
||||||
|
stt.save_file(m.state_dict(), 'resnet101.safetensors')
|
||||||
|
m = torchvision.models.resnet152(pretrained=True)
|
||||||
|
stt.save_file(m.state_dict(), 'resnet152.safetensors')
|
90
candle-examples/examples/resnet/main.rs
Normal file
90
candle-examples/examples/resnet/main.rs
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::resnet;
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "18")]
|
||||||
|
Resnet18,
|
||||||
|
#[value(name = "34")]
|
||||||
|
Resnet34,
|
||||||
|
#[value(name = "50")]
|
||||||
|
Resnet50,
|
||||||
|
#[value(name = "101")]
|
||||||
|
Resnet101,
|
||||||
|
#[value(name = "152")]
|
||||||
|
Resnet152,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Variant of the model to use.
|
||||||
|
#[arg(value_enum, long, default_value_t = Which::Resnet18)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("lmz/candle-resnet".into());
|
||||||
|
let filename = match args.which {
|
||||||
|
Which::Resnet18 => "resnet18.safetensors",
|
||||||
|
Which::Resnet34 => "resnet34.safetensors",
|
||||||
|
Which::Resnet50 => "resnet50.safetensors",
|
||||||
|
Which::Resnet101 => "resnet101.safetensors",
|
||||||
|
Which::Resnet152 => "resnet152.safetensors",
|
||||||
|
};
|
||||||
|
api.get(filename)?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let class_count = candle_examples::imagenet::CLASS_COUNT as usize;
|
||||||
|
let model = match args.which {
|
||||||
|
Which::Resnet18 => resnet::resnet18(class_count, vb)?,
|
||||||
|
Which::Resnet34 => resnet::resnet34(class_count, vb)?,
|
||||||
|
Which::Resnet50 => resnet::resnet50(class_count, vb)?,
|
||||||
|
Which::Resnet101 => resnet::resnet101(class_count, vb)?,
|
||||||
|
Which::Resnet152 => resnet::resnet152(class_count, vb)?,
|
||||||
|
};
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -50,6 +50,9 @@ cached.
|
|||||||
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
|
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
|
||||||
and using the command line flag `--use-flash-attn`.
|
and using the command line flag `--use-flash-attn`.
|
||||||
|
|
||||||
|
Note that flash-attention-v2 is only compatible with Ampere, Ada, or Hopper GPUs
|
||||||
|
(e.g., A100/H100, RTX 3090/4090).
|
||||||
|
|
||||||
## Image to Image Pipeline
|
## Image to Image Pipeline
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -416,7 +416,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
|
|
||||||
println!("Building the autoencoder.");
|
println!("Building the autoencoder.");
|
||||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
||||||
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
|
let vae = sd_config.build_vae(vae_weights, &device, dtype)?;
|
||||||
let init_latent_dist = match &img2img {
|
let init_latent_dist = match &img2img {
|
||||||
None => None,
|
None => None,
|
||||||
Some(image) => {
|
Some(image) => {
|
||||||
@ -426,7 +426,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
println!("Building the unet.");
|
println!("Building the unet.");
|
||||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
||||||
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
|
let unet = sd_config.build_unet(unet_weights, &device, 4, use_flash_attn, dtype)?;
|
||||||
|
|
||||||
let t_start = if img2img.is_some() {
|
let t_start = if img2img.is_some() {
|
||||||
n_steps - (n_steps as f64 * img2img_strength) as usize
|
n_steps - (n_steps as f64 * img2img_strength) as usize
|
||||||
|
@ -5,12 +5,26 @@
|
|||||||
```bash
|
```bash
|
||||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
||||||
...
|
...
|
||||||
Running on CPU, to run on GPU, build this example with `--features cuda`
|
|
||||||
Eine schöne Kerze.
|
Eine schöne Kerze.
|
||||||
9 tokens generated (2.42 token/s)
|
9 tokens generated (2.42 token/s)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Sentence embedding example:
|
Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported.
|
||||||
|
|
||||||
|
## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)
|
||||||
|
|
||||||
|
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example t5 --release -- \
|
||||||
|
--model-id "jbochi/madlad400-3b-mt" \
|
||||||
|
--prompt "<2de> How are you, my friend?" \
|
||||||
|
--decode --temperature 0
|
||||||
|
...
|
||||||
|
Wie geht es dir, mein Freund?
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sentence embedding example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
|
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
|
||||||
|
@ -104,6 +104,17 @@ impl T5ModelBuilder {
|
|||||||
api.get("model-00004-of-00005.safetensors")?,
|
api.get("model-00004-of-00005.safetensors")?,
|
||||||
api.get("model-00005-of-00005.safetensors")?,
|
api.get("model-00005-of-00005.safetensors")?,
|
||||||
]
|
]
|
||||||
|
} else if model_id == "google/flan-ul2" {
|
||||||
|
vec![
|
||||||
|
api.get("model-00001-of-00008.safetensors")?,
|
||||||
|
api.get("model-00002-of-00008.safetensors")?,
|
||||||
|
api.get("model-00003-of-00008.safetensors")?,
|
||||||
|
api.get("model-00004-of-00008.safetensors")?,
|
||||||
|
api.get("model-00005-of-00008.safetensors")?,
|
||||||
|
api.get("model-00006-of-00008.safetensors")?,
|
||||||
|
api.get("model-00007-of-00008.safetensors")?,
|
||||||
|
api.get("model-00008-of-00008.safetensors")?,
|
||||||
|
]
|
||||||
} else {
|
} else {
|
||||||
vec![api.get("model.safetensors")?]
|
vec![api.get("model.safetensors")?]
|
||||||
};
|
};
|
||||||
@ -172,7 +183,12 @@ fn main() -> Result<()> {
|
|||||||
println!("Took {:?}", start.elapsed());
|
println!("Took {:?}", start.elapsed());
|
||||||
} else {
|
} else {
|
||||||
let mut model = builder.build_conditional_generation()?;
|
let mut model = builder.build_conditional_generation()?;
|
||||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
let mut output_token_ids = [builder
|
||||||
|
.config
|
||||||
|
.decoder_start_token_id
|
||||||
|
.unwrap_or(builder.config.pad_token_id)
|
||||||
|
as u32]
|
||||||
|
.to_vec();
|
||||||
if let Some(decoder_prompt) = &args.decoder_prompt {
|
if let Some(decoder_prompt) = &args.decoder_prompt {
|
||||||
print!("{decoder_prompt}");
|
print!("{decoder_prompt}");
|
||||||
output_token_ids.extend(
|
output_token_ids.extend(
|
||||||
|
BIN
candle-examples/examples/trocr/assets/trocr.png
Normal file
BIN
candle-examples/examples/trocr/assets/trocr.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
154
candle-examples/examples/trocr/image_processor.rs
Normal file
154
candle-examples/examples/trocr/image_processor.rs
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
use image::{DynamicImage, ImageBuffer};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Result, Tensor};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
|
pub struct ProcessorConfig {
|
||||||
|
do_resize: bool,
|
||||||
|
height: u32,
|
||||||
|
width: u32,
|
||||||
|
do_rescale: bool,
|
||||||
|
do_normalize: bool,
|
||||||
|
image_mean: Vec<f32>,
|
||||||
|
image_std: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ProcessorConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
do_resize: true,
|
||||||
|
height: 384,
|
||||||
|
width: 384,
|
||||||
|
do_rescale: true,
|
||||||
|
do_normalize: true,
|
||||||
|
image_mean: vec![0.5, 0.5, 0.5],
|
||||||
|
image_std: vec![0.5, 0.5, 0.5],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ViTImageProcessor {
|
||||||
|
do_resize: bool,
|
||||||
|
height: u32,
|
||||||
|
width: u32,
|
||||||
|
do_normalize: bool,
|
||||||
|
image_mean: Vec<f32>,
|
||||||
|
image_std: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ViTImageProcessor {
|
||||||
|
pub fn new(config: &ProcessorConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
do_resize: config.do_resize,
|
||||||
|
height: config.height,
|
||||||
|
width: config.width,
|
||||||
|
do_normalize: config.do_normalize,
|
||||||
|
image_mean: config.image_mean.clone(),
|
||||||
|
image_std: config.image_std.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn preprocess(&self, images: Vec<&str>) -> Result<Tensor> {
|
||||||
|
let height = self.height as usize;
|
||||||
|
let width = self.width as usize;
|
||||||
|
let channels = 3;
|
||||||
|
|
||||||
|
let images = self.load_images(images)?;
|
||||||
|
|
||||||
|
let resized_images: Vec<DynamicImage> = if self.do_resize {
|
||||||
|
images
|
||||||
|
.iter()
|
||||||
|
.map(|image| self.resize(image.clone(), None).unwrap())
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
images
|
||||||
|
};
|
||||||
|
|
||||||
|
let normalized_images: Vec<Tensor> = if self.do_normalize {
|
||||||
|
resized_images
|
||||||
|
.iter()
|
||||||
|
.map(|image| self.normalize(image.clone(), None, None).unwrap())
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
let resized_images: Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>> =
|
||||||
|
resized_images.iter().map(|image| image.to_rgb8()).collect();
|
||||||
|
let data = resized_images
|
||||||
|
.into_iter()
|
||||||
|
.map(|image| image.into_raw())
|
||||||
|
.collect::<Vec<Vec<u8>>>();
|
||||||
|
|
||||||
|
data.iter()
|
||||||
|
.map(|image| {
|
||||||
|
Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu)
|
||||||
|
.unwrap()
|
||||||
|
.permute((2, 0, 1))
|
||||||
|
.unwrap()
|
||||||
|
})
|
||||||
|
.collect::<Vec<Tensor>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
Tensor::stack(&normalized_images, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resize(
|
||||||
|
&self,
|
||||||
|
image: image::DynamicImage,
|
||||||
|
size: Option<HashMap<String, u32>>,
|
||||||
|
) -> Result<image::DynamicImage> {
|
||||||
|
let (height, width) = match &size {
|
||||||
|
Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()),
|
||||||
|
None => (&self.height, &self.width),
|
||||||
|
};
|
||||||
|
|
||||||
|
let resized_image =
|
||||||
|
image.resize_exact(*width, *height, image::imageops::FilterType::Triangle);
|
||||||
|
|
||||||
|
Ok(resized_image)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize(
|
||||||
|
&self,
|
||||||
|
image: image::DynamicImage,
|
||||||
|
mean: Option<Vec<f32>>,
|
||||||
|
std: Option<Vec<f32>>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let mean = match mean {
|
||||||
|
Some(mean) => mean,
|
||||||
|
None => self.image_mean.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let std = match std {
|
||||||
|
Some(std) => std,
|
||||||
|
None => self.image_std.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?;
|
||||||
|
let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let image = image.to_rgb8();
|
||||||
|
let data = image.into_raw();
|
||||||
|
|
||||||
|
let height = self.height as usize;
|
||||||
|
let width = self.width as usize;
|
||||||
|
let channels = 3;
|
||||||
|
|
||||||
|
let data =
|
||||||
|
Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?;
|
||||||
|
|
||||||
|
(data.to_dtype(DType::F32)? / 255.)?
|
||||||
|
.broadcast_sub(&mean)?
|
||||||
|
.broadcast_div(&std)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
|
||||||
|
let mut images: Vec<image::DynamicImage> = Vec::new();
|
||||||
|
for path in image_path {
|
||||||
|
let img = image::io::Reader::open(path)?.decode().unwrap();
|
||||||
|
images.push(img);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(images)
|
||||||
|
}
|
||||||
|
}
|
132
candle-examples/examples/trocr/main.rs
Normal file
132
candle-examples/examples/trocr/main.rs
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::trocr;
|
||||||
|
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
mod image_processor;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
Base,
|
||||||
|
Large,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// Choose the variant of the model to run.
|
||||||
|
#[arg(long, default_value = "base")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Text to be translated
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let tokenizer_dec = {
|
||||||
|
let tokenizer = Api::new()?
|
||||||
|
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||||
|
.get("tokenizer.json")?;
|
||||||
|
|
||||||
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let vb = {
|
||||||
|
let model = match args.model {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => match args.which {
|
||||||
|
Which::Base => Api::new()?
|
||||||
|
.repo(hf_hub::Repo::with_revision(
|
||||||
|
"microsoft/trocr-base-handwritten".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/3".to_string(),
|
||||||
|
))
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
Which::Large => Api::new()?
|
||||||
|
.repo(hf_hub::Repo::with_revision(
|
||||||
|
"microsoft/trocr-large-handwritten".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/6".to_string(),
|
||||||
|
))
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
println!("model: {:?}", model);
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoder_config = match args.which {
|
||||||
|
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
|
||||||
|
Which::Large => {
|
||||||
|
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let decoder_config = trocr::TrOCRConfig::default();
|
||||||
|
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
||||||
|
|
||||||
|
let config = image_processor::ProcessorConfig::default();
|
||||||
|
let processor = image_processor::ViTImageProcessor::new(&config);
|
||||||
|
|
||||||
|
let image = vec![args.image.as_str()];
|
||||||
|
let image = processor.preprocess(image)?;
|
||||||
|
|
||||||
|
let encoder_xs = model.encoder().forward(&image)?;
|
||||||
|
|
||||||
|
let mut logits_processor =
|
||||||
|
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||||
|
|
||||||
|
let mut token_ids: Vec<u32> = vec![decoder_config.decoder_start_token_id];
|
||||||
|
for index in 0..1000 {
|
||||||
|
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||||
|
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||||
|
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
|
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||||
|
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||||
|
let token = logits_processor.sample(&logits)?;
|
||||||
|
token_ids.push(token);
|
||||||
|
|
||||||
|
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||||
|
use std::io::Write;
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
if token == decoder_config.eos_token_id {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
16
candle-examples/examples/trocr/readme.md
Normal file
16
candle-examples/examples/trocr/readme.md
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# candle-trocr
|
||||||
|
|
||||||
|
`TrOCR` is a transformer OCR Model. In this example it is used to
|
||||||
|
transcribe image text. See the associated [model
|
||||||
|
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
|
||||||
|
the model itself.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
<s> industry , Mr. Brown commented icily . " Let us have a</s>
|
||||||
|
```
|
13
candle-examples/examples/vgg/README.md
Normal file
13
candle-examples/examples/vgg/README.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
## VGG Model Implementation
|
||||||
|
|
||||||
|
This example demonstrates the implementation of VGG models (VGG13, VGG16, VGG19) using the Candle library.
|
||||||
|
|
||||||
|
The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main function in `candle-examples/examples/vgg/main.rs` loads an image, selects the VGG model based on the provided argument, and applies the model to the loaded image.
|
||||||
|
|
||||||
|
You can run the example with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13
|
||||||
|
```
|
||||||
|
|
||||||
|
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).
|
77
candle-examples/examples/vgg/main.rs
Normal file
77
candle-examples/examples/vgg/main.rs
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{ModuleT, VarBuilder};
|
||||||
|
use candle_transformers::models::vgg::{Models, Vgg};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
Vgg13,
|
||||||
|
Vgg16,
|
||||||
|
Vgg19,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Variant of the model to use.
|
||||||
|
#[arg(value_enum, long, default_value_t = Which::Vgg13)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
|
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let repo = match args.which {
|
||||||
|
Which::Vgg13 => "timm/vgg13.tv_in1k",
|
||||||
|
Which::Vgg16 => "timm/vgg16.tv_in1k",
|
||||||
|
Which::Vgg19 => "timm/vgg19.tv_in1k",
|
||||||
|
};
|
||||||
|
let api = api.model(repo.into());
|
||||||
|
let filename = "model.safetensors";
|
||||||
|
let model_file = api.get(filename)?;
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = match args.which {
|
||||||
|
Which::Vgg13 => Vgg::new(vb, Models::Vgg13)?,
|
||||||
|
Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?,
|
||||||
|
Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?,
|
||||||
|
};
|
||||||
|
let logits = model.forward_t(&image, /*train=*/ false)?;
|
||||||
|
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
// Sort the predictions and take the top 5
|
||||||
|
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
||||||
|
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||||
|
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Print the top predictions
|
||||||
|
for &(i, p) in &top {
|
||||||
|
println!(
|
||||||
|
"{:50}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[i],
|
||||||
|
p * 100.0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
20
candle-examples/examples/vit/README.md
Normal file
20
candle-examples/examples/vit/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-vit
|
||||||
|
|
||||||
|
Vision Transformer (ViT) model implementation following the lines of
|
||||||
|
[vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224)
|
||||||
|
This uses a classification head trained on the ImageNet dataset and returns the
|
||||||
|
probabilities for the top-5 classes.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example vit --release -- --image tiger.jpg
|
||||||
|
|
||||||
|
loaded image Tensor[dims 3, 224, 224; f32]
|
||||||
|
model built
|
||||||
|
tiger, Panthera tigris : 100.00%
|
||||||
|
tiger cat : 0.00%
|
||||||
|
jaguar, panther, Panthera onca, Felis onca: 0.00%
|
||||||
|
leopard, Panthera pardus: 0.00%
|
||||||
|
lion, king of beasts, Panthera leo: 0.00%
|
||||||
|
```
|
59
candle-examples/examples/vit/main.rs
Normal file
59
candle-examples/examples/vit/main.rs
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::vit;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("google/vit-base-patch16-224".into());
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = vit::Model::new(&vit::Config::vit_base_patch16_224(), 1000, vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -128,7 +128,13 @@ impl Decoder {
|
|||||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||||
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
let no_speech_token = m::NO_SPEECH_TOKENS
|
||||||
|
.iter()
|
||||||
|
.find_map(|token| token_id(&tokenizer, token).ok());
|
||||||
|
let no_speech_token = match no_speech_token {
|
||||||
|
None => anyhow::bail!("unable to find any non-speech token"),
|
||||||
|
Some(n) => n,
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||||
@ -345,7 +351,7 @@ enum Task {
|
|||||||
Translate,
|
Translate,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
|
||||||
enum WhichModel {
|
enum WhichModel {
|
||||||
Tiny,
|
Tiny,
|
||||||
#[value(name = "tiny.en")]
|
#[value(name = "tiny.en")]
|
||||||
@ -361,15 +367,27 @@ enum WhichModel {
|
|||||||
MediumEn,
|
MediumEn,
|
||||||
Large,
|
Large,
|
||||||
LargeV2,
|
LargeV2,
|
||||||
|
LargeV3,
|
||||||
|
#[value(name = "distil-medium.en")]
|
||||||
|
DistilMediumEn,
|
||||||
|
#[value(name = "distil-large-v2")]
|
||||||
|
DistilLargeV2,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WhichModel {
|
impl WhichModel {
|
||||||
fn is_multilingual(&self) -> bool {
|
fn is_multilingual(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
Self::Tiny | Self::Base | Self::Small | Self::Medium | Self::Large | Self::LargeV2 => {
|
Self::Tiny
|
||||||
true
|
| Self::Base
|
||||||
|
| Self::Small
|
||||||
|
| Self::Medium
|
||||||
|
| Self::Large
|
||||||
|
| Self::LargeV2
|
||||||
|
| Self::LargeV3
|
||||||
|
| Self::DistilLargeV2 => true,
|
||||||
|
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
|
||||||
|
false
|
||||||
}
|
}
|
||||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -385,6 +403,9 @@ impl WhichModel {
|
|||||||
Self::MediumEn => ("openai/whisper-medium.en", "main"),
|
Self::MediumEn => ("openai/whisper-medium.en", "main"),
|
||||||
Self::Large => ("openai/whisper-large", "refs/pr/36"),
|
Self::Large => ("openai/whisper-large", "refs/pr/36"),
|
||||||
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
|
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
|
||||||
|
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
|
||||||
|
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
|
||||||
|
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -496,17 +517,21 @@ fn main() -> Result<()> {
|
|||||||
repo.get(&format!("model-{ext}-q80.gguf"))?,
|
repo.get(&format!("model-{ext}-q80.gguf"))?,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
(
|
let config = repo.get("config.json")?;
|
||||||
repo.get("config.json")?,
|
let tokenizer = repo.get("tokenizer.json")?;
|
||||||
repo.get("tokenizer.json")?,
|
let model = repo.get("model.safetensors")?;
|
||||||
repo.get("model.safetensors")?,
|
(config, tokenizer, model)
|
||||||
)
|
|
||||||
};
|
};
|
||||||
(config, tokenizer, model, sample)
|
(config, tokenizer, model, sample)
|
||||||
};
|
};
|
||||||
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let mel_bytes = include_bytes!("melfilters.bytes");
|
let mel_bytes = match config.num_mel_bins {
|
||||||
|
80 => include_bytes!("melfilters.bytes").as_slice(),
|
||||||
|
128 => include_bytes!("melfilters128.bytes").as_slice(),
|
||||||
|
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
|
||||||
|
};
|
||||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||||
|
|
||||||
@ -522,12 +547,15 @@ fn main() -> Result<()> {
|
|||||||
.map(|v| *v as f32 / 32768.)
|
.map(|v| *v as f32 / 32768.)
|
||||||
.collect();
|
.collect();
|
||||||
println!("pcm data loaded {}", pcm_data.len());
|
println!("pcm data loaded {}", pcm_data.len());
|
||||||
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
|
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
||||||
let mel_len = mel.len();
|
let mel_len = mel.len();
|
||||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
let mel = Tensor::from_vec(
|
||||||
|
mel,
|
||||||
|
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
println!("loaded mel: {:?}", mel.dims());
|
println!("loaded mel: {:?}", mel.dims());
|
||||||
|
|
||||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
|
||||||
let mut model = if args.quantized {
|
let mut model = if args.quantized {
|
||||||
let vb =
|
let vb =
|
||||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
||||||
|
BIN
candle-examples/examples/whisper/melfilters128.bytes
Normal file
BIN
candle-examples/examples/whisper/melfilters128.bytes
Normal file
Binary file not shown.
268
candle-examples/examples/yi/main.rs
Normal file
268
candle-examples/examples/yi/main.rs
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle_transformers::models::yi::{Config, Model};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "6b")]
|
||||||
|
L6b,
|
||||||
|
#[value(name = "34b")]
|
||||||
|
L34b,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||||
|
};
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "01-ai/Yi-6B")]
|
||||||
|
model_id: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "6b")]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
args.model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => match args.which {
|
||||||
|
Which::L6b => vec![
|
||||||
|
repo.get("model-00001-of-00002.safetensors")?,
|
||||||
|
repo.get("model-00002-of-00002.safetensors")?,
|
||||||
|
],
|
||||||
|
Which::L34b => vec![
|
||||||
|
repo.get("model-00001-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00002-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00003-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00004-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00005-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00006-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00007-of-00007.safetensors")?,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = match args.which {
|
||||||
|
Which::L6b => Config::config_6b(),
|
||||||
|
Which::L34b => Config::config_34b(),
|
||||||
|
};
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -108,7 +108,7 @@ pub fn parse_config<T: AsRef<Path>>(path: T) -> Result<Darknet> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
enum Bl {
|
enum Bl {
|
||||||
Layer(Box<dyn candle_nn::Module + Send>),
|
Layer(Box<dyn candle_nn::Module + Send + Sync>),
|
||||||
Route(Vec<usize>),
|
Route(Vec<usize>),
|
||||||
Shortcut(usize),
|
Shortcut(usize),
|
||||||
Yolo(usize, Vec<(usize, usize)>),
|
Yolo(usize, Vec<(usize, usize)>),
|
||||||
|
@ -43,6 +43,7 @@ pub fn report(
|
|||||||
confidence_threshold: f32,
|
confidence_threshold: f32,
|
||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
|
let pred = pred.to_device(&Device::Cpu)?;
|
||||||
let (npreds, pred_size) = pred.dims2()?;
|
let (npreds, pred_size) = pred.dims2()?;
|
||||||
let nclasses = pred_size - 5;
|
let nclasses = pred_size - 5;
|
||||||
// The bounding boxes grouped by (maximum) class index.
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
|
@ -32,7 +32,7 @@ Image source:
|
|||||||
### Pose Estimation
|
### Pose Estimation
|
||||||
```bash
|
```bash
|
||||||
cargo run --example yolo-v8 --release -- \
|
cargo run --example yolo-v8 --release -- \
|
||||||
candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose
|
candle-examples/examples/yolo-v8/assets/bike.jpg --task pose
|
||||||
```
|
```
|
||||||
|
|
||||||

|

|
||||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
|||||||
mod model;
|
mod model;
|
||||||
use model::{Multiples, YoloV8, YoloV8Pose};
|
use model::{Multiples, YoloV8, YoloV8Pose};
|
||||||
|
|
||||||
use candle::{DType, IndexOp, Result, Tensor};
|
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||||
use candle_nn::{Module, VarBuilder};
|
use candle_nn::{Module, VarBuilder};
|
||||||
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
@ -61,6 +61,7 @@ pub fn report_detect(
|
|||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
legend_size: u32,
|
legend_size: u32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
|
let pred = pred.to_device(&Device::Cpu)?;
|
||||||
let (pred_size, npreds) = pred.dims2()?;
|
let (pred_size, npreds) = pred.dims2()?;
|
||||||
let nclasses = pred_size - 4;
|
let nclasses = pred_size - 4;
|
||||||
// The bounding boxes grouped by (maximum) class index.
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
@ -153,6 +154,7 @@ pub fn report_pose(
|
|||||||
confidence_threshold: f32,
|
confidence_threshold: f32,
|
||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
|
let pred = pred.to_device(&Device::Cpu)?;
|
||||||
let (pred_size, npreds) = pred.dims2()?;
|
let (pred_size, npreds) = pred.dims2()?;
|
||||||
if pred_size != 17 * 3 + 4 + 1 {
|
if pred_size != 17 * 3 + 4 + 1 {
|
||||||
candle::bail!("unexpected pred-size {pred_size}");
|
candle::bail!("unexpected pred-size {pred_size}");
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{
|
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Conv2d, Conv2dConfig, Module, VarBuilder};
|
||||||
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, PartialEq, Debug)]
|
#[derive(Clone, Copy, PartialEq, Debug)]
|
||||||
pub struct Multiples {
|
pub struct Multiples {
|
||||||
@ -76,7 +74,6 @@ impl Module for Upsample {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct ConvBlock {
|
struct ConvBlock {
|
||||||
conv: Conv2d,
|
conv: Conv2d,
|
||||||
bn: BatchNorm,
|
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,11 +93,10 @@ impl ConvBlock {
|
|||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
};
|
};
|
||||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
|
||||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||||
|
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conv,
|
conv,
|
||||||
bn,
|
|
||||||
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
|
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -110,7 +106,6 @@ impl Module for ConvBlock {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let xs = self.conv.forward(xs)?;
|
let xs = self.conv.forward(xs)?;
|
||||||
let xs = self.bn.forward(&xs)?;
|
|
||||||
candle_nn::ops::silu(&xs)
|
candle_nn::ops::silu(&xs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,17 +2,28 @@ pub mod coco_classes;
|
|||||||
pub mod imagenet;
|
pub mod imagenet;
|
||||||
pub mod token_output_stream;
|
pub mod token_output_stream;
|
||||||
|
|
||||||
|
use candle::utils::{cuda_is_available, metal_is_available};
|
||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
pub fn device(cpu: bool) -> Result<Device> {
|
pub fn device(cpu: bool) -> Result<Device> {
|
||||||
if cpu {
|
if cpu {
|
||||||
Ok(Device::Cpu)
|
Ok(Device::Cpu)
|
||||||
|
} else if cuda_is_available() {
|
||||||
|
Ok(Device::new_cuda(0)?)
|
||||||
|
} else if metal_is_available() {
|
||||||
|
Ok(Device::new_metal(0)?)
|
||||||
} else {
|
} else {
|
||||||
let device = Device::cuda_if_available(0)?;
|
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||||
if !device.is_cuda() {
|
{
|
||||||
|
println!(
|
||||||
|
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||||
|
{
|
||||||
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
||||||
}
|
}
|
||||||
Ok(device)
|
Ok(Device::Cpu)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.1", package = "candle-core" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
@ -21,4 +21,4 @@ rayon = "1.7.0"
|
|||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0", features = ["cuda"] }
|
candle-nn = { path = "../candle-nn", version = "0.3.1", features = ["cuda"] }
|
||||||
|
@ -84,12 +84,19 @@ fn main() -> Result<()> {
|
|||||||
(kernel_dir.join(f), obj_file)
|
(kernel_dir.join(f), obj_file)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
|
||||||
let should_compile = if out_file.exists() {
|
let should_compile = if out_file.exists() {
|
||||||
cu_files.iter().any(|(cu_file, _)| {
|
kernel_dir
|
||||||
let out_modified = out_file.metadata().unwrap().modified().unwrap();
|
.read_dir()
|
||||||
let in_modified = cu_file.metadata().unwrap().modified().unwrap();
|
.expect("kernels folder should exist")
|
||||||
in_modified.duration_since(out_modified).is_ok()
|
.any(|entry| {
|
||||||
})
|
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
|
||||||
|
let in_modified = entry.metadata().unwrap().modified().unwrap();
|
||||||
|
in_modified.duration_since(*out_modified).is_ok()
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
true
|
true
|
||||||
};
|
};
|
||||||
@ -100,12 +107,19 @@ fn main() -> Result<()> {
|
|||||||
let mut command = std::process::Command::new("nvcc");
|
let mut command = std::process::Command::new("nvcc");
|
||||||
command
|
command
|
||||||
.arg("-std=c++17")
|
.arg("-std=c++17")
|
||||||
|
.arg("-O3")
|
||||||
|
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||||
|
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||||
|
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||||
|
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||||
.arg("-c")
|
.arg("-c")
|
||||||
.args(["-o", obj_file.to_str().unwrap()])
|
.args(["-o", obj_file.to_str().unwrap()])
|
||||||
.args(["--default-stream", "per-thread"])
|
.args(["--default-stream", "per-thread"])
|
||||||
.arg("-Icutlass/include")
|
.arg("-Icutlass/include")
|
||||||
.arg("--expt-relaxed-constexpr")
|
.arg("--expt-relaxed-constexpr")
|
||||||
|
.arg("--expt-extended-lambda")
|
||||||
|
.arg("--use_fast_math")
|
||||||
.arg("--verbose");
|
.arg("--verbose");
|
||||||
if let Ok(ccbin_path) = &ccbin_env {
|
if let Ok(ccbin_path) = &ccbin_env {
|
||||||
command
|
command
|
||||||
@ -203,13 +217,21 @@ fn set_cuda_include_dir() -> Result<()> {
|
|||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
fn compute_cap() -> Result<usize> {
|
fn compute_cap() -> Result<usize> {
|
||||||
// Grab compute code from nvidia-smi
|
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||||
let mut compute_cap = {
|
|
||||||
|
// Try to parse compute caps from env
|
||||||
|
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||||
|
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
||||||
|
compute_cap_str
|
||||||
|
.parse::<usize>()
|
||||||
|
.context("Could not parse compute cap")?
|
||||||
|
} else {
|
||||||
|
// Use nvidia-smi to get the current compute cap
|
||||||
let out = std::process::Command::new("nvidia-smi")
|
let out = std::process::Command::new("nvidia-smi")
|
||||||
.arg("--query-gpu=compute_cap")
|
.arg("--query-gpu=compute_cap")
|
||||||
.arg("--format=csv")
|
.arg("--format=csv")
|
||||||
.output()
|
.output()
|
||||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||||
let mut lines = out.lines();
|
let mut lines = out.lines();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -220,16 +242,19 @@ fn compute_cap() -> Result<usize> {
|
|||||||
.next()
|
.next()
|
||||||
.context("missing line in stdout")?
|
.context("missing line in stdout")?
|
||||||
.replace('.', "");
|
.replace('.', "");
|
||||||
cap.parse::<usize>()
|
let cap = cap
|
||||||
.with_context(|| format!("cannot parse as int {cap}"))?
|
.parse::<usize>()
|
||||||
|
.with_context(|| format!("cannot parse as int {cap}"))?;
|
||||||
|
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
||||||
|
cap
|
||||||
};
|
};
|
||||||
|
|
||||||
// Grab available GPU codes from nvcc and select the highest one
|
// Grab available GPU codes from nvcc and select the highest one
|
||||||
let max_nvcc_code = {
|
let (supported_nvcc_codes, max_nvcc_code) = {
|
||||||
let out = std::process::Command::new("nvcc")
|
let out = std::process::Command::new("nvcc")
|
||||||
.arg("--list-gpu-code")
|
.arg("--list-gpu-code")
|
||||||
.output()
|
.output()
|
||||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||||
|
|
||||||
let out = out.lines().collect::<Vec<&str>>();
|
let out = out.lines().collect::<Vec<&str>>();
|
||||||
@ -243,30 +268,21 @@ fn compute_cap() -> Result<usize> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
codes.sort();
|
codes.sort();
|
||||||
if !codes.contains(&compute_cap) {
|
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
|
||||||
anyhow::bail!(
|
(codes, max_nvcc_code)
|
||||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
|
||||||
);
|
|
||||||
}
|
|
||||||
*codes.last().unwrap()
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
// Check that nvcc supports the asked compute caps
|
||||||
// then choose the highest gpu code in nvcc
|
if !supported_nvcc_codes.contains(&compute_cap) {
|
||||||
if compute_cap > max_nvcc_code {
|
anyhow::bail!(
|
||||||
println!(
|
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
|
||||||
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
);
|
||||||
|
}
|
||||||
|
if compute_cap > max_nvcc_code {
|
||||||
|
anyhow::bail!(
|
||||||
|
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
|
||||||
);
|
);
|
||||||
compute_cap = max_nvcc_code;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
|
||||||
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
|
||||||
compute_cap = compute_cap_str
|
|
||||||
.parse::<usize>()
|
|
||||||
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
|
||||||
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
|
||||||
}
|
|
||||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
|
||||||
Ok(compute_cap)
|
Ok(compute_cap)
|
||||||
}
|
}
|
||||||
|
@ -233,8 +233,8 @@ impl FlashAttnVarLen {
|
|||||||
|
|
||||||
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
|
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
|
||||||
let seqlens_q = match &*seqlens_q {
|
let seqlens_q = match &*seqlens_q {
|
||||||
candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"),
|
|
||||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
||||||
|
_ => candle::bail!("seqlens_q must be a cuda tensor"),
|
||||||
};
|
};
|
||||||
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
|
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
|
||||||
Some((o1, o2)) => seqlens_q.slice(o1..o2),
|
Some((o1, o2)) => seqlens_q.slice(o1..o2),
|
||||||
@ -243,8 +243,8 @@ impl FlashAttnVarLen {
|
|||||||
|
|
||||||
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
|
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
|
||||||
let seqlens_k = match &*seqlens_k {
|
let seqlens_k = match &*seqlens_k {
|
||||||
candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"),
|
|
||||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
||||||
|
_ => candle::bail!("seqlens_k must be a cuda tensor"),
|
||||||
};
|
};
|
||||||
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
|
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
|
||||||
Some((o1, o2)) => seqlens_k.slice(o1..o2),
|
Some((o1, o2)) => seqlens_k.slice(o1..o2),
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
@ -12,5 +12,6 @@ license = "MIT OR Apache-2.0"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
glob = "0.3.1"
|
glob = "0.3.1"
|
||||||
rayon = "1.7.0"
|
rayon = "1.7.0"
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user