mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
Compare commits
308 Commits
meshgrid-f
...
ivarflakst
Author | SHA1 | Date | |
---|---|---|---|
933716b374 | |||
fd7c856564 | |||
1cf34368b7 | |||
17e6e2d7ee | |||
80b1c689f9 | |||
db923517b3 | |||
403680f17d | |||
86a8e58897 | |||
5270224f40 | |||
7e3349d7c3 | |||
1257fc6719 | |||
ea36f3b11f | |||
79478ff5a1 | |||
86b7c01b30 | |||
bdd8107fda | |||
ecf88a6d38 | |||
e6d86b0819 | |||
88618255cb | |||
539ead927a | |||
a46864bd56 | |||
bafe95b660 | |||
a3d92ab226 | |||
e90bcdcc7c | |||
8e06bfb4fd | |||
6242276c09 | |||
e06e8d0dbe | |||
e63bb8661b | |||
41915184bb | |||
c1876b8041 | |||
85e5680277 | |||
1327419776 | |||
402349d120 | |||
9f0c99f0c1 | |||
0fc95c9f0c | |||
2480c5dbdd | |||
63944714f2 | |||
d3bdd788cf | |||
ae06cb74bb | |||
a897fda74e | |||
1f1179913a | |||
6e98cf2a92 | |||
2cc1247999 | |||
edf3fcd1c4 | |||
53e4755015 | |||
87efb5d8eb | |||
ad181f9cdc | |||
88945f2c22 | |||
12b2a337f3 | |||
fb05af4c42 | |||
ad075a5f7e | |||
0eb90ed783 | |||
89b5a06858 | |||
3f04a79ada | |||
30313c3081 | |||
e72d52b1a2 | |||
b4cb982e49 | |||
6ebe043273 | |||
6bf52b9fdf | |||
84250bf52f | |||
8d1a57c9a0 | |||
955e63c803 | |||
3a7304cb0d | |||
fa3ea98ba9 | |||
135ae5f3eb | |||
41614b4a9b | |||
03ce8caf40 | |||
b0fe5e4453 | |||
1fb2dd905c | |||
a0facd0e67 | |||
4290b81244 | |||
51e577a682 | |||
0a245e6fa4 | |||
87d7f81b43 | |||
4373534d59 | |||
f4a2787217 | |||
488e02a3f6 | |||
adc95ca2bf | |||
4907c63ea1 | |||
d76ac20e0e | |||
f5c98f22c7 | |||
5b12fbb143 | |||
cc06ba2294 | |||
a6bd0b47a5 | |||
b59b1b2bb6 | |||
3922b42c18 | |||
1e442d4bb9 | |||
cd889c0f8a | |||
8e93e76a91 | |||
b3e838f3e2 | |||
8bf892403a | |||
d35f0a1376 | |||
65cb90bd40 | |||
996a7f2e24 | |||
3071ea6c3e | |||
37c539f2b7 | |||
eae3a20d43 | |||
13a5d15ebc | |||
1505d85276 | |||
95e18ef675 | |||
7135791dd5 | |||
88589d8815 | |||
5b35fd0fcf | |||
ba1fae590e | |||
78d982e1bd | |||
d8b9a727fc | |||
ceb78d3e28 | |||
f6408a3779 | |||
10d94659c3 | |||
563a79afa1 | |||
8ede5f4210 | |||
9fc210fae8 | |||
9b5e4843a6 | |||
03641293ee | |||
064ba17bd7 | |||
e8ee253ee0 | |||
8bd3d6b94b | |||
6a3ca7da0c | |||
96f1a28e39 | |||
586b6f6fff | |||
e4b0cc59f5 | |||
0a6e0a8c9a | |||
972903021c | |||
94817dac56 | |||
1e86717bf2 | |||
c630622a07 | |||
c4cfcf1539 | |||
1782e93de6 | |||
cfdf9640a3 | |||
e12cbfd73b | |||
30a958e5dd | |||
614842b311 | |||
79eab519fd | |||
6bc92e63cb | |||
aa04015098 | |||
8b5059e951 | |||
26540641c1 | |||
34d83377f6 | |||
77197379cc | |||
916a8c5464 | |||
243e83f2b9 | |||
cf27868b57 | |||
40c3e1bd5a | |||
ece4c69a68 | |||
4eeaf205d6 | |||
f419a38e1a | |||
361f2ad2af | |||
e60f9b5dfc | |||
7be982f6f7 | |||
104e196d46 | |||
5e33c85c8f | |||
2b3a018be7 | |||
931432ed55 | |||
0404a3eb5b | |||
a9d0657432 | |||
4cb443d00a | |||
87dc559817 | |||
77252ffb82 | |||
18eb87f25f | |||
da0af3cb3e | |||
9bd94c1ffa | |||
803ac8405b | |||
6e25822d4f | |||
236b820e28 | |||
2648e797c2 | |||
b5c283e86f | |||
8418154ee0 | |||
99b7273b03 | |||
16161145ae | |||
0738df5290 | |||
37bf1ed012 | |||
dd40edfe73 | |||
5aa1a65dab | |||
2ca086939f | |||
4349ff1fc2 | |||
7c3cfd1086 | |||
e2eb6590ed | |||
481c45d78d | |||
14a2bdc062 | |||
bfa7c8fc01 | |||
762e996ce6 | |||
ca19a9af62 | |||
ec23427d60 | |||
f83e14f68d | |||
c7e613ab5e | |||
8f63f68289 | |||
1edc3ddf24 | |||
b380657bfe | |||
60f624a902 | |||
8d6c6de8e0 | |||
7ec345c2eb | |||
671fc29b36 | |||
dc64adb8e4 | |||
c66e5d4716 | |||
bd3b243725 | |||
2813fb5dbc | |||
7cfffcac10 | |||
38de52bc4b | |||
d46670f7c0 | |||
f710fab02e | |||
f82bf2d915 | |||
df6814f34e | |||
39406a6721 | |||
976ad9f9c2 | |||
a4c4a56429 | |||
f49bf6a81d | |||
992a788da1 | |||
8d8f48c60c | |||
d31f11035f | |||
9ab3f9729f | |||
a1f41ab37b | |||
92a05b51cf | |||
c6763e3b41 | |||
347e31c9ff | |||
f4fcf60900 | |||
12561b31d3 | |||
a209ce8ceb | |||
f1e678b39c | |||
a007f8fdb4 | |||
2341aa079e | |||
9e666d4229 | |||
1b12142a02 | |||
d2c3f14773 | |||
26c4e5bf1d | |||
18d30005c5 | |||
6958384327 | |||
e6697471bb | |||
73d02f4f57 | |||
f772213e84 | |||
2feb0b054f | |||
2d28497197 | |||
f3a4f3db76 | |||
7920b45c8a | |||
d4a45c936a | |||
c912d24570 | |||
d5c2a7b64b | |||
508f811b93 | |||
a773a4b22b | |||
5a363dbc26 | |||
abc4f698c5 | |||
a923e8b53a | |||
2a45bcf943 | |||
47f4ddb011 | |||
f365a075e5 | |||
60fdab4e17 | |||
928a9d906e | |||
d1d89bac1f | |||
39ad840a90 | |||
b5e4f84bed | |||
7051fb8098 | |||
dc68c130e4 | |||
bc9a1bf239 | |||
f7c957d64f | |||
8cbb9d0e6c | |||
bfe95115c6 | |||
6fa3151820 | |||
0a58886ccb | |||
3173b1ce3b | |||
ad63f20781 | |||
1cfc5d6d0c | |||
b07b2350b6 | |||
1b5063f3ca | |||
3b0d1e7d03 | |||
be4555c5a5 | |||
6975c65112 | |||
a2a20aeecc | |||
e08fbb6543 | |||
d39d0c40fd | |||
b97463098c | |||
fbd69f952c | |||
6c990a33ea | |||
1704f1b3ae | |||
693fad511c | |||
36fb84f038 | |||
c12ad45562 | |||
7d0202710b | |||
392a00a147 | |||
4c967b9184 | |||
c05c0a8213 | |||
969960847a | |||
5fc66bd4ba | |||
174b208052 | |||
154c674a79 | |||
7bbde55c61 | |||
c3f2676d49 | |||
46d6566c99 | |||
55bc3382cf | |||
dece37c6f4 | |||
498c50348c | |||
012ae0090e | |||
95a857cf57 | |||
612f5b8156 | |||
ef33df7ae2 | |||
c8face3f95 | |||
85bea43e5b | |||
b3181455d5 | |||
e2826e70b3 | |||
916619f70b | |||
9b1158b315 | |||
70d06ab4b0 | |||
0ec5ebcec4 | |||
c8e197f68c | |||
5f20697918 | |||
e37b487767 | |||
e5dc8cb4f4 | |||
e7b886d56f | |||
6a446d9d73 | |||
0acd16751d | |||
c698e17619 |
7
.github/dependabot.yml
vendored
Normal file
7
.github/dependabot.yml
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: "cargo"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
open-pull-requests-limit: 5
|
6
.github/workflows/ci_cuda.yaml
vendored
6
.github/workflows/ci_cuda.yaml
vendored
@ -8,6 +8,8 @@ jobs:
|
|||||||
start-runner:
|
start-runner:
|
||||||
name: Start self-hosted EC2 runner
|
name: Start self-hosted EC2 runner
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
# Don't run on forks, they won't have access to secrets anyway.
|
||||||
|
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||||
env:
|
env:
|
||||||
AWS_REGION: us-east-1
|
AWS_REGION: us-east-1
|
||||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||||
@ -59,7 +61,7 @@ jobs:
|
|||||||
- name: Install Rust Stable
|
- name: Install Rust Stable
|
||||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||||
- uses: Swatinem/rust-cache@v2
|
- uses: Swatinem/rust-cache@v2
|
||||||
- run: apt-get update -y && apt-get install libssl-dev -y
|
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
|
||||||
- name: Test (cuda)
|
- name: Test (cuda)
|
||||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||||
stop-runner:
|
stop-runner:
|
||||||
@ -70,7 +72,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
AWS_REGION: us-east-1
|
AWS_REGION: us-east-1
|
||||||
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
|
||||||
steps:
|
steps:
|
||||||
- name: Configure AWS credentials
|
- name: Configure AWS credentials
|
||||||
uses: aws-actions/configure-aws-credentials@v1
|
uses: aws-actions/configure-aws-credentials@v1
|
||||||
|
BIN
.github/workflows/maturin.yml
vendored
Normal file
BIN
.github/workflows/maturin.yml
vendored
Normal file
Binary file not shown.
8
.github/workflows/python.yml
vendored
8
.github/workflows/python.yml
vendored
@ -39,6 +39,12 @@ jobs:
|
|||||||
path: ~/.cargo/registry
|
path: ~/.cargo/registry
|
||||||
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
|
- name: Install Protoc
|
||||||
|
uses: arduino/setup-protoc@v2
|
||||||
|
with:
|
||||||
|
version: "25.0"
|
||||||
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
- name: Install
|
- name: Install
|
||||||
working-directory: ./candle-pyo3
|
working-directory: ./candle-pyo3
|
||||||
run: |
|
run: |
|
||||||
@ -46,7 +52,7 @@ jobs:
|
|||||||
source .env/bin/activate
|
source .env/bin/activate
|
||||||
pip install -U pip
|
pip install -U pip
|
||||||
pip install pytest maturin black
|
pip install pytest maturin black
|
||||||
python -m maturin develop -r
|
python -m maturin develop -r --features onnx
|
||||||
|
|
||||||
- name: Check style
|
- name: Check style
|
||||||
working-directory: ./candle-pyo3
|
working-directory: ./candle-pyo3
|
||||||
|
@ -63,7 +63,7 @@ This documents the main changes to the `candle` crate.
|
|||||||
[760](https://github.com/huggingface/candle/pull/760).
|
[760](https://github.com/huggingface/candle/pull/760).
|
||||||
- Add the Segment-Anything Model (SAM) as an example
|
- Add the Segment-Anything Model (SAM) as an example
|
||||||
[773](https://github.com/huggingface/candle/pull/773).
|
[773](https://github.com/huggingface/candle/pull/773).
|
||||||
- TinyViT backbone for the segemnt anything example
|
- TinyViT backbone for the segment anything example
|
||||||
[787](https://github.com/huggingface/candle/pull/787).
|
[787](https://github.com/huggingface/candle/pull/787).
|
||||||
- Shape with holes support
|
- Shape with holes support
|
||||||
[770](https://github.com/huggingface/candle/pull/770).
|
[770](https://github.com/huggingface/candle/pull/770).
|
||||||
|
40
Cargo.toml
40
Cargo.toml
@ -7,20 +7,19 @@ members = [
|
|||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-pyo3",
|
"candle-pyo3",
|
||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
"candle-wasm-examples/llama2-c",
|
"candle-wasm-examples/*",
|
||||||
"candle-wasm-examples/segment-anything",
|
|
||||||
"candle-wasm-examples/whisper",
|
|
||||||
"candle-wasm-examples/yolo",
|
|
||||||
"candle-wasm-examples/bert",
|
|
||||||
"candle-wasm-examples/phi",
|
|
||||||
"candle-wasm-examples/t5",
|
|
||||||
"candle-wasm-tests",
|
"candle-wasm-tests",
|
||||||
]
|
]
|
||||||
exclude = ["candle-flash-attn", "candle-kernels"]
|
exclude = [
|
||||||
|
"candle-flash-attn",
|
||||||
|
"candle-kernels",
|
||||||
|
"candle-metal-kernels",
|
||||||
|
"candle-onnx",
|
||||||
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.3.0"
|
version = "0.3.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -32,9 +31,18 @@ license = "MIT OR Apache-2.0"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
|
candle = { path = "./candle-core", package = "candle-core" }
|
||||||
|
candle-datasets = { path = "./candle-datasets" }
|
||||||
|
candle-flash-attn = { path = "./candle-flash-attn" }
|
||||||
|
candle-kernels = { path = "./candle-kernels" }
|
||||||
|
candle-metal-kernels = { path = "./candle-metal-kernels" }
|
||||||
|
candle-nn = { path = "./candle-nn" }
|
||||||
|
candle-onnx = { path = "./candle-onnx" }
|
||||||
|
candle-transformers = { path = "./candle-transformers" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
|
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||||
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.3.0"
|
hf-hub = "0.3.0"
|
||||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||||
@ -42,25 +50,27 @@ imageproc = { version = "0.23.0", default-features = false }
|
|||||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||||
libc = { version = "0.2.147" }
|
libc = { version = "0.2.147" }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||||
num_cpus = "1.15.0"
|
num_cpus = "1.15.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
parquet = { version = "45.0.0" }
|
parquet = { version = "50.0.0" }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
rand_distr = "0.4.3"
|
rand_distr = "0.4.3"
|
||||||
rayon = "1.7.0"
|
rayon = "1.7.0"
|
||||||
rusttype = { version = "0.9", default-features = false }
|
rusttype = { version = "0.9", default-features = false }
|
||||||
safetensors = "0.3.1"
|
safetensors = "0.4.1"
|
||||||
serde = { version = "1.0.171", features = ["derive"] }
|
serde = { version = "1.0.171", features = ["derive"] }
|
||||||
|
serde_plain = "1.0.2"
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tokenizers = { version = "0.13.4", default-features = false }
|
tokenizers = { version = "0.15.0", default-features = false }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
|
metal = { version = "0.27.0", features = ["mps"]}
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
67
README.md
67
README.md
@ -51,23 +51,32 @@ For more advanced examples, please have a look at the following section.
|
|||||||
These online demos run entirely in your browser:
|
These online demos run entirely in your browser:
|
||||||
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
||||||
object recognition.
|
object recognition.
|
||||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
|
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
|
||||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||||
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
- [Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||||
|
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
|
||||||
|
|
||||||
We also provide a some command line based examples using state of the art models:
|
We also provide a some command line based examples using state of the art models:
|
||||||
|
|
||||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
||||||
|
the SOLAR-10.7B variant.
|
||||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||||
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||||
pre-trained on 1T tokens of English and code datasets.
|
pre-trained on 1T tokens of English and code datasets.
|
||||||
|
- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
|
||||||
|
implementation of the Mamba state space model.
|
||||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||||
performance larger than all publicly available 13b models as of 2023-09-28.
|
better performance than all publicly available 13b models as of 2023-09-28.
|
||||||
|
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
||||||
|
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
||||||
|
much faster inference.
|
||||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||||
|
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||||
|
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||||
the LLaMA model using the same quantization techniques as
|
the LLaMA model using the same quantization techniques as
|
||||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||||
@ -75,7 +84,7 @@ We also provide a some command line based examples using state of the art models
|
|||||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
|
||||||
|
|
||||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||||
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
|
image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
|
||||||
|
|
||||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||||
|
|
||||||
@ -95,12 +104,18 @@ We also provide a some command line based examples using state of the art models
|
|||||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||||
|
|
||||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||||
|
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||||
using self-supervision (can be used for imagenet classification, depth
|
using self-supervision (can be used for imagenet classification, depth
|
||||||
evaluation, segmentation).
|
evaluation, segmentation).
|
||||||
|
- [VGG](./candle-examples/examples/vgg/),
|
||||||
|
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||||
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
generate captions for an image.
|
generate captions for an image.
|
||||||
|
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||||
|
model, generates the translated text from the input text.
|
||||||
|
|
||||||
Run them using commands like:
|
Run them using commands like:
|
||||||
```
|
```
|
||||||
@ -116,7 +131,7 @@ There are also some wasm examples for whisper and
|
|||||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||||
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
[Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||||
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||||
|
|
||||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||||
@ -133,10 +148,20 @@ And then head over to
|
|||||||
<!--- ANCHOR: useful_libraries --->
|
<!--- ANCHOR: useful_libraries --->
|
||||||
|
|
||||||
## Useful External Resources
|
## Useful External Resources
|
||||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
|
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
|
||||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and
|
||||||
that conforms to the official `peft` implementation.
|
ergonomic LoRA implementation for Candle. `candle-lora` has
|
||||||
|
out-of-the-box LoRA support for many models from Candle, which can be found
|
||||||
|
[here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
|
||||||
|
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
|
||||||
|
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||||
|
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||||
|
serving local LLMs including an OpenAI compatible API server.
|
||||||
|
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||||
|
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||||
|
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||||
|
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||||
|
|
||||||
If you have an addition to this list, please submit a pull request.
|
If you have an addition to this list, please submit a pull request.
|
||||||
|
|
||||||
@ -155,15 +180,26 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- WASM support, run your models in a browser.
|
- WASM support, run your models in a browser.
|
||||||
- Included models.
|
- Included models.
|
||||||
- Language Models.
|
- Language Models.
|
||||||
- LLaMA v1 and v2.
|
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
||||||
- Falcon.
|
- Falcon.
|
||||||
- StarCoder.
|
- StarCoder.
|
||||||
- Phi v1.5.
|
- Phi 1, 1.5, and 2.
|
||||||
|
- Minimal Mamba
|
||||||
- Mistral 7b v0.1.
|
- Mistral 7b v0.1.
|
||||||
|
- Mixtral 8x7b v0.1.
|
||||||
- StableLM-3B-4E1T.
|
- StableLM-3B-4E1T.
|
||||||
- Replit-code-v1.5-3B.
|
- Replit-code-v1.5-3B.
|
||||||
- T5.
|
|
||||||
- Bert.
|
- Bert.
|
||||||
|
- Yi-6B and Yi-34B.
|
||||||
|
- Quantized LLMs.
|
||||||
|
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||||
|
- Mistral 7b, and 7b instruct.
|
||||||
|
- Mixtral 8x7b.
|
||||||
|
- Zephyr 7b a and b (Mistral-7b based).
|
||||||
|
- OpenChat 3.5 (Mistral-7b based).
|
||||||
|
- Text to text.
|
||||||
|
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||||
|
- Marian MT (Machine Translation).
|
||||||
- Whisper (multi-lingual support).
|
- Whisper (multi-lingual support).
|
||||||
- Text to image.
|
- Text to image.
|
||||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||||
@ -171,7 +207,7 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Image to text.
|
- Image to text.
|
||||||
- BLIP.
|
- BLIP.
|
||||||
- Computer Vision Models.
|
- Computer Vision Models.
|
||||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG.
|
||||||
- yolo-v3, yolo-v8.
|
- yolo-v3, yolo-v8.
|
||||||
- Segment-Anything Model (SAM).
|
- Segment-Anything Model (SAM).
|
||||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||||
@ -210,6 +246,7 @@ Cheatsheet:
|
|||||||
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
||||||
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
||||||
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
||||||
|
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
|
||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|
||||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { workspace = true }
|
||||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
candle-datasets = { workspace = true }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { workspace = true }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
candle-transformers = { workspace = true }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
candle-flash-attn = { workspace = true, optional = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
@ -28,6 +28,7 @@ let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap()
|
|||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[test]
|
#[test]
|
||||||
fn book_hub_2() {
|
fn book_hub_2() {
|
||||||
|
{
|
||||||
// ANCHOR: book_hub_2
|
// ANCHOR: book_hub_2
|
||||||
use candle::Device;
|
use candle::Device;
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
@ -45,9 +46,10 @@ let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap()
|
|||||||
assert_eq!(weights.len(), 206);
|
assert_eq!(weights.len(), 206);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[rustfmt::skip]
|
// #[rustfmt::skip]
|
||||||
#[test]
|
// #[test]
|
||||||
fn book_hub_3() {
|
// fn book_hub_3() {
|
||||||
|
{
|
||||||
// ANCHOR: book_hub_3
|
// ANCHOR: book_hub_3
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
@ -102,6 +104,7 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
|||||||
assert_eq!(view.shape(), &[768, 768]);
|
assert_eq!(view.shape(), &[768, 768]);
|
||||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -12,7 +12,9 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
candle-kernels = { workspace = true, optional = true }
|
||||||
|
candle-metal-kernels = { workspace = true, optional = true }
|
||||||
|
metal = { workspace = true, optional = true}
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
gemm = { workspace = true }
|
gemm = { workspace = true }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
@ -32,6 +34,8 @@ zip = { workspace = true }
|
|||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
clap = { workspace = true }
|
clap = { workspace = true }
|
||||||
|
criterion = { workspace = true }
|
||||||
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
@ -39,3 +43,8 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
|||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||||
|
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "bench_main"
|
||||||
|
harness = false
|
||||||
|
9
candle-core/benches/bench_main.rs
Normal file
9
candle-core/benches/bench_main.rs
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
mod benchmarks;
|
||||||
|
|
||||||
|
use criterion::criterion_main;
|
||||||
|
criterion_main!(
|
||||||
|
benchmarks::affine::benches,
|
||||||
|
benchmarks::matmul::benches,
|
||||||
|
benchmarks::random::benches,
|
||||||
|
benchmarks::where_cond::benches
|
||||||
|
);
|
43
candle-core/benches/benchmarks/affine.rs
Normal file
43
candle-core/benches/benchmarks/affine.rs
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run(a: &Tensor) {
|
||||||
|
a.affine(12.34, 56.78).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||||
|
let b = 1;
|
||||||
|
let m = 1024;
|
||||||
|
let k = 1024;
|
||||||
|
|
||||||
|
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
||||||
|
|
||||||
|
let flops = b * m * k * dtype.size_in_bytes();
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(black_box(&tensor));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
|
for device in handler.devices {
|
||||||
|
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
|
||||||
|
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
|
||||||
|
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
44
candle-core/benches/benchmarks/matmul.rs
Normal file
44
candle-core/benches/benchmarks/matmul.rs
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run(a: &Tensor, b: &Tensor) {
|
||||||
|
a.matmul(&b.t().unwrap()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_bench(c: &mut Criterion, device: &Device) {
|
||||||
|
let b = 1;
|
||||||
|
let m = 1;
|
||||||
|
let n = 2048;
|
||||||
|
let k = 2048;
|
||||||
|
|
||||||
|
let dtype = DType::F32;
|
||||||
|
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||||
|
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
||||||
|
|
||||||
|
let flops = b * m * n * k;
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name("matmul"));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(black_box(&lhs), black_box(&rhs));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
|
for device in handler.devices {
|
||||||
|
run_bench(c, &device);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
66
candle-core/benches/benchmarks/mod.rs
Normal file
66
candle-core/benches/benchmarks/mod.rs
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
pub(crate) mod affine;
|
||||||
|
pub(crate) mod matmul;
|
||||||
|
pub(crate) mod random;
|
||||||
|
pub(crate) mod where_cond;
|
||||||
|
|
||||||
|
use candle_core::{Device, Result};
|
||||||
|
|
||||||
|
pub(crate) trait BenchDevice {
|
||||||
|
fn sync(&self) -> Result<()>;
|
||||||
|
|
||||||
|
fn bench_name<S: Into<String>>(&self, name: S) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BenchDevice for Device {
|
||||||
|
fn sync(&self) -> Result<()> {
|
||||||
|
match self {
|
||||||
|
Device::Cpu => Ok(()),
|
||||||
|
Device::Cuda(device) => {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
return Ok(device.synchronize()?);
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||||
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
return Ok(device.wait_until_completed()?);
|
||||||
|
#[cfg(not(feature = "metal"))]
|
||||||
|
panic!("Metal device without metal feature enabled: {:?}", device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bench_name<S: Into<String>>(&self, name: S) -> String {
|
||||||
|
match self {
|
||||||
|
Device::Cpu => {
|
||||||
|
let cpu_type = if cfg!(feature = "accelerate") {
|
||||||
|
"accelerate"
|
||||||
|
} else if cfg!(feature = "mkl") {
|
||||||
|
"mkl"
|
||||||
|
} else {
|
||||||
|
"cpu"
|
||||||
|
};
|
||||||
|
format!("{}_{}", cpu_type, name.into())
|
||||||
|
}
|
||||||
|
Device::Cuda(_) => format!("cuda_{}", name.into()),
|
||||||
|
Device::Metal(_) => format!("metal_{}", name.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BenchDeviceHandler {
|
||||||
|
devices: Vec<Device>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BenchDeviceHandler {
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
let mut devices = Vec::new();
|
||||||
|
if cfg!(feature = "metal") {
|
||||||
|
devices.push(Device::new_metal(0)?);
|
||||||
|
} else if cfg!(feature = "cuda") {
|
||||||
|
devices.push(Device::new_cuda(0)?);
|
||||||
|
}
|
||||||
|
devices.push(Device::Cpu);
|
||||||
|
Ok(Self { devices })
|
||||||
|
}
|
||||||
|
}
|
63
candle-core/benches/benchmarks/random.rs
Normal file
63
candle-core/benches/benchmarks/random.rs
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn rand_uniform(a: &Tensor) {
|
||||||
|
a.rand_like(-1.0, 123.0).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal(a: &Tensor) {
|
||||||
|
a.randn_like(100.0, 15.0).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_random_bench(c: &mut Criterion, device: &Device) {
|
||||||
|
let b = 1;
|
||||||
|
|
||||||
|
let rows = 2048;
|
||||||
|
let cols = 2048;
|
||||||
|
|
||||||
|
let dtype = DType::F32;
|
||||||
|
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||||
|
|
||||||
|
let flops = b * rows * cols * dtype.size_in_bytes();
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |benches| {
|
||||||
|
benches.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
rand_uniform(black_box(&tensor));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
|
||||||
|
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name("random_normal"));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |benches| {
|
||||||
|
benches.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
rand_normal(black_box(&tensor));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
|
for device in handler.devices {
|
||||||
|
run_random_bench(c, &device);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
64
candle-core/benches/benchmarks/where_cond.rs
Normal file
64
candle-core/benches/benchmarks/where_cond.rs
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
|
||||||
|
a.where_cond(b, c).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn create_cond_arr<const N: usize>() -> [u8; N] {
|
||||||
|
let mut arr = [0u8; N];
|
||||||
|
let mut i = 0;
|
||||||
|
while i < N {
|
||||||
|
arr[i] = (i % 2) as u8;
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
arr
|
||||||
|
}
|
||||||
|
|
||||||
|
const B: usize = 1;
|
||||||
|
const M: usize = 1024;
|
||||||
|
const K: usize = 1024;
|
||||||
|
const SIZE: usize = B * M * K;
|
||||||
|
|
||||||
|
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
||||||
|
|
||||||
|
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||||
|
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
||||||
|
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
||||||
|
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
||||||
|
|
||||||
|
let elements = B * M * K;
|
||||||
|
// E.g. 2 f32 tensors + 1 u8 tensor
|
||||||
|
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(
|
||||||
|
black_box(&tensor),
|
||||||
|
black_box(&on_true),
|
||||||
|
black_box(&on_false),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let device = BenchDeviceHandler::new().unwrap();
|
||||||
|
for d in device.devices {
|
||||||
|
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
|
||||||
|
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
|
||||||
|
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
@ -8,11 +8,10 @@ use anyhow::Result;
|
|||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
||||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
||||||
let start = std::time::Instant::now();
|
let new_a = a.slice_scatter(&b, 1, 2)?;
|
||||||
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
println!("{:?}", start.elapsed());
|
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
println!("{res:?}");
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use candle_core::quantized::{gguf_file, k_quants, QTensor};
|
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
|
||||||
use candle_core::{Device, Result, Tensor};
|
use candle_core::{Device, Result};
|
||||||
use clap::{Parser, Subcommand, ValueEnum};
|
use clap::{Parser, Subcommand, ValueEnum};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
@ -11,12 +11,7 @@ enum QuantizationMode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl QuantizationMode {
|
impl QuantizationMode {
|
||||||
fn quantize(
|
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> {
|
||||||
&self,
|
|
||||||
name: &str,
|
|
||||||
tensor: QTensor,
|
|
||||||
default: fn(&Tensor) -> Result<QTensor>,
|
|
||||||
) -> Result<QTensor> {
|
|
||||||
match self {
|
match self {
|
||||||
Self::Llama => {
|
Self::Llama => {
|
||||||
// Same behavior as the llama.cpp quantization.
|
// Same behavior as the llama.cpp quantization.
|
||||||
@ -24,9 +19,9 @@ impl QuantizationMode {
|
|||||||
if should_quantize {
|
if should_quantize {
|
||||||
let tensor = tensor.dequantize(&Device::Cpu)?;
|
let tensor = tensor.dequantize(&Device::Cpu)?;
|
||||||
if name == "output.weight" {
|
if name == "output.weight" {
|
||||||
QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
|
QTensor::quantize(&tensor, GgmlDType::Q6K)
|
||||||
} else {
|
} else {
|
||||||
default(&tensor)
|
QTensor::quantize(&tensor, dtype)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Ok(tensor)
|
Ok(tensor)
|
||||||
@ -60,6 +55,27 @@ enum Quantization {
|
|||||||
F32,
|
F32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Quantization {
|
||||||
|
fn dtype(&self) -> GgmlDType {
|
||||||
|
match self {
|
||||||
|
Quantization::Q4_0 => GgmlDType::Q4_0,
|
||||||
|
Quantization::Q4_1 => GgmlDType::Q4_1,
|
||||||
|
Quantization::Q5_0 => GgmlDType::Q5_0,
|
||||||
|
Quantization::Q5_1 => GgmlDType::Q5_1,
|
||||||
|
Quantization::Q8_0 => GgmlDType::Q8_0,
|
||||||
|
Quantization::Q8_1 => GgmlDType::Q8_1,
|
||||||
|
Quantization::Q2k => GgmlDType::Q2K,
|
||||||
|
Quantization::Q3k => GgmlDType::Q3K,
|
||||||
|
Quantization::Q4k => GgmlDType::Q4K,
|
||||||
|
Quantization::Q5k => GgmlDType::Q5K,
|
||||||
|
Quantization::Q6k => GgmlDType::Q6K,
|
||||||
|
Quantization::Q8k => GgmlDType::Q8K,
|
||||||
|
Quantization::F16 => GgmlDType::F16,
|
||||||
|
Quantization::F32 => GgmlDType::F32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(ValueEnum, Debug, Clone)]
|
#[derive(ValueEnum, Debug, Clone)]
|
||||||
enum Format {
|
enum Format {
|
||||||
Safetensors,
|
Safetensors,
|
||||||
@ -102,7 +118,7 @@ enum Command {
|
|||||||
},
|
},
|
||||||
|
|
||||||
Quantize {
|
Quantize {
|
||||||
/// The input file, in gguf format.
|
/// The input file(s), in safetensors format.
|
||||||
in_file: Vec<std::path::PathBuf>,
|
in_file: Vec<std::path::PathBuf>,
|
||||||
|
|
||||||
/// The output file, in gguf format.
|
/// The output file, in gguf format.
|
||||||
@ -117,6 +133,15 @@ enum Command {
|
|||||||
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
||||||
mode: QuantizationMode,
|
mode: QuantizationMode,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
Dequantize {
|
||||||
|
/// The input file, in gguf format.
|
||||||
|
in_file: std::path::PathBuf,
|
||||||
|
|
||||||
|
/// The output file, in safetensors format.
|
||||||
|
#[arg(long)]
|
||||||
|
out_file: std::path::PathBuf,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
@ -125,7 +150,12 @@ struct Args {
|
|||||||
command: Command,
|
command: Command,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
|
fn run_ls(
|
||||||
|
file: &std::path::PathBuf,
|
||||||
|
format: Option<Format>,
|
||||||
|
verbose: bool,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<()> {
|
||||||
let format = match format {
|
let format = match format {
|
||||||
Some(format) => format,
|
Some(format) => format,
|
||||||
None => match Format::infer(file) {
|
None => match Format::infer(file) {
|
||||||
@ -191,7 +221,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
|||||||
}
|
}
|
||||||
Format::Ggml => {
|
Format::Ggml => {
|
||||||
let mut file = std::fs::File::open(file)?;
|
let mut file = std::fs::File::open(file)?;
|
||||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
|
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||||
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
for (name, qtensor) in tensors.iter() {
|
for (name, qtensor) in tensors.iter() {
|
||||||
@ -232,37 +262,8 @@ fn run_quantize_safetensors(
|
|||||||
}
|
}
|
||||||
println!("tensors: {}", tensors.len());
|
println!("tensors: {}", tensors.len());
|
||||||
|
|
||||||
let quantize_fn = match q {
|
let dtype = q.dtype();
|
||||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
let block_size = dtype.block_size();
|
||||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
|
||||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
|
||||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
|
||||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
|
||||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
|
||||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
|
||||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
|
||||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
|
||||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
|
||||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
|
||||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
|
||||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
|
||||||
Quantization::F32 => QTensor::quantize::<f32>,
|
|
||||||
};
|
|
||||||
let block_size = match q {
|
|
||||||
Quantization::Q4_0 => k_quants::QK4_0,
|
|
||||||
Quantization::Q4_1 => k_quants::QK4_1,
|
|
||||||
Quantization::Q5_0 => k_quants::QK5_0,
|
|
||||||
Quantization::Q5_1 => k_quants::QK5_1,
|
|
||||||
Quantization::Q8_0 => k_quants::QK8_0,
|
|
||||||
Quantization::Q8_1 => k_quants::QK8_1,
|
|
||||||
Quantization::Q2k
|
|
||||||
| Quantization::Q3k
|
|
||||||
| Quantization::Q4k
|
|
||||||
| Quantization::Q5k
|
|
||||||
| Quantization::Q6k
|
|
||||||
| Quantization::Q8k => k_quants::QK_K,
|
|
||||||
Quantization::F16 | Quantization::F32 => 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
let qtensors = tensors
|
let qtensors = tensors
|
||||||
.into_par_iter()
|
.into_par_iter()
|
||||||
@ -270,9 +271,9 @@ fn run_quantize_safetensors(
|
|||||||
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
||||||
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
||||||
let tensor = if should_quantize {
|
let tensor = if should_quantize {
|
||||||
quantize_fn(&tensor)?
|
QTensor::quantize(&tensor, dtype)?
|
||||||
} else {
|
} else {
|
||||||
QTensor::quantize::<f32>(&tensor)?
|
QTensor::quantize(&tensor, GgmlDType::F32)?
|
||||||
};
|
};
|
||||||
Ok((name, tensor))
|
Ok((name, tensor))
|
||||||
})
|
})
|
||||||
@ -285,11 +286,29 @@ fn run_quantize_safetensors(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn run_dequantize(
|
||||||
|
in_file: std::path::PathBuf,
|
||||||
|
out_file: std::path::PathBuf,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut in_file = std::fs::File::open(in_file)?;
|
||||||
|
let content = gguf_file::Content::read(&mut in_file)?;
|
||||||
|
let mut tensors = std::collections::HashMap::new();
|
||||||
|
for (tensor_name, _) in content.tensor_infos.iter() {
|
||||||
|
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
|
||||||
|
let tensor = tensor.dequantize(device)?;
|
||||||
|
tensors.insert(tensor_name.to_string(), tensor);
|
||||||
|
}
|
||||||
|
candle_core::safetensors::save(&tensors, out_file)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn run_quantize(
|
fn run_quantize(
|
||||||
in_files: &[std::path::PathBuf],
|
in_files: &[std::path::PathBuf],
|
||||||
out_file: std::path::PathBuf,
|
out_file: std::path::PathBuf,
|
||||||
q: Quantization,
|
q: Quantization,
|
||||||
qmode: QuantizationMode,
|
qmode: QuantizationMode,
|
||||||
|
device: &Device,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if in_files.is_empty() {
|
if in_files.is_empty() {
|
||||||
candle_core::bail!("no specified input files")
|
candle_core::bail!("no specified input files")
|
||||||
@ -315,31 +334,15 @@ fn run_quantize(
|
|||||||
let content = gguf_file::Content::read(&mut in_)?;
|
let content = gguf_file::Content::read(&mut in_)?;
|
||||||
println!("tensors: {}", content.tensor_infos.len());
|
println!("tensors: {}", content.tensor_infos.len());
|
||||||
|
|
||||||
let quantize_fn = match q {
|
let dtype = q.dtype();
|
||||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
|
||||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
|
||||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
|
||||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
|
||||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
|
||||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
|
||||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
|
||||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
|
||||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
|
||||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
|
||||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
|
||||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
|
||||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
|
||||||
Quantization::F32 => QTensor::quantize::<f32>,
|
|
||||||
};
|
|
||||||
|
|
||||||
let qtensors = content
|
let qtensors = content
|
||||||
.tensor_infos
|
.tensor_infos
|
||||||
.par_iter()
|
.par_iter()
|
||||||
.map(|(name, _)| {
|
.map(|(name, _)| {
|
||||||
println!(" quantizing {name}");
|
println!(" quantizing {name}");
|
||||||
let mut in_file = std::fs::File::open(&in_files[0])?;
|
let mut in_file = std::fs::File::open(&in_files[0])?;
|
||||||
let tensor = content.tensor(&mut in_file, name)?;
|
let tensor = content.tensor(&mut in_file, name, device)?;
|
||||||
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
let tensor = qmode.quantize(name, tensor, dtype)?;
|
||||||
Ok((name, tensor))
|
Ok((name, tensor))
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
@ -359,6 +362,7 @@ fn run_quantize(
|
|||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
let device = Device::Cpu;
|
||||||
match args.command {
|
match args.command {
|
||||||
Command::Ls {
|
Command::Ls {
|
||||||
files,
|
files,
|
||||||
@ -370,7 +374,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
if multiple_files {
|
if multiple_files {
|
||||||
println!("--- {file:?} ---");
|
println!("--- {file:?} ---");
|
||||||
}
|
}
|
||||||
run_ls(file, format.clone(), verbose)?
|
run_ls(file, format.clone(), verbose, &device)?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Command::Quantize {
|
Command::Quantize {
|
||||||
@ -378,7 +382,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
out_file,
|
out_file,
|
||||||
quantization,
|
quantization,
|
||||||
mode,
|
mode,
|
||||||
} => run_quantize(&in_file, out_file, quantization, mode)?,
|
} => run_quantize(&in_file, out_file, quantization, mode, &device)?,
|
||||||
|
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,14 @@ pub trait BackendStorage: Sized {
|
|||||||
_params: &crate::conv::ParamsConv1D,
|
_params: &crate::conv::ParamsConv1D,
|
||||||
) -> Result<Self>;
|
) -> Result<Self>;
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self>;
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
_l: &Layout,
|
_l: &Layout,
|
||||||
|
@ -15,6 +15,17 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
|
||||||
|
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
|
||||||
|
Ok(s) => {
|
||||||
|
!s.is_empty() && s != "0"
|
||||||
|
},
|
||||||
|
Err(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||||
@ -57,6 +68,11 @@ impl Tensor {
|
|||||||
kernel: rhs,
|
kernel: rhs,
|
||||||
..
|
..
|
||||||
}
|
}
|
||||||
|
| Op::ConvTranspose1D {
|
||||||
|
arg: lhs,
|
||||||
|
kernel: rhs,
|
||||||
|
..
|
||||||
|
}
|
||||||
| Op::Conv2D {
|
| Op::Conv2D {
|
||||||
arg: lhs,
|
arg: lhs,
|
||||||
kernel: rhs,
|
kernel: rhs,
|
||||||
@ -98,7 +114,7 @@ impl Tensor {
|
|||||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
| Op::UpsampleNearest1D(node)
|
| Op::UpsampleNearest1D(node)
|
||||||
| Op::UpsampleNearest2D(node)
|
| Op::UpsampleNearest2D { arg: node, .. }
|
||||||
| Op::AvgPool2D { arg: node, .. }
|
| Op::AvgPool2D { arg: node, .. }
|
||||||
| Op::MaxPool2D { arg: node, .. }
|
| Op::MaxPool2D { arg: node, .. }
|
||||||
| Op::Copy(node)
|
| Op::Copy(node)
|
||||||
@ -150,10 +166,16 @@ impl Tensor {
|
|||||||
if node.is_variable() {
|
if node.is_variable() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let grad = grads.remove(node).unwrap();
|
let grad = grads
|
||||||
// TODO: We should perform all these operations in place (or at least not track the
|
.remove(node)
|
||||||
// whole graph). The only drawback would be if we wanted to support grad of grad but
|
.expect("candle internal error - grad not populated");
|
||||||
// this is out of scope.
|
// https://github.com/huggingface/candle/issues/1241
|
||||||
|
// Ideally, we would make these operations in place where possible to ensure that we
|
||||||
|
// do not have to allocate too often. Here we just call `.detach` to avoid computing
|
||||||
|
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||||
|
// derivatives but these are out of scope at the moment.
|
||||||
|
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||||
|
let grad = if do_not_detach { grad } else { grad.detach()? };
|
||||||
if let Some(op) = node.op() {
|
if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||||
@ -208,7 +230,44 @@ impl Tensor {
|
|||||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||||
}
|
}
|
||||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
Op::Conv1D {
|
||||||
|
arg,
|
||||||
|
kernel,
|
||||||
|
padding,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
} => {
|
||||||
|
// The output height for conv_transpose1d is:
|
||||||
|
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
|
||||||
|
let grad_l_in = grad.dim(2)?;
|
||||||
|
let k_size = kernel.dim(2)?;
|
||||||
|
let out_size =
|
||||||
|
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
|
||||||
|
let out_padding = arg.dim(2)? - out_size;
|
||||||
|
let grad_arg = grad.conv_transpose1d(
|
||||||
|
kernel,
|
||||||
|
*padding,
|
||||||
|
out_padding,
|
||||||
|
*stride,
|
||||||
|
*dilation,
|
||||||
|
)?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
|
|
||||||
|
let grad_kernel = arg
|
||||||
|
.transpose(0, 1)?
|
||||||
|
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||||
|
.transpose(0, 1)?;
|
||||||
|
let sum_grad = grads.or_insert(kernel)?;
|
||||||
|
let (_, _, k0) = kernel.dims3()?;
|
||||||
|
let (_, _, g_k0) = grad_kernel.dims3()?;
|
||||||
|
let grad_kernel = if g_k0 != k0 {
|
||||||
|
grad_kernel.narrow(2, 0, k0)?
|
||||||
|
} else {
|
||||||
|
grad_kernel
|
||||||
|
};
|
||||||
|
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||||
|
}
|
||||||
Op::Conv2D {
|
Op::Conv2D {
|
||||||
arg,
|
arg,
|
||||||
kernel,
|
kernel,
|
||||||
@ -238,8 +297,18 @@ impl Tensor {
|
|||||||
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||||
.transpose(0, 1)?;
|
.transpose(0, 1)?;
|
||||||
let sum_grad = grads.or_insert(kernel)?;
|
let sum_grad = grads.or_insert(kernel)?;
|
||||||
|
let (_, _, k0, k1) = kernel.dims4()?;
|
||||||
|
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
||||||
|
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
||||||
|
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
||||||
|
} else {
|
||||||
|
grad_kernel
|
||||||
|
};
|
||||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||||
}
|
}
|
||||||
|
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
||||||
|
op: "conv-transpose1d",
|
||||||
|
})?,
|
||||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "conv-transpose2d",
|
op: "conv-transpose2d",
|
||||||
})?,
|
})?,
|
||||||
@ -281,9 +350,27 @@ impl Tensor {
|
|||||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "upsample-nearest1d",
|
op: "upsample-nearest1d",
|
||||||
})?,
|
})?,
|
||||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
Op::UpsampleNearest2D {
|
||||||
op: "upsample-nearest2d",
|
arg,
|
||||||
})?,
|
target_h,
|
||||||
|
target_w,
|
||||||
|
} => {
|
||||||
|
let (_n, c, h, w) = arg.dims4()?;
|
||||||
|
if target_h % h != 0 || target_w % w != 0 {
|
||||||
|
crate::bail!("backward not supported for non integer upscaling factors")
|
||||||
|
}
|
||||||
|
let scale_h = target_h / h;
|
||||||
|
let scale_w = target_w / w;
|
||||||
|
|
||||||
|
if scale_h != scale_w {
|
||||||
|
crate::bail!("backward not supported for non uniform upscaling factors")
|
||||||
|
};
|
||||||
|
let kernel =
|
||||||
|
Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
|
||||||
|
let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = conv_sum;
|
||||||
|
}
|
||||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||||
@ -480,16 +567,38 @@ impl Tensor {
|
|||||||
+ 0.5)?;
|
+ 0.5)?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
Op::Unary(arg, UnaryOp::Erf) => {
|
||||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
|
||||||
|
let erf_grad =
|
||||||
|
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
|
||||||
|
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
|
||||||
|
}
|
||||||
|
Op::Unary(arg, UnaryOp::GeluErf) => {
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
|
||||||
|
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
|
||||||
|
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
|
||||||
|
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
|
||||||
|
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
|
||||||
|
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
|
||||||
|
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
|
||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Relu) => {
|
Op::Unary(arg, UnaryOp::Relu) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||||
}
|
}
|
||||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
Op::Elu(arg, alpha) => {
|
||||||
|
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
let zeros = arg.zeros_like()?;
|
||||||
|
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||||
|
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||||
|
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
||||||
|
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||||
|
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||||
|
}
|
||||||
Op::Powf(arg, e) => {
|
Op::Powf(arg, e) => {
|
||||||
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
@ -25,6 +25,33 @@ impl ParamsConv1D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct ParamsConvTranspose1D {
|
||||||
|
pub(crate) b_size: usize,
|
||||||
|
pub(crate) l_in: usize,
|
||||||
|
pub(crate) c_out: usize,
|
||||||
|
pub(crate) c_in: usize,
|
||||||
|
pub(crate) k_size: usize,
|
||||||
|
pub(crate) padding: usize,
|
||||||
|
pub(crate) output_padding: usize,
|
||||||
|
pub(crate) stride: usize,
|
||||||
|
pub(crate) dilation: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ParamsConvTranspose1D {
|
||||||
|
pub(crate) fn l_out(&self) -> usize {
|
||||||
|
(self.l_in - 1) * self.stride - 2 * self.padding
|
||||||
|
+ self.dilation * (self.k_size - 1)
|
||||||
|
+ self.output_padding
|
||||||
|
+ 1
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||||
|
let l_out = self.l_out();
|
||||||
|
vec![self.b_size, self.c_out, l_out]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum CudnnFwdAlgo {
|
pub enum CudnnFwdAlgo {
|
||||||
ImplicitGemm,
|
ImplicitGemm,
|
||||||
@ -160,6 +187,49 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies a 1D transposed convolution over the input tensor.
|
||||||
|
pub fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
kernel: &Self,
|
||||||
|
padding: usize,
|
||||||
|
output_padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (b_size, c_in, l_in) = self.dims3()?;
|
||||||
|
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||||
|
if c_in != c_in_k {
|
||||||
|
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||||
|
}
|
||||||
|
let params = ParamsConvTranspose1D {
|
||||||
|
b_size,
|
||||||
|
l_in,
|
||||||
|
k_size,
|
||||||
|
c_out,
|
||||||
|
c_in,
|
||||||
|
padding,
|
||||||
|
output_padding,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
};
|
||||||
|
let storage = self.storage().conv_transpose1d(
|
||||||
|
self.layout(),
|
||||||
|
&kernel.storage(),
|
||||||
|
kernel.layout(),
|
||||||
|
¶ms,
|
||||||
|
)?;
|
||||||
|
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||||
|
arg,
|
||||||
|
kernel,
|
||||||
|
padding: params.padding,
|
||||||
|
output_padding: params.output_padding,
|
||||||
|
stride: params.stride,
|
||||||
|
dilation: params.dilation,
|
||||||
|
});
|
||||||
|
let out_dims = params.out_dims();
|
||||||
|
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||||
let storage =
|
let storage =
|
||||||
self.storage()
|
self.storage()
|
||||||
|
@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
|||||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||||
let ids = match self.ids_l.contiguous_offsets() {
|
let ids = match self.ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &self.ids[a..b],
|
Some((a, b)) => &self.ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||||
};
|
};
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &src[a..b],
|
Some((a, b)) => &src[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||||
};
|
};
|
||||||
let dim = self.dim;
|
let dim = self.dim;
|
||||||
let ids_dims = self.ids_l.dims();
|
let ids_dims = self.ids_l.dims();
|
||||||
@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
|||||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||||
let src = match layout.contiguous_offsets() {
|
let src = match layout.contiguous_offsets() {
|
||||||
Some((a, b)) => &src[a..b],
|
Some((a, b)) => &src[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||||
};
|
};
|
||||||
let dim = self.dim;
|
let dim = self.dim;
|
||||||
let n_ids = match self.ids_l.dims() {
|
let n_ids = match self.ids_l.dims() {
|
||||||
@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
|||||||
let mut dst = vec![T::zero(); dst_len];
|
let mut dst = vec![T::zero(); dst_len];
|
||||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
|
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||||
Some((o1, o2)) => &src[o1..o2],
|
Some((o1, o2)) => &src[o1..o2],
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
|||||||
|
|
||||||
let ids = match self.ids_l.contiguous_offsets() {
|
let ids = match self.ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &self.ids[a..b],
|
Some((a, b)) => &self.ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||||
};
|
};
|
||||||
for left_i in 0..ids_left_len {
|
for left_i in 0..ids_left_len {
|
||||||
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
||||||
@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
|||||||
let mut dst = vec![T::zero(); dst_len];
|
let mut dst = vec![T::zero(); dst_len];
|
||||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
Some((o1, o2)) => &src[o1..o2],
|
Some((o1, o2)) => &src[o1..o2],
|
||||||
};
|
};
|
||||||
let dim = self.dim;
|
let dim = self.dim;
|
||||||
@ -1256,6 +1256,74 @@ impl Map1 for Im2Col {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||||
|
|
||||||
|
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||||
|
const OP: &'static str = "conv_transpose1d";
|
||||||
|
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||||
|
let p = self.0;
|
||||||
|
let inp = &inp[inp_l.start_offset()..];
|
||||||
|
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||||
|
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||||
|
let l_out = p.l_out();
|
||||||
|
|
||||||
|
// Output shape: [b_size, c_out, l_out].
|
||||||
|
let dst_elems = p.c_out * l_out * p.b_size;
|
||||||
|
let dst = vec![T::zero(); dst_elems];
|
||||||
|
let dst_s0 = p.c_out * l_out;
|
||||||
|
let dst_s1 = l_out;
|
||||||
|
let dst_s2 = 1;
|
||||||
|
|
||||||
|
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||||
|
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
||||||
|
let cont_s0 = p.l_in * p.c_in;
|
||||||
|
let cont_s1 = p.c_in;
|
||||||
|
for b_idx in 0..p.b_size {
|
||||||
|
for l_idx in 0..p.l_in {
|
||||||
|
for c_idx in 0..p.c_in {
|
||||||
|
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
|
||||||
|
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
|
||||||
|
inp_cont[dst_idx] = inp[src_idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for k_idx in 0..p.k_size {
|
||||||
|
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||||
|
let k_cont = (0..p.c_in)
|
||||||
|
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for b_idx in 0..p.b_size {
|
||||||
|
for l_idx in 0..p.l_in {
|
||||||
|
let out_idx = l_idx * p.stride + k_idx * p.dilation;
|
||||||
|
if out_idx < p.padding {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let out_idx = out_idx - p.padding;
|
||||||
|
if out_idx < l_out {
|
||||||
|
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
|
||||||
|
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
|
||||||
|
let mut d = T::zero();
|
||||||
|
unsafe {
|
||||||
|
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
|
||||||
|
}
|
||||||
|
let dst_p = dst.as_ptr();
|
||||||
|
// Safety: dst_idx are uniques per dst_c_idx which is used to
|
||||||
|
// parallelise the different tasks so no two threads can try to
|
||||||
|
// write at the same location.
|
||||||
|
unsafe {
|
||||||
|
let ptr = dst_p.add(dst_idx) as *mut T;
|
||||||
|
*ptr += d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||||
|
|
||||||
impl<'a> Map2 for Conv2D<'a> {
|
impl<'a> Map2 for Conv2D<'a> {
|
||||||
@ -2435,6 +2503,16 @@ impl BackendStorage for CpuStorage {
|
|||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
l: &Layout,
|
||||||
|
kernel: &Self,
|
||||||
|
kernel_l: &Layout,
|
||||||
|
params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||||
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
@ -2539,25 +2617,25 @@ impl BackendStorage for CpuStorage {
|
|||||||
Self::U8(ids) => {
|
Self::U8(ids) => {
|
||||||
let ids = match ids_l.contiguous_offsets() {
|
let ids = match ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &ids[a..b],
|
Some((a, b)) => &ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
};
|
};
|
||||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||||
}
|
}
|
||||||
Self::U32(ids) => {
|
Self::U32(ids) => {
|
||||||
let ids = match ids_l.contiguous_offsets() {
|
let ids = match ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &ids[a..b],
|
Some((a, b)) => &ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
};
|
};
|
||||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||||
}
|
}
|
||||||
Self::I64(ids) => {
|
Self::I64(ids) => {
|
||||||
let ids = match ids_l.contiguous_offsets() {
|
let ids = match ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &ids[a..b],
|
Some((a, b)) => &ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
};
|
};
|
||||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||||
}
|
}
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1808,6 +1808,16 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cudnn"))]
|
#[cfg(not(feature = "cudnn"))]
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
|
@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
|||||||
pub enum DeviceLocation {
|
pub enum DeviceLocation {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda { gpu_id: usize },
|
Cuda { gpu_id: usize },
|
||||||
|
Metal { gpu_id: usize },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Device {
|
pub enum Device {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda(crate::CudaDevice),
|
Cuda(crate::CudaDevice),
|
||||||
|
Metal(crate::MetalDevice),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait NdArray {
|
pub trait NdArray {
|
||||||
@ -128,10 +130,15 @@ impl Device {
|
|||||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
||||||
|
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
|
Self::Cpu => CpuDevice.set_seed(seed),
|
||||||
Self::Cuda(c) => c.set_seed(seed),
|
Self::Cuda(c) => c.set_seed(seed),
|
||||||
|
Self::Metal(m) => m.set_seed(seed),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,6 +146,7 @@ impl Device {
|
|||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Self::Cpu, Self::Cpu) => true,
|
(Self::Cpu, Self::Cpu) => true,
|
||||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -147,21 +155,20 @@ impl Device {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu => DeviceLocation::Cpu,
|
Self::Cpu => DeviceLocation::Cpu,
|
||||||
Self::Cuda(device) => device.location(),
|
Self::Cuda(device) => device.location(),
|
||||||
|
Device::Metal(device) => device.location(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_cpu(&self) -> bool {
|
pub fn is_cpu(&self) -> bool {
|
||||||
match self {
|
matches!(self, Self::Cpu)
|
||||||
Self::Cpu => true,
|
|
||||||
Self::Cuda(_) => false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_cuda(&self) -> bool {
|
pub fn is_cuda(&self) -> bool {
|
||||||
match self {
|
matches!(self, Self::Cuda(_))
|
||||||
Self::Cpu => false,
|
|
||||||
Self::Cuda(_) => true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_metal(&self) -> bool {
|
||||||
|
matches!(self, Self::Metal(_))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||||
@ -185,10 +192,20 @@ impl Device {
|
|||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
|
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||||
|
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||||
|
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
|
||||||
|
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||||
|
} else {
|
||||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn rand_uniform<T: crate::FloatDType>(
|
pub(crate) fn rand_uniform<T: crate::FloatDType>(
|
||||||
@ -213,10 +230,20 @@ impl Device {
|
|||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
|
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||||
|
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||||
|
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
|
||||||
|
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||||
|
} else {
|
||||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn rand_normal<T: crate::FloatDType>(
|
pub(crate) fn rand_normal<T: crate::FloatDType>(
|
||||||
@ -238,6 +265,10 @@ impl Device {
|
|||||||
let storage = device.ones_impl(shape, dtype)?;
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,6 +282,10 @@ impl Device {
|
|||||||
let storage = device.zeros_impl(shape, dtype)?;
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,6 +297,11 @@ impl Device {
|
|||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = array.to_cpu_storage();
|
||||||
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -273,6 +313,11 @@ impl Device {
|
|||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = S::to_cpu_storage_owned(data);
|
||||||
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,9 @@ impl Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
|
crate::DeviceLocation::Metal { gpu_id } => {
|
||||||
|
format!(", metal:{}", gpu_id)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(f, "Tensor[")?;
|
write!(f, "Tensor[")?;
|
||||||
@ -476,6 +479,9 @@ impl std::fmt::Display for Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
|
crate::DeviceLocation::Metal { gpu_id } => {
|
||||||
|
format!(", metal:{}", gpu_id)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(
|
write!(
|
||||||
|
@ -79,6 +79,16 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
|
223
candle-core/src/dummy_metal_backend.rs
Normal file
223
candle-core/src/dummy_metal_backend.rs
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
#![allow(dead_code)]
|
||||||
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
|
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MetalDevice;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MetalStorage;
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum MetalError {
|
||||||
|
#[error("{0}")]
|
||||||
|
Message(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for MetalError {
|
||||||
|
fn from(e: String) -> Self {
|
||||||
|
MetalError::Message(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! fail {
|
||||||
|
() => {
|
||||||
|
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::backend::BackendStorage for MetalStorage {
|
||||||
|
type Device = MetalDevice;
|
||||||
|
|
||||||
|
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dtype(&self) -> DType {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(&self) -> &Self::Device {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv1d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConv1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv2d(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &crate::conv::ParamsConv2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose2d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConvTranspose2D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scatter_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_add(
|
||||||
|
&self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matmul(
|
||||||
|
&self,
|
||||||
|
_: &Self,
|
||||||
|
_: (usize, usize, usize, usize),
|
||||||
|
_: &Layout,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::backend::BackendDevice for MetalDevice {
|
||||||
|
type Storage = MetalStorage;
|
||||||
|
fn new(_: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, _: u64) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn same_device(&self, _: &Self) -> bool {
|
||||||
|
fail!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{DType, DeviceLocation, Layout, Shape};
|
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MatMulUnexpectedStriding {
|
pub struct MatMulUnexpectedStriding {
|
||||||
@ -152,6 +152,9 @@ pub enum Error {
|
|||||||
#[error("the candle crate has not been built with cuda support")]
|
#[error("the candle crate has not been built with cuda support")]
|
||||||
NotCompiledWithCudaSupport,
|
NotCompiledWithCudaSupport,
|
||||||
|
|
||||||
|
#[error("the candle crate has not been built with metal support")]
|
||||||
|
NotCompiledWithMetalSupport,
|
||||||
|
|
||||||
#[error("cannot find tensor {path}")]
|
#[error("cannot find tensor {path}")]
|
||||||
CannotFindTensor { path: String },
|
CannotFindTensor { path: String },
|
||||||
|
|
||||||
@ -159,6 +162,9 @@ pub enum Error {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
|
||||||
|
#[error("Metal error {0}")]
|
||||||
|
Metal(#[from] MetalError),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ impl Tensor {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
/// Generic structure used to index a slice of the tensor
|
/// Generic structure used to index a slice of the tensor
|
||||||
pub enum TensorIndexer {
|
pub enum TensorIndexer {
|
||||||
/// This selects the elemnts for which an index has some specific value.
|
/// This selects the elements for which an index has some specific value.
|
||||||
Select(usize),
|
Select(usize),
|
||||||
/// This is a regular slice, purely indexing a chunk of the tensor
|
/// This is a regular slice, purely indexing a chunk of the tensor
|
||||||
Narrow(Bound<usize>, Bound<usize>),
|
Narrow(Bound<usize>, Bound<usize>),
|
||||||
@ -104,37 +104,31 @@ impl From<&Tensor> for TensorIndexer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! impl_from_range {
|
trait RB: RangeBounds<usize> {}
|
||||||
($range_type:ty) => {
|
impl RB for Range<usize> {}
|
||||||
impl From<$range_type> for TensorIndexer {
|
impl RB for RangeFrom<usize> {}
|
||||||
fn from(range: $range_type) -> Self {
|
impl RB for RangeFull {}
|
||||||
use std::ops::Bound::*;
|
impl RB for RangeInclusive<usize> {}
|
||||||
|
impl RB for RangeTo<usize> {}
|
||||||
|
impl RB for RangeToInclusive<usize> {}
|
||||||
|
|
||||||
|
impl<T: RB> From<T> for TensorIndexer {
|
||||||
|
fn from(range: T) -> Self {
|
||||||
|
use std::ops::Bound::*;
|
||||||
let start = match range.start_bound() {
|
let start = match range.start_bound() {
|
||||||
Included(idx) => Included(*idx),
|
Included(idx) => Included(*idx),
|
||||||
Excluded(idx) => Excluded(*idx),
|
Excluded(idx) => Excluded(*idx),
|
||||||
Unbounded => Unbounded,
|
Unbounded => Unbounded,
|
||||||
};
|
};
|
||||||
|
|
||||||
let end = match range.end_bound() {
|
let end = match range.end_bound() {
|
||||||
Included(idx) => Included(*idx),
|
Included(idx) => Included(*idx),
|
||||||
Excluded(idx) => Excluded(*idx),
|
Excluded(idx) => Excluded(*idx),
|
||||||
Unbounded => Unbounded,
|
Unbounded => Unbounded,
|
||||||
};
|
};
|
||||||
|
|
||||||
TensorIndexer::Narrow(start, end)
|
TensorIndexer::Narrow(start, end)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_from_range!(Range<usize>);
|
|
||||||
impl_from_range!(RangeFrom<usize>);
|
|
||||||
impl_from_range!(RangeFull);
|
|
||||||
impl_from_range!(RangeInclusive<usize>);
|
|
||||||
impl_from_range!(RangeTo<usize>);
|
|
||||||
impl_from_range!(RangeToInclusive<usize>);
|
|
||||||
|
|
||||||
/// Trait used to implement multiple signatures for ease of use of the slicing
|
/// Trait used to implement multiple signatures for ease of use of the slicing
|
||||||
/// of a tensor
|
/// of a tensor
|
||||||
pub trait IndexOp<T> {
|
pub trait IndexOp<T> {
|
||||||
|
@ -49,9 +49,12 @@ mod device;
|
|||||||
pub mod display;
|
pub mod display;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
mod dummy_cuda_backend;
|
mod dummy_cuda_backend;
|
||||||
|
mod dummy_metal_backend;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
mod indexer;
|
mod indexer;
|
||||||
pub mod layout;
|
pub mod layout;
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
pub mod metal_backend;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
mod mkl;
|
mod mkl;
|
||||||
pub mod npy;
|
pub mod npy;
|
||||||
@ -69,7 +72,7 @@ pub mod utils;
|
|||||||
mod variable;
|
mod variable;
|
||||||
|
|
||||||
pub use cpu_backend::CpuStorage;
|
pub use cpu_backend::CpuStorage;
|
||||||
pub use device::{Device, DeviceLocation};
|
pub use device::{Device, DeviceLocation, NdArray};
|
||||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use indexer::IndexOp;
|
pub use indexer::IndexOp;
|
||||||
@ -87,6 +90,12 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
|
|||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||||
|
|
||||||
|
#[cfg(not(feature = "metal"))]
|
||||||
|
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
@ -114,14 +123,20 @@ pub trait Module {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for quantized::QMatMul {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
self.forward(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
self(xs)
|
self(xs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||||
|
// separate the training and evaluation behaviors.
|
||||||
|
pub trait ModuleT {
|
||||||
|
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: Module> ModuleT for M {
|
||||||
|
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
|
||||||
|
self.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
1698
candle-core/src/metal_backend.rs
Normal file
1698
candle-core/src/metal_backend.rs
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
|||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use num_traits::float::Float;
|
use num_traits::float::Float;
|
||||||
|
|
||||||
@ -90,6 +90,16 @@ pub enum Op {
|
|||||||
dilation: usize,
|
dilation: usize,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
ConvTranspose1D {
|
||||||
|
arg: Tensor,
|
||||||
|
kernel: Tensor,
|
||||||
|
padding: usize,
|
||||||
|
output_padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
},
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
Conv2D {
|
Conv2D {
|
||||||
arg: Tensor,
|
arg: Tensor,
|
||||||
@ -122,7 +132,11 @@ pub enum Op {
|
|||||||
},
|
},
|
||||||
|
|
||||||
UpsampleNearest1D(Tensor),
|
UpsampleNearest1D(Tensor),
|
||||||
UpsampleNearest2D(Tensor),
|
UpsampleNearest2D {
|
||||||
|
arg: Tensor,
|
||||||
|
target_h: usize,
|
||||||
|
target_w: usize,
|
||||||
|
},
|
||||||
|
|
||||||
Cat(Vec<Tensor>, usize),
|
Cat(Vec<Tensor>, usize),
|
||||||
|
|
||||||
@ -174,6 +188,18 @@ pub trait CustomOp1 {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_storage: &MetalStorage,
|
||||||
|
_layout: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||||
/// The function should return the gradient of the argument.
|
/// The function should return the gradient of the argument.
|
||||||
@ -209,6 +235,20 @@ pub trait CustomOp2 {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
fn bwd(
|
fn bwd(
|
||||||
&self,
|
&self,
|
||||||
_arg1: &Tensor,
|
_arg1: &Tensor,
|
||||||
@ -251,6 +291,22 @@ pub trait CustomOp3 {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
fn bwd(
|
fn bwd(
|
||||||
&self,
|
&self,
|
||||||
_arg1: &Tensor,
|
_arg1: &Tensor,
|
||||||
@ -536,13 +592,13 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
|||||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||||
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
||||||
unary_op!(Abs, "abs", v, v.abs());
|
|
||||||
unary_op!(Neg, "neg", v, -v);
|
unary_op!(Neg, "neg", v, -v);
|
||||||
unary_op!(Recip, "recip", v, v.recip());
|
unary_op!(Recip, "recip", v, v.recip());
|
||||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||||
|
|
||||||
/// `gelu` operation
|
/// Tanh based approximation of the `gelu` operation
|
||||||
|
/// GeluErf is the more precise one.
|
||||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
impl UnaryOpT for Gelu {
|
impl UnaryOpT for Gelu {
|
||||||
const NAME: &'static str = "gelu";
|
const NAME: &'static str = "gelu";
|
||||||
@ -632,6 +688,8 @@ impl UnaryOpT for Gelu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// `erf` operation
|
||||||
|
/// <https://en.wikipedia.org/wiki/Error_function>
|
||||||
impl UnaryOpT for Erf {
|
impl UnaryOpT for Erf {
|
||||||
const NAME: &'static str = "erf";
|
const NAME: &'static str = "erf";
|
||||||
const KERNEL: &'static str = "uerf";
|
const KERNEL: &'static str = "uerf";
|
||||||
@ -666,6 +724,40 @@ impl UnaryOpT for Erf {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl UnaryOpT for Abs {
|
||||||
|
const NAME: &'static str = "abs";
|
||||||
|
const KERNEL: &'static str = "uabs";
|
||||||
|
const V: Self = Abs;
|
||||||
|
#[inline(always)]
|
||||||
|
fn bf16(v: bf16) -> bf16 {
|
||||||
|
v.abs()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f16(v: f16) -> f16 {
|
||||||
|
v.abs()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32(v: f32) -> f32 {
|
||||||
|
v.abs()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
v.abs()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u8(v: u8) -> u8 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u32(v: u32) -> u32 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn i64(v: i64) -> i64 {
|
||||||
|
v.abs()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UnaryOpT for Ceil {
|
impl UnaryOpT for Ceil {
|
||||||
const NAME: &'static str = "ceil";
|
const NAME: &'static str = "ceil";
|
||||||
const KERNEL: &'static str = "uceil";
|
const KERNEL: &'static str = "uceil";
|
||||||
@ -887,6 +979,10 @@ impl BackpropOp {
|
|||||||
};
|
};
|
||||||
Self(op)
|
Self(op)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn is_none(&self) -> bool {
|
||||||
|
self.0.is_none()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for BackpropOp {
|
impl std::ops::Deref for BackpropOp {
|
||||||
|
@ -703,6 +703,7 @@ impl PthTensors {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||||
|
use std::io::Read;
|
||||||
let tensor_info = match self.tensor_infos.get(name) {
|
let tensor_info = match self.tensor_infos.get(name) {
|
||||||
None => return Ok(None),
|
None => return Ok(None),
|
||||||
Some(tensor_info) => tensor_info,
|
Some(tensor_info) => tensor_info,
|
||||||
@ -712,14 +713,21 @@ impl PthTensors {
|
|||||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||||
|
|
||||||
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||||
// For now only support the basic case.
|
// case.
|
||||||
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
if !tensor_info.layout.is_contiguous() {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"cannot retrieve non-contiguous tensors {:?}",
|
"cannot retrieve non-contiguous tensors {:?}",
|
||||||
tensor_info.layout
|
tensor_info.layout
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
let start_offset = tensor_info.layout.start_offset();
|
||||||
|
if start_offset > 0 {
|
||||||
|
std::io::copy(
|
||||||
|
&mut reader.by_ref().take(start_offset as u64),
|
||||||
|
&mut std::io::sink(),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
let tensor = Tensor::from_reader(
|
let tensor = Tensor::from_reader(
|
||||||
tensor_info.layout.shape().clone(),
|
tensor_info.layout.shape().clone(),
|
||||||
tensor_info.dtype,
|
tensor_info.dtype,
|
||||||
|
@ -50,14 +50,9 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
|
|||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||||
let qk = QK8_0;
|
let qk = QK8_0;
|
||||||
let nb = n / qk;
|
|
||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
if nb % 2 != 0 {
|
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut acc = _mm256_setzero_ps();
|
let mut acc = _mm256_setzero_ps();
|
||||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
@ -358,7 +353,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
|||||||
q3 = q3.add(32);
|
q3 = q3.add(32);
|
||||||
|
|
||||||
// Prepare low and high bits
|
// Prepare low and high bits
|
||||||
// We hardcode the shifts here to avoid loading them into a seperate register
|
// We hardcode the shifts here to avoid loading them into a separate register
|
||||||
let q3l_0 = _mm256_and_si256(q3bits, m3);
|
let q3l_0 = _mm256_and_si256(q3bits, m3);
|
||||||
let q3h_0 = if j == 0 {
|
let q3h_0 = if j == 0 {
|
||||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
|
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
|
||||||
@ -591,7 +586,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
|||||||
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
|
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
|
||||||
q5 = q5.add(32);
|
q5 = q5.add(32);
|
||||||
|
|
||||||
//Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register
|
//Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
|
||||||
let q5l_0 = _mm256_and_si256(q5bits, m4);
|
let q5l_0 = _mm256_and_si256(q5bits, m4);
|
||||||
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
|
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
|
||||||
let q5l_0_right_shift = match j {
|
let q5l_0_right_shift = match j {
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
//! Support for the GGML file format.
|
//! Support for the GGML file format.
|
||||||
|
|
||||||
use super::{k_quants, GgmlDType};
|
#[cfg(feature = "metal")]
|
||||||
use crate::Result;
|
use super::metal::load_quantized_metal;
|
||||||
|
use super::{k_quants, GgmlDType, QStorage};
|
||||||
|
use crate::{Device, Result};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -121,11 +123,22 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
|||||||
raw_data: &[u8],
|
raw_data: &[u8],
|
||||||
size_in_bytes: usize,
|
size_in_bytes: usize,
|
||||||
dims: Vec<usize>,
|
dims: Vec<usize>,
|
||||||
|
device: &Device,
|
||||||
) -> Result<super::QTensor> {
|
) -> Result<super::QTensor> {
|
||||||
let raw_data_ptr = raw_data.as_ptr();
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||||
super::QTensor::new(data.to_vec(), dims)
|
let data: QStorage = match device {
|
||||||
|
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
Device::Metal(metal) => load_quantized_metal(metal, data)?,
|
||||||
|
#[cfg(not(feature = "metal"))]
|
||||||
|
Device::Metal(_metal) => {
|
||||||
|
crate::bail!("Metal backend requires `metal` feature")
|
||||||
|
}
|
||||||
|
device => unimplemented!("Implement quantized tensor for device {device:?}"),
|
||||||
|
};
|
||||||
|
super::QTensor::new(data, dims)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a [Tensor] from a raw GGML tensor.
|
/// Creates a [Tensor] from a raw GGML tensor.
|
||||||
@ -133,29 +146,50 @@ pub fn qtensor_from_ggml(
|
|||||||
ggml_dtype: GgmlDType,
|
ggml_dtype: GgmlDType,
|
||||||
raw_data: &[u8],
|
raw_data: &[u8],
|
||||||
dims: Vec<usize>,
|
dims: Vec<usize>,
|
||||||
|
device: &Device,
|
||||||
) -> Result<super::QTensor> {
|
) -> Result<super::QTensor> {
|
||||||
let tensor_elems = dims.iter().product::<usize>();
|
let tensor_elems = dims.iter().product::<usize>();
|
||||||
let blck_size = ggml_dtype.blck_size();
|
let block_size = ggml_dtype.block_size();
|
||||||
if tensor_elems % blck_size != 0 {
|
if tensor_elems % block_size != 0 {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
|
||||||
|
|
||||||
match ggml_dtype {
|
match ggml_dtype {
|
||||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
||||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
||||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
GgmlDType::Q4_0 => {
|
||||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
||||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
}
|
||||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
GgmlDType::Q4_1 => {
|
||||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
||||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
}
|
||||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
GgmlDType::Q5_0 => {
|
||||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
||||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
}
|
||||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
GgmlDType::Q5_1 => {
|
||||||
|
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q8_0 => {
|
||||||
|
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q2K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q3K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q4K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q5K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q6K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -163,6 +197,7 @@ pub fn qtensor_from_ggml(
|
|||||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
magic: VersionedMagic,
|
magic: VersionedMagic,
|
||||||
|
device: &Device,
|
||||||
) -> Result<(String, super::QTensor)> {
|
) -> Result<(String, super::QTensor)> {
|
||||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||||
@ -183,11 +218,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
|||||||
}
|
}
|
||||||
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||||
let tensor_elems = dims.iter().product::<usize>();
|
let tensor_elems = dims.iter().product::<usize>();
|
||||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
|
||||||
// TODO: Mmap version to avoid copying the data around?
|
// TODO: Mmap version to avoid copying the data around?
|
||||||
let mut raw_data = vec![0u8; size_in_bytes];
|
let mut raw_data = vec![0u8; size_in_bytes];
|
||||||
reader.read_exact(&mut raw_data)?;
|
reader.read_exact(&mut raw_data)?;
|
||||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
|
||||||
Ok(tensor) => Ok((name, tensor)),
|
Ok(tensor) => Ok((name, tensor)),
|
||||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||||
}
|
}
|
||||||
@ -201,7 +236,10 @@ pub struct Content {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Content {
|
impl Content {
|
||||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||||
|
reader: &mut R,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Content> {
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||||
@ -211,7 +249,7 @@ impl Content {
|
|||||||
let mut tensors = HashMap::new();
|
let mut tensors = HashMap::new();
|
||||||
|
|
||||||
while reader.stream_position()? != last_position {
|
while reader.stream_position()? != last_position {
|
||||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||||
tensors.insert(name, tensor);
|
tensors.insert(name, tensor);
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||||
|
|
||||||
use super::{GgmlDType, QTensor};
|
use super::{GgmlDType, QTensor};
|
||||||
use crate::Result;
|
use crate::{Device, Result};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -29,6 +29,7 @@ impl TryFrom<u32> for Magic {
|
|||||||
pub enum VersionedMagic {
|
pub enum VersionedMagic {
|
||||||
GgufV1,
|
GgufV1,
|
||||||
GgufV2,
|
GgufV2,
|
||||||
|
GgufV3,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VersionedMagic {
|
impl VersionedMagic {
|
||||||
@ -39,7 +40,8 @@ impl VersionedMagic {
|
|||||||
let versioned_magic = match (magic, version) {
|
let versioned_magic = match (magic, version) {
|
||||||
(Magic::Gguf, 1) => Self::GgufV1,
|
(Magic::Gguf, 1) => Self::GgufV1,
|
||||||
(Magic::Gguf, 2) => Self::GgufV2,
|
(Magic::Gguf, 2) => Self::GgufV2,
|
||||||
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
(Magic::Gguf, 3) => Self::GgufV3,
|
||||||
|
_ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"),
|
||||||
};
|
};
|
||||||
Ok(versioned_magic)
|
Ok(versioned_magic)
|
||||||
}
|
}
|
||||||
@ -57,19 +59,25 @@ impl TensorInfo {
|
|||||||
&self,
|
&self,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
tensor_data_offset: u64,
|
tensor_data_offset: u64,
|
||||||
|
device: &Device,
|
||||||
) -> Result<QTensor> {
|
) -> Result<QTensor> {
|
||||||
let tensor_elems = self.shape.elem_count();
|
let tensor_elems = self.shape.elem_count();
|
||||||
let blck_size = self.ggml_dtype.blck_size();
|
let block_size = self.ggml_dtype.block_size();
|
||||||
if tensor_elems % blck_size != 0 {
|
if tensor_elems % block_size != 0 {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
|
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size();
|
||||||
let mut raw_data = vec![0u8; size_in_bytes];
|
let mut raw_data = vec![0u8; size_in_bytes];
|
||||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||||
reader.read_exact(&mut raw_data)?;
|
reader.read_exact(&mut raw_data)?;
|
||||||
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
super::ggml_file::qtensor_from_ggml(
|
||||||
|
self.ggml_dtype,
|
||||||
|
&raw_data,
|
||||||
|
self.shape.dims().to_vec(),
|
||||||
|
device,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,7 +92,9 @@ pub struct Content {
|
|||||||
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
||||||
let len = match magic {
|
let len = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||||
|
reader.read_u64::<LittleEndian>()? as usize
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let mut v = vec![0u8; len];
|
let mut v = vec![0u8; len];
|
||||||
reader.read_exact(&mut v)?;
|
reader.read_exact(&mut v)?;
|
||||||
@ -284,7 +294,9 @@ impl Value {
|
|||||||
let value_type = ValueType::from_u32(value_type)?;
|
let value_type = ValueType::from_u32(value_type)?;
|
||||||
let len = match magic {
|
let len = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||||
|
reader.read_u64::<LittleEndian>()? as usize
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let mut vs = Vec::with_capacity(len);
|
let mut vs = Vec::with_capacity(len);
|
||||||
for _ in 0..len {
|
for _ in 0..len {
|
||||||
@ -381,11 +393,15 @@ impl Content {
|
|||||||
|
|
||||||
let tensor_count = match magic {
|
let tensor_count = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||||
|
reader.read_u64::<LittleEndian>()? as usize
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let metadata_kv_count = match magic {
|
let metadata_kv_count = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||||
|
reader.read_u64::<LittleEndian>()? as usize
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut metadata = HashMap::new();
|
let mut metadata = HashMap::new();
|
||||||
@ -407,7 +423,7 @@ impl Content {
|
|||||||
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
||||||
dimensions.into_iter().map(|c| c as usize).collect()
|
dimensions.into_iter().map(|c| c as usize).collect()
|
||||||
}
|
}
|
||||||
VersionedMagic::GgufV2 => {
|
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||||
let mut dimensions = vec![0; n_dimensions as usize];
|
let mut dimensions = vec![0; n_dimensions as usize];
|
||||||
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
||||||
dimensions.into_iter().map(|c| c as usize).collect()
|
dimensions.into_iter().map(|c| c as usize).collect()
|
||||||
@ -450,12 +466,13 @@ impl Content {
|
|||||||
&self,
|
&self,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
name: &str,
|
name: &str,
|
||||||
|
device: &Device,
|
||||||
) -> Result<QTensor> {
|
) -> Result<QTensor> {
|
||||||
let tensor_info = match self.tensor_infos.get(name) {
|
let tensor_info = match self.tensor_infos.get(name) {
|
||||||
Some(tensor_info) => tensor_info,
|
Some(tensor_info) => tensor_info,
|
||||||
None => crate::bail!("cannot find tensor-infor for {name}"),
|
None => crate::bail!("cannot find tensor info for {name}"),
|
||||||
};
|
};
|
||||||
tensor_info.read(reader, self.tensor_data_offset)
|
tensor_info.read(reader, self.tensor_data_offset, device)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -507,10 +524,9 @@ pub fn write<W: std::io::Seek + std::io::Write>(
|
|||||||
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let data_ptr = tensor.as_ptr();
|
let data = tensor.data()?;
|
||||||
let size_in_bytes = tensor.storage_size_in_bytes();
|
let size_in_bytes = data.len();
|
||||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
w.write_all(&data)?;
|
||||||
w.write_all(data)?;
|
|
||||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||||
w.write_all(&vec![0u8; padding])?;
|
w.write_all(&vec![0u8; padding])?;
|
||||||
}
|
}
|
||||||
|
@ -236,14 +236,9 @@ impl GgmlType for BlockQ4_0 {
|
|||||||
|
|
||||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
let qk = QK8_0;
|
let qk = QK8_0;
|
||||||
let nb = n / qk;
|
|
||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
if nb % 2 != 0 {
|
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generic implementation.
|
// Generic implementation.
|
||||||
let mut sumf = 0f32;
|
let mut sumf = 0f32;
|
||||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||||
@ -1550,13 +1545,13 @@ impl GgmlType for BlockQ5K {
|
|||||||
let d2 = d * sc as f32;
|
let d2 = d * sc as f32;
|
||||||
let m2 = min * m as f32;
|
let m2 = min * m as f32;
|
||||||
for (ql, qh) in ql.iter().zip(qh) {
|
for (ql, qh) in ql.iter().zip(qh) {
|
||||||
let to_add = if qh & u1 != 0 { 16 } else { 1 };
|
let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
|
||||||
y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
|
y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
|
||||||
ys_index += 1;
|
ys_index += 1;
|
||||||
}
|
}
|
||||||
for (ql, qh) in ql.iter().zip(qh) {
|
for (ql, qh) in ql.iter().zip(qh) {
|
||||||
let to_add = if qh & u2 != 0 { 16 } else { 1 };
|
let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
|
||||||
y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
|
y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
|
||||||
ys_index += 1;
|
ys_index += 1;
|
||||||
}
|
}
|
||||||
is += 2;
|
is += 2;
|
||||||
|
153
candle-core/src/quantized/metal.rs
Normal file
153
candle-core/src/quantized/metal.rs
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
use super::{GgmlDType, QStorage};
|
||||||
|
use crate::{DType, MetalDevice, MetalStorage, Result};
|
||||||
|
use metal::Buffer;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub struct QMetalStorage {
|
||||||
|
dtype: GgmlDType,
|
||||||
|
device: MetalDevice,
|
||||||
|
buffer: Arc<Buffer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QMetalStorage {
|
||||||
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
|
self.dtype
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn buffer(&self) -> &Buffer {
|
||||||
|
&self.buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
|
||||||
|
Self {
|
||||||
|
device,
|
||||||
|
buffer,
|
||||||
|
dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||||
|
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||||
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
command_buffer.set_label("to_cpu");
|
||||||
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
|
blit.set_label("blit_to_cpu");
|
||||||
|
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||||
|
blit.end_encoding();
|
||||||
|
self.device.wait_until_completed()?;
|
||||||
|
let mut out = vec![0.0; elem_count];
|
||||||
|
match self.dtype {
|
||||||
|
GgmlDType::F32 => {
|
||||||
|
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
f32::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::F16 => {
|
||||||
|
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
half::f16::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::Q4_0 => {
|
||||||
|
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::Q4_1 => {
|
||||||
|
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::Q5_0 => {
|
||||||
|
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::Q5_1 => {
|
||||||
|
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::Q8_0 => {
|
||||||
|
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::Q8_1 => {
|
||||||
|
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::Q2K => {
|
||||||
|
let vec: Vec<crate::quantized::BlockQ2K> =
|
||||||
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::Q3K => {
|
||||||
|
let vec: Vec<crate::quantized::BlockQ3K> =
|
||||||
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::Q4K => {
|
||||||
|
let vec: Vec<crate::quantized::BlockQ4K> =
|
||||||
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::Q5K => {
|
||||||
|
let vec: Vec<crate::quantized::BlockQ5K> =
|
||||||
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::Q6K => {
|
||||||
|
let vec: Vec<crate::quantized::BlockQ6K> =
|
||||||
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
GgmlDType::Q8K => {
|
||||||
|
let vec: Vec<crate::quantized::BlockQ8K> =
|
||||||
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let buffer = self.device.new_buffer_with_data(&out)?;
|
||||||
|
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
||||||
|
// Quantization only happens on CPU for now.
|
||||||
|
let src = src.to_cpu::<f32>()?;
|
||||||
|
let elem_count = src.len();
|
||||||
|
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||||
|
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
|
||||||
|
qcpu_storage.quantize(&src)?;
|
||||||
|
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
|
||||||
|
self.buffer = buffer;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
||||||
|
device: &MetalDevice,
|
||||||
|
data: &[T],
|
||||||
|
) -> Result<QStorage> {
|
||||||
|
let buffer = device.new_buffer_with_data(data)?;
|
||||||
|
let device = device.clone();
|
||||||
|
Ok(QStorage::Metal(QMetalStorage {
|
||||||
|
dtype: T::DTYPE,
|
||||||
|
device,
|
||||||
|
buffer,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||||
|
let ptr = buffer.contents() as *const T;
|
||||||
|
assert!(!ptr.is_null());
|
||||||
|
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||||
|
slice.to_vec()
|
||||||
|
}
|
@ -1,23 +1,125 @@
|
|||||||
use crate::{Device, Result, Shape, Tensor};
|
#[cfg(feature = "metal")]
|
||||||
|
use crate::{backend::BackendStorage, DType};
|
||||||
|
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||||
|
use k_quants::*;
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
#[cfg(target_feature = "avx")]
|
#[cfg(target_feature = "avx")]
|
||||||
pub mod avx;
|
pub mod avx;
|
||||||
pub mod ggml_file;
|
pub mod ggml_file;
|
||||||
pub mod gguf_file;
|
pub mod gguf_file;
|
||||||
pub mod k_quants;
|
pub mod k_quants;
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
pub mod metal;
|
||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
pub mod neon;
|
pub mod neon;
|
||||||
#[cfg(target_feature = "simd128")]
|
#[cfg(target_feature = "simd128")]
|
||||||
pub mod simd128;
|
pub mod simd128;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
use half::f16;
|
||||||
|
|
||||||
pub use k_quants::GgmlType;
|
pub use k_quants::GgmlType;
|
||||||
|
|
||||||
pub struct QTensor {
|
pub struct QTensor {
|
||||||
data: Box<dyn QuantizedType>,
|
storage: QStorage,
|
||||||
shape: Shape,
|
shape: Shape,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Device {
|
||||||
|
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
|
||||||
|
match self {
|
||||||
|
Device::Cpu => {
|
||||||
|
let storage = dtype.cpu_zeros(elem_count);
|
||||||
|
Ok(QStorage::Cpu(storage))
|
||||||
|
}
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
Device::Metal(metal) => {
|
||||||
|
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||||
|
let buffer = metal.allocate_zeros(size)?;
|
||||||
|
Ok(QStorage::Metal(metal::QMetalStorage::new(
|
||||||
|
buffer,
|
||||||
|
metal.clone(),
|
||||||
|
dtype,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "metal"))]
|
||||||
|
Device::Metal(_metal) => {
|
||||||
|
crate::bail!("Metal feature not activated");
|
||||||
|
}
|
||||||
|
Device::Cuda(_cuda) => {
|
||||||
|
crate::bail!("Cuda ggml quantization not supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum QStorage {
|
||||||
|
Cpu(Box<dyn QuantizedType>),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
Metal(metal::QMetalStorage),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QStorage {
|
||||||
|
fn block_size(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
QStorage::Cpu(storage) => storage.block_size(),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dtype(&self) -> GgmlDType {
|
||||||
|
match self {
|
||||||
|
QStorage::Cpu(storage) => storage.dtype(),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
QStorage::Metal(storage) => storage.dtype(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size_in_bytes(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
QStorage::Metal(storage) => storage.buffer().length() as usize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn quantize(&mut self, src: &Storage) -> Result<()> {
|
||||||
|
match (self, src) {
|
||||||
|
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
||||||
|
storage.from_float(src.as_slice::<f32>()?)?;
|
||||||
|
}
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||||
|
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
||||||
|
match self {
|
||||||
|
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data(&self) -> Result<Cow<[u8]>> {
|
||||||
|
match self {
|
||||||
|
QStorage::Cpu(storage) => {
|
||||||
|
let data_ptr = storage.as_ptr();
|
||||||
|
let size_in_bytes = storage.storage_size_in_bytes();
|
||||||
|
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||||
|
Ok(Cow::from(data))
|
||||||
|
}
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
QStorage::Metal(_storage) => {
|
||||||
|
crate::bail!("not implemented");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub enum GgmlDType {
|
pub enum GgmlDType {
|
||||||
F32,
|
F32,
|
||||||
@ -77,6 +179,25 @@ impl GgmlDType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The block dtype
|
||||||
|
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
|
||||||
|
match self {
|
||||||
|
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
|
||||||
|
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
|
||||||
|
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
|
||||||
|
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
|
||||||
|
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
|
||||||
|
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
|
||||||
|
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
|
||||||
|
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
|
||||||
|
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
|
||||||
|
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
|
||||||
|
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
|
||||||
|
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
|
||||||
|
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
|
||||||
|
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
|
||||||
|
}
|
||||||
|
}
|
||||||
/// The type size for blocks in bytes.
|
/// The type size for blocks in bytes.
|
||||||
pub fn type_size(&self) -> usize {
|
pub fn type_size(&self) -> usize {
|
||||||
use k_quants::*;
|
use k_quants::*;
|
||||||
@ -100,7 +221,7 @@ impl GgmlDType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// The block size, i.e. the number of elements stored in each block.
|
/// The block size, i.e. the number of elements stored in each block.
|
||||||
pub fn blck_size(&self) -> usize {
|
pub fn block_size(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
Self::F32 => 1,
|
Self::F32 => 1,
|
||||||
Self::F16 => 1,
|
Self::F16 => 1,
|
||||||
@ -119,9 +240,13 @@ impl GgmlDType {
|
|||||||
pub trait QuantizedType: Send + Sync {
|
pub trait QuantizedType: Send + Sync {
|
||||||
fn dtype(&self) -> GgmlDType;
|
fn dtype(&self) -> GgmlDType;
|
||||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||||
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
|
||||||
fn storage_size_in_bytes(&self) -> usize;
|
fn storage_size_in_bytes(&self) -> usize;
|
||||||
fn as_ptr(&self) -> *const u8;
|
fn as_ptr(&self) -> *const u8;
|
||||||
|
fn block_size(&self) -> usize;
|
||||||
|
#[allow(clippy::wrong_self_convention)]
|
||||||
|
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
|
||||||
|
fn size(&self) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||||
@ -129,12 +254,26 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
|||||||
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn size(&self) -> usize {
|
||||||
|
self.len() * core::mem::size_of::<T>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
|
||||||
|
T::from_float(xs, self)
|
||||||
|
}
|
||||||
|
|
||||||
fn dtype(&self) -> GgmlDType {
|
fn dtype(&self) -> GgmlDType {
|
||||||
T::DTYPE
|
T::DTYPE
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
fn block_size(&self) -> usize {
|
||||||
T::to_float(self.as_slice(), ys)
|
T::BLCK_SIZE
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
|
||||||
|
let mut ys = vec![0.0f32; elem_count];
|
||||||
|
T::to_float(self.as_slice(), &mut ys)?;
|
||||||
|
Ok(CpuStorage::F32(ys))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn storage_size_in_bytes(&self) -> usize {
|
fn storage_size_in_bytes(&self) -> usize {
|
||||||
@ -152,56 +291,49 @@ impl std::fmt::Debug for QTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
|
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
if dims.is_empty() {
|
if dims.is_empty() {
|
||||||
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
||||||
}
|
}
|
||||||
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
|
if dims[dims.len() - 1] % block_size != 0 {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
||||||
T::BLCK_SIZE
|
block_size
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QTensor {
|
impl QTensor {
|
||||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
|
||||||
data: Vec<T>,
|
|
||||||
shape: S,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
check_shape::<T>(&shape)?;
|
check_shape(&shape, storage.block_size())?;
|
||||||
Ok(Self {
|
Ok(Self { storage, shape })
|
||||||
data: Box::new(data),
|
|
||||||
shape,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
|
||||||
let shape = src.shape();
|
let shape = src.shape();
|
||||||
check_shape::<T>(shape)?;
|
let block_size = dtype.block_size();
|
||||||
let src = src
|
check_shape(shape, block_size)?;
|
||||||
.to_dtype(crate::DType::F32)?
|
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
|
||||||
.flatten_all()?
|
let elem_count = shape.elem_count();
|
||||||
.to_vec1::<f32>()?;
|
if elem_count % block_size != 0 {
|
||||||
if src.len() % T::BLCK_SIZE != 0 {
|
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"tensor size ({shape:?}) is not divisible by block size {}",
|
"tensor size ({shape:?}) is not divisible by block size {}",
|
||||||
T::BLCK_SIZE
|
block_size
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
let mut storage = src.device().qzeros(elem_count, dtype)?;
|
||||||
T::from_float(&src, &mut data)?;
|
storage.quantize(&src.storage())?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
data: Box::new(data),
|
storage,
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
self.data.dtype()
|
self.storage.dtype()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
@ -213,21 +345,19 @@ impl QTensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
||||||
self.data.to_float(&mut f32_data)?;
|
let none = crate::op::BackpropOp::none();
|
||||||
Tensor::from_vec(f32_data, &self.shape, device)
|
let is_variable = false;
|
||||||
}
|
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
|
||||||
|
.to_device(device)
|
||||||
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
|
||||||
self.data.matmul_t(mkn, lhs, dst)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
pub fn storage_size_in_bytes(&self) -> usize {
|
||||||
self.data.storage_size_in_bytes()
|
self.storage.size_in_bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn as_ptr(&self) -> *const u8 {
|
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
|
||||||
self.data.as_ptr()
|
self.storage.data()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -294,21 +424,97 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
}
|
}
|
||||||
dst_shape.push(n);
|
dst_shape.push(n);
|
||||||
let dst_shape = Shape::from(dst_shape);
|
let dst_shape = Shape::from(dst_shape);
|
||||||
let storage = storage.as_slice::<f32>()?;
|
#[allow(clippy::infallible_destructuring_match)]
|
||||||
let storage =
|
let self_storage = match &self.storage {
|
||||||
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
QStorage::Cpu(storage) => storage,
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
_ => crate::bail!("Invalid storage"),
|
||||||
|
};
|
||||||
|
let slice = storage.as_slice::<f32>()?;
|
||||||
|
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||||
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||||
self.matmul_t(
|
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
|
||||||
(dst_shape.elem_count() / n, k, n),
|
|
||||||
storage,
|
|
||||||
&mut dst_storage,
|
|
||||||
)?;
|
|
||||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &crate::MetalStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(crate::MetalStorage, Shape)> {
|
||||||
|
use crate::MetalError;
|
||||||
|
|
||||||
|
if !layout.is_contiguous() {
|
||||||
|
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||||
|
}
|
||||||
|
let src_shape = layout.shape();
|
||||||
|
// self is transposed so n is first then k.
|
||||||
|
if src_shape.rank() < 2 {
|
||||||
|
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||||
|
}
|
||||||
|
let (n, k) = self.shape.dims2()?;
|
||||||
|
let mut dst_shape = src_shape.dims().to_vec();
|
||||||
|
|
||||||
|
let (b, m) = match dst_shape.len() {
|
||||||
|
3 => (dst_shape[0], dst_shape[1]),
|
||||||
|
2 => (1, dst_shape[0]),
|
||||||
|
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||||
|
};
|
||||||
|
let last_k = dst_shape.pop().unwrap();
|
||||||
|
if last_k != k {
|
||||||
|
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||||
|
}
|
||||||
|
dst_shape.push(n);
|
||||||
|
let dst_shape = Shape::from(dst_shape);
|
||||||
|
let device = storage.device().clone();
|
||||||
|
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||||
|
let (buffer, dtype) = match &self.storage {
|
||||||
|
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
|
||||||
|
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
||||||
|
};
|
||||||
|
let command_buffer = device.command_buffer()?;
|
||||||
|
candle_metal_kernels::call_quantized_matmul_t(
|
||||||
|
device.device(),
|
||||||
|
&command_buffer,
|
||||||
|
device.kernels(),
|
||||||
|
dtype.into(),
|
||||||
|
(b, m, n, k),
|
||||||
|
storage.buffer(),
|
||||||
|
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||||
|
buffer,
|
||||||
|
&dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||||
|
Ok((dst_storage, dst_shape))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QMatMul {
|
#[cfg(feature = "metal")]
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||||
|
fn from(value: GgmlDType) -> Self {
|
||||||
|
match value {
|
||||||
|
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||||
|
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||||
|
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||||
|
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||||
|
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||||
|
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||||
|
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||||
|
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||||
|
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||||
|
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||||
|
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||||
|
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||||
|
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||||
|
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::Module for QMatMul {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||||
Self::Tensor(w) => {
|
Self::Tensor(w) => {
|
||||||
|
@ -12,6 +12,14 @@ use core::arch::arm::*;
|
|||||||
#[cfg(target_arch = "aarch64")]
|
#[cfg(target_arch = "aarch64")]
|
||||||
use core::arch::aarch64::*;
|
use core::arch::aarch64::*;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
|
||||||
|
// TODO: dotprod
|
||||||
|
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
|
||||||
|
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
||||||
|
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
|
||||||
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||||
let qk = QK8_0;
|
let qk = QK8_0;
|
||||||
@ -19,71 +27,39 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
|||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
if nb % 2 != 0 {
|
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
for i in 0..nb {
|
||||||
for i in (0..nb).step_by(2) {
|
|
||||||
let x0 = &xs[i];
|
let x0 = &xs[i];
|
||||||
let x1 = &xs[i + 1];
|
|
||||||
let y0 = &ys[i];
|
let y0 = &ys[i];
|
||||||
let y1 = &ys[i + 1];
|
|
||||||
|
|
||||||
let m4b = vdupq_n_u8(0x0F);
|
let m4b = vdupq_n_u8(0x0F);
|
||||||
let s8b = vdupq_n_s8(0x8);
|
let s8b = vdupq_n_s8(0x8);
|
||||||
|
|
||||||
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
||||||
let v0_1 = vld1q_u8(x1.qs.as_ptr());
|
|
||||||
|
|
||||||
// 4-bit -> 8-bit
|
// 4-bit -> 8-bit
|
||||||
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
||||||
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
||||||
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
|
||||||
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
|
||||||
|
|
||||||
// sub 8
|
// sub 8
|
||||||
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||||
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||||
let v0_1ls = vsubq_s8(v0_1l, s8b);
|
|
||||||
let v0_1hs = vsubq_s8(v0_1h, s8b);
|
|
||||||
|
|
||||||
// load y
|
// load y
|
||||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||||
let v1_1l = vld1q_s8(y1.qs.as_ptr());
|
|
||||||
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
|
|
||||||
|
|
||||||
// TODO: Support dotprod when it's available outside of nightly.
|
|
||||||
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
|
||||||
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
|
||||||
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
|
||||||
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
|
||||||
|
|
||||||
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
|
|
||||||
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
|
||||||
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
|
|
||||||
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
|
||||||
|
|
||||||
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
|
||||||
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
|
||||||
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
|
||||||
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
|
||||||
|
|
||||||
|
let pl0 = vdotq_s32(v0_0ls, v1_0l);
|
||||||
|
let ph0 = vdotq_s32(v0_0hs, v1_0h);
|
||||||
sumv0 = vmlaq_n_f32(
|
sumv0 = vmlaq_n_f32(
|
||||||
sumv0,
|
sumv0,
|
||||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||||
x0.d.to_f32() * y0.d.to_f32(),
|
x0.d.to_f32() * y0.d.to_f32(),
|
||||||
);
|
);
|
||||||
sumv1 = vmlaq_n_f32(
|
|
||||||
sumv1,
|
|
||||||
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
|
|
||||||
x1.d.to_f32() * y1.d.to_f32(),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
Ok(vaddvq_f32(sumv0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,57 +70,29 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
|||||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
let nb = n / QK8_0;
|
let nb = n / QK8_0;
|
||||||
if nb % 2 != 0 {
|
|
||||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
|
||||||
}
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
for i in 0..nb {
|
||||||
for i in (0..nb).step_by(2) {
|
|
||||||
let x0 = &xs[i];
|
let x0 = &xs[i];
|
||||||
let x1 = &xs[i + 1];
|
|
||||||
let y0 = &ys[i];
|
let y0 = &ys[i];
|
||||||
let y1 = &ys[i + 1];
|
|
||||||
|
|
||||||
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
||||||
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
||||||
let x1_0 = vld1q_s8(x1.qs.as_ptr());
|
|
||||||
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
|
|
||||||
|
|
||||||
// load y
|
// load y
|
||||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||||
let y1_0 = vld1q_s8(y1.qs.as_ptr());
|
|
||||||
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
|
|
||||||
|
|
||||||
// TODO dotprod once this is the intrinsics are.
|
let p0 = vdotq_s32(x0_0, y0_0);
|
||||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
let p1 = vdotq_s32(x0_1, y0_1);
|
||||||
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
|
||||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
|
||||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
|
||||||
|
|
||||||
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
|
|
||||||
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
|
||||||
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
|
|
||||||
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
|
||||||
|
|
||||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
|
||||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
|
||||||
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
|
||||||
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
|
||||||
|
|
||||||
sumv0 = vmlaq_n_f32(
|
sumv0 = vmlaq_n_f32(
|
||||||
sumv0,
|
sumv0,
|
||||||
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
||||||
x0.d.to_f32() * y0.d.to_f32(),
|
x0.d.to_f32() * y0.d.to_f32(),
|
||||||
);
|
);
|
||||||
sumv1 = vmlaq_n_f32(
|
|
||||||
sumv1,
|
|
||||||
vcvtq_f32_s32(vaddq_s32(p2, p3)),
|
|
||||||
x1.d.to_f32() * y1.d.to_f32(),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
Ok(vaddvq_f32(sumv0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
|
|||||||
for i in (0..QK_K).step_by(16) {
|
for i in (0..QK_K).step_by(16) {
|
||||||
let xs = vld1q_s8(xs.add(i));
|
let xs = vld1q_s8(xs.add(i));
|
||||||
let ys = vld1q_s8(ys.add(i));
|
let ys = vld1q_s8(ys.add(i));
|
||||||
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
let xy = vdotq_s32(xs, ys);
|
||||||
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
|
||||||
|
|
||||||
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
|
||||||
sum_i = vaddq_s32(sum_i, xy)
|
sum_i = vaddq_s32(sum_i, xy)
|
||||||
}
|
}
|
||||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||||
@ -238,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
|||||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
||||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
||||||
|
|
||||||
// TODO: dotprod
|
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||||
|
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||||
let p0 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
|
||||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
|
||||||
);
|
|
||||||
let p1 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
|
||||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
|
||||||
);
|
|
||||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||||
scale = scale.add(2);
|
scale = scale.add(2);
|
||||||
|
|
||||||
let p2 = vaddq_s16(
|
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
|
||||||
);
|
|
||||||
let p3 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
|
||||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
|
||||||
);
|
|
||||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||||
scale = scale.add(2);
|
scale = scale.add(2);
|
||||||
|
|
||||||
let q8bytes = vld1q_s8_x4(q8);
|
let q8bytes = vld1q_s8_x4(q8);
|
||||||
@ -281,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
|||||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
||||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
||||||
|
|
||||||
// TODO: dotprod case.
|
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||||
let p0 = vaddq_s16(
|
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
|
||||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
|
||||||
);
|
|
||||||
let p1 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
|
||||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
|
||||||
);
|
|
||||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||||
scale = scale.add(2);
|
scale = scale.add(2);
|
||||||
|
|
||||||
let p2 = vaddq_s16(
|
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
|
||||||
);
|
|
||||||
let p3 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
|
||||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
|
||||||
);
|
|
||||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||||
scale = scale.add(2);
|
scale = scale.add(2);
|
||||||
}
|
}
|
||||||
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
||||||
@ -380,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
|||||||
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
||||||
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
||||||
|
|
||||||
// TODO: dotprod
|
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
|
||||||
|
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
|
||||||
let p0 = vaddq_s16(
|
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
|
||||||
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
|
|
||||||
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
|
|
||||||
);
|
|
||||||
let p1 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
|
|
||||||
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
|
|
||||||
);
|
|
||||||
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
|
|
||||||
scales = scales.add(1);
|
scales = scales.add(1);
|
||||||
|
|
||||||
let p2 = vaddq_s16(
|
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
|
||||||
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
|
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
|
||||||
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
|
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
|
||||||
);
|
|
||||||
let p3 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
|
|
||||||
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
|
|
||||||
);
|
|
||||||
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
|
|
||||||
scales = scales.add(1);
|
scales = scales.add(1);
|
||||||
}
|
}
|
||||||
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
||||||
@ -464,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
|||||||
for j in 0..QK_K / 64 {
|
for j in 0..QK_K / 64 {
|
||||||
let q4bits = vld1q_u8_x2(q4);
|
let q4bits = vld1q_u8_x2(q4);
|
||||||
q4 = q4.add(32);
|
q4 = q4.add(32);
|
||||||
// TODO: dotprod
|
|
||||||
let q8bytes = vld1q_s8_x2(q8);
|
let q8bytes = vld1q_s8_x2(q8);
|
||||||
q8 = q8.add(32);
|
q8 = q8.add(32);
|
||||||
let q4bytes = int8x16x2_t(
|
let q4bytes = int8x16x2_t(
|
||||||
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
||||||
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
||||||
);
|
);
|
||||||
let p0 = vaddq_s16(
|
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
|
||||||
);
|
|
||||||
let p1 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
|
||||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
|
||||||
);
|
|
||||||
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
|
|
||||||
|
|
||||||
let q8bytes = vld1q_s8_x2(q8);
|
let q8bytes = vld1q_s8_x2(q8);
|
||||||
q8 = q8.add(32);
|
q8 = q8.add(32);
|
||||||
@ -487,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
|||||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
||||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
||||||
);
|
);
|
||||||
let p2 = vaddq_s16(
|
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
|
||||||
);
|
|
||||||
let p3 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
|
||||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
|
||||||
);
|
|
||||||
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
|
|
||||||
}
|
}
|
||||||
sumf += d * (sumi1 + sumi2) as f32;
|
sumf += d * (sumi1 + sumi2) as f32;
|
||||||
}
|
}
|
||||||
@ -573,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
|||||||
vreinterpretq_s8_u8(q3h_3),
|
vreinterpretq_s8_u8(q3h_3),
|
||||||
);
|
);
|
||||||
|
|
||||||
// TODO: dotprod
|
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
|
||||||
let p0 = vaddq_s16(
|
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
|
||||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
|
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
|
||||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
|
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
|
||||||
);
|
isum += vaddvq_s32(p0) * *scale as i32
|
||||||
let p1 = vaddq_s16(
|
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
|
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
|
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||||
);
|
|
||||||
let p2 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
|
|
||||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
|
|
||||||
);
|
|
||||||
let p3 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
|
|
||||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
|
|
||||||
);
|
|
||||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
|
||||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
|
||||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
|
||||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
|
||||||
scale = scale.add(4);
|
scale = scale.add(4);
|
||||||
|
|
||||||
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
||||||
@ -618,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
|||||||
vreinterpretq_s8_u8(q3h_3),
|
vreinterpretq_s8_u8(q3h_3),
|
||||||
);
|
);
|
||||||
|
|
||||||
// TODO: dotprod
|
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
|
||||||
let p0 = vaddq_s16(
|
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
|
||||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
|
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
|
||||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
|
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
|
||||||
);
|
isum += vaddvq_s32(p0) * *scale as i32
|
||||||
let p1 = vaddq_s16(
|
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
|
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
|
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||||
);
|
|
||||||
let p2 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
|
|
||||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
|
|
||||||
);
|
|
||||||
let p3 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
|
|
||||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
|
|
||||||
);
|
|
||||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
|
||||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
|
||||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
|
||||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
|
||||||
scale = scale.add(4);
|
scale = scale.add(4);
|
||||||
|
|
||||||
if j == 0 {
|
if j == 0 {
|
||||||
@ -696,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
|
|||||||
let mut is = 0usize;
|
let mut is = 0usize;
|
||||||
|
|
||||||
// TODO: dotprod
|
// TODO: dotprod
|
||||||
|
|
||||||
for _j in 0..QK_K / 128 {
|
for _j in 0..QK_K / 128 {
|
||||||
let q2bits = vld1q_u8_x2(q2);
|
let q2bits = vld1q_u8_x2(q2);
|
||||||
q2 = q2.add(32);
|
q2 = q2.add(32);
|
||||||
@ -743,14 +607,7 @@ unsafe fn multiply_accum_with_scale(
|
|||||||
q2bytes: int8x16x2_t,
|
q2bytes: int8x16x2_t,
|
||||||
q8bytes: int8x16x2_t,
|
q8bytes: int8x16x2_t,
|
||||||
) -> i32 {
|
) -> i32 {
|
||||||
let p1 = vaddq_s16(
|
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
|
||||||
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
|
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
|
||||||
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
|
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
|
||||||
);
|
|
||||||
let p2 = vaddq_s16(
|
|
||||||
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
|
|
||||||
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
|
|
||||||
);
|
|
||||||
vaddvq_s16(p1) as i32 * aux[is + index] as i32
|
|
||||||
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
|
|
||||||
}
|
}
|
||||||
|
@ -11,10 +11,6 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
|||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
let nb = n / QK8_0;
|
|
||||||
if nb % 2 != 0 {
|
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
|
||||||
}
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut acc = f32x4_splat(0.0f32);
|
let mut acc = f32x4_splat(0.0f32);
|
||||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
@ -61,10 +57,6 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
|||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
let nb = n / QK8_0;
|
|
||||||
if nb % 2 != 0 {
|
|
||||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
|
||||||
}
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut acc = f32x4_splat(0.0f32);
|
let mut acc = f32x4_splat(0.0f32);
|
||||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
|
@ -203,7 +203,7 @@ impl Shape {
|
|||||||
|
|
||||||
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
||||||
/// broadcasted shape. This is to be used for binary pointwise ops.
|
/// broadcasted shape. This is to be used for binary pointwise ops.
|
||||||
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||||
let lhs = self;
|
let lhs = self;
|
||||||
let lhs_dims = lhs.dims();
|
let lhs_dims = lhs.dims();
|
||||||
let rhs_dims = rhs.dims();
|
let rhs_dims = rhs.dims();
|
||||||
@ -478,23 +478,6 @@ extract_dims!(
|
|||||||
(usize, usize, usize, usize, usize)
|
(usize, usize, usize, usize, usize)
|
||||||
);
|
);
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stride() {
|
|
||||||
let shape = Shape::from(());
|
|
||||||
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
|
|
||||||
let shape = Shape::from(42);
|
|
||||||
assert_eq!(shape.stride_contiguous(), [1]);
|
|
||||||
let shape = Shape::from((42, 1337));
|
|
||||||
assert_eq!(shape.stride_contiguous(), [1337, 1]);
|
|
||||||
let shape = Shape::from((299, 792, 458));
|
|
||||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait ShapeWithOneHole {
|
pub trait ShapeWithOneHole {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape>;
|
fn into_shape(self, el_count: usize) -> Result<Shape>;
|
||||||
}
|
}
|
||||||
@ -627,3 +610,20 @@ impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
|||||||
Ok((d1, d2, d3, d4, d).into())
|
Ok((d1, d2, d3, d4, d).into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stride() {
|
||||||
|
let shape = Shape::from(());
|
||||||
|
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
|
||||||
|
let shape = Shape::from(42);
|
||||||
|
assert_eq!(shape.stride_contiguous(), [1]);
|
||||||
|
let shape = Shape::from((42, 1337));
|
||||||
|
assert_eq!(shape.stride_contiguous(), [1337, 1]);
|
||||||
|
let shape = Shape::from((299, 792, 458));
|
||||||
|
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||||
|
|
||||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||||
// out of memory. Instead try_clone should be used.
|
// out of memory. Instead try_clone should be used.
|
||||||
@ -8,6 +8,7 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape
|
|||||||
pub enum Storage {
|
pub enum Storage {
|
||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
Cuda(CudaStorage),
|
Cuda(CudaStorage),
|
||||||
|
Metal(MetalStorage),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Storage {
|
impl Storage {
|
||||||
@ -18,6 +19,10 @@ impl Storage {
|
|||||||
let storage = storage.try_clone(layout)?;
|
let storage = storage.try_clone(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.try_clone(layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -25,6 +30,7 @@ impl Storage {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu(_) => Device::Cpu,
|
Self::Cpu(_) => Device::Cpu,
|
||||||
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||||
|
Self::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,6 +38,7 @@ impl Storage {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu(storage) => storage.dtype(),
|
Self::Cpu(storage) => storage.dtype(),
|
||||||
Self::Cuda(storage) => storage.dtype(),
|
Self::Cuda(storage) => storage.dtype(),
|
||||||
|
Self::Metal(storage) => storage.dtype(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,6 +72,10 @@ impl Storage {
|
|||||||
let storage = storage.affine(layout, mul, add)?;
|
let storage = storage.affine(layout, mul, add)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.affine(layout, mul, add)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,6 +89,10 @@ impl Storage {
|
|||||||
let storage = storage.powf(layout, alpha)?;
|
let storage = storage.powf(layout, alpha)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.powf(layout, alpha)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,6 +106,10 @@ impl Storage {
|
|||||||
let storage = storage.elu(layout, alpha)?;
|
let storage = storage.elu(layout, alpha)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.elu(layout, alpha)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,6 +131,10 @@ impl Storage {
|
|||||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => {
|
(lhs, rhs) => {
|
||||||
// Should not happen because of the same device check above but we're defensive
|
// Should not happen because of the same device check above but we're defensive
|
||||||
// anyway.
|
// anyway.
|
||||||
@ -135,6 +158,10 @@ impl Storage {
|
|||||||
let storage = storage.reduce_op(op, layout, s)?;
|
let storage = storage.reduce_op(op, layout, s)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.reduce_op(op, layout, s)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,6 +175,10 @@ impl Storage {
|
|||||||
let storage = storage.to_dtype(layout, dtype)?;
|
let storage = storage.to_dtype(layout, dtype)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.to_dtype(layout, dtype)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,6 +192,10 @@ impl Storage {
|
|||||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||||
Ok((Self::Cuda(storage), shape))
|
Ok((Self::Cuda(storage), shape))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let (storage, shape) = c.metal_fwd(storage, l)?;
|
||||||
|
Ok((Self::Metal(storage), shape))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,6 +216,10 @@ impl Storage {
|
|||||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||||
Ok((Self::Cuda(s), shape))
|
Ok((Self::Cuda(s), shape))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s1), Self::Metal(s2)) => {
|
||||||
|
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
|
||||||
|
Ok((Self::Metal(s), shape))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -205,6 +244,10 @@ impl Storage {
|
|||||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||||
Ok((Self::Cuda(s), shape))
|
Ok((Self::Cuda(s), shape))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
||||||
|
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||||
|
Ok((Self::Metal(s), shape))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -219,6 +262,10 @@ impl Storage {
|
|||||||
let storage = storage.unary_impl::<B>(layout)?;
|
let storage = storage.unary_impl::<B>(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.unary_impl::<B>(layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -239,6 +286,10 @@ impl Storage {
|
|||||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => {
|
(lhs, rhs) => {
|
||||||
// Should not happen because of the same device check above but we're defensive
|
// Should not happen because of the same device check above but we're defensive
|
||||||
// anyway.
|
// anyway.
|
||||||
@ -270,6 +321,10 @@ impl Storage {
|
|||||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -279,6 +334,33 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
l: &Layout,
|
||||||
|
kernel: &Self,
|
||||||
|
kernel_l: &Layout,
|
||||||
|
params: &crate::conv::ParamsConvTranspose1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.same_device(kernel, "conv-transpose1d")?;
|
||||||
|
self.same_dtype(kernel, "conv-transpose1d")?;
|
||||||
|
match (self, &kernel) {
|
||||||
|
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||||
|
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Cpu(s))
|
||||||
|
}
|
||||||
|
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||||
|
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Cuda(s))
|
||||||
|
}
|
||||||
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: lhs.device().location(),
|
||||||
|
rhs: rhs.device().location(),
|
||||||
|
op: "conv-transpose1d",
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn conv2d(
|
pub(crate) fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
@ -297,6 +379,10 @@ impl Storage {
|
|||||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -324,6 +410,10 @@ impl Storage {
|
|||||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -348,6 +438,10 @@ impl Storage {
|
|||||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -366,6 +460,10 @@ impl Storage {
|
|||||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -379,6 +477,10 @@ impl Storage {
|
|||||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -392,6 +494,10 @@ impl Storage {
|
|||||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
Self::Metal(storage) => {
|
||||||
|
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -415,6 +521,10 @@ impl Storage {
|
|||||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
|
||||||
|
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -441,6 +551,10 @@ impl Storage {
|
|||||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes)) => {
|
||||||
|
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -465,6 +579,10 @@ impl Storage {
|
|||||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||||
|
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -489,6 +607,10 @@ impl Storage {
|
|||||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||||
|
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -510,6 +632,10 @@ impl Storage {
|
|||||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -537,6 +663,10 @@ impl Storage {
|
|||||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
|
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||||
|
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Metal(storage))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -556,6 +686,9 @@ impl Storage {
|
|||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||||
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||||
|
(Self::Metal(src), Self::Metal(dst)) => {
|
||||||
|
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
//! Tensors are N-dimenional matrixes of elements using a single data type.
|
//! Tensors are N-dimensional matrixes of elements using a single data type.
|
||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{
|
use crate::op::{
|
||||||
@ -6,7 +6,7 @@ use crate::op::{
|
|||||||
};
|
};
|
||||||
use crate::scalar::TensorOrScalar;
|
use crate::scalar::TensorOrScalar;
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
/// Unique identifier for tensors.
|
/// Unique identifier for tensors.
|
||||||
@ -361,6 +361,16 @@ impl Tensor {
|
|||||||
Self::new_impl(array, shape, device, false)
|
Self::new_impl(array, shape, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a new tensor with all the elements having the same specified value. Note that
|
||||||
|
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
|
||||||
|
pub fn full<D: crate::WithDType, S: Into<Shape>>(
|
||||||
|
value: D,
|
||||||
|
shape: S,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates a new 1D tensor from an iterator.
|
/// Creates a new 1D tensor from an iterator.
|
||||||
pub fn from_iter<D: crate::WithDType>(
|
pub fn from_iter<D: crate::WithDType>(
|
||||||
iter: impl IntoIterator<Item = D>,
|
iter: impl IntoIterator<Item = D>,
|
||||||
@ -385,12 +395,22 @@ impl Tensor {
|
|||||||
step: D,
|
step: D,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
if D::is_zero(&step) {
|
||||||
|
bail!("step cannot be zero")
|
||||||
|
}
|
||||||
let mut data = vec![];
|
let mut data = vec![];
|
||||||
let mut current = start;
|
let mut current = start;
|
||||||
|
if step >= D::zero() {
|
||||||
while current < end {
|
while current < end {
|
||||||
data.push(current);
|
data.push(current);
|
||||||
current += step;
|
current += step;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
while current > end {
|
||||||
|
data.push(current);
|
||||||
|
current += step;
|
||||||
|
}
|
||||||
|
}
|
||||||
let len = data.len();
|
let len = data.len();
|
||||||
Self::from_vec_impl(data, len, device, false)
|
Self::from_vec_impl(data, len, device, false)
|
||||||
}
|
}
|
||||||
@ -467,6 +487,12 @@ impl Tensor {
|
|||||||
broadcast_binary_op!(broadcast_div, div);
|
broadcast_binary_op!(broadcast_div, div);
|
||||||
broadcast_binary_op!(broadcast_maximum, maximum);
|
broadcast_binary_op!(broadcast_maximum, maximum);
|
||||||
broadcast_binary_op!(broadcast_minimum, minimum);
|
broadcast_binary_op!(broadcast_minimum, minimum);
|
||||||
|
broadcast_binary_op!(broadcast_eq, eq);
|
||||||
|
broadcast_binary_op!(broadcast_ne, ne);
|
||||||
|
broadcast_binary_op!(broadcast_lt, lt);
|
||||||
|
broadcast_binary_op!(broadcast_le, le);
|
||||||
|
broadcast_binary_op!(broadcast_gt, gt);
|
||||||
|
broadcast_binary_op!(broadcast_ge, ge);
|
||||||
|
|
||||||
unary_op!(recip, Recip);
|
unary_op!(recip, Recip);
|
||||||
unary_op!(neg, Neg);
|
unary_op!(neg, Neg);
|
||||||
@ -513,6 +539,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -652,7 +679,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Split a tensor into the specified number of chunks, this may return less chunks than
|
/// Split a tensor into the specified number of chunks, this may return less chunks than
|
||||||
/// specificed.
|
/// specified.
|
||||||
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
|
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
|
||||||
let dim = dim.to_index(self.shape(), "chunk")?;
|
let dim = dim.to_index(self.shape(), "chunk")?;
|
||||||
let size = self.dim(dim)?;
|
let size = self.dim(dim)?;
|
||||||
@ -839,6 +866,20 @@ impl Tensor {
|
|||||||
self.sum_impl(mean_dims, false)? * scale
|
self.sum_impl(mean_dims, false)? * scale
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the unbiased variance over the selected dimension.
|
||||||
|
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "var")?;
|
||||||
|
let mean = self.mean_keepdim(dim)?;
|
||||||
|
let squares = self.broadcast_sub(&mean)?.sqr()?;
|
||||||
|
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the unbiased variance over the selected dimension.
|
||||||
|
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "var")?;
|
||||||
|
self.var_keepdim(dim)?.squeeze(dim)
|
||||||
|
}
|
||||||
|
|
||||||
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
||||||
/// number of dimensions as the original tensor and the select dimension has a single element.
|
/// number of dimensions as the original tensor and the select dimension has a single element.
|
||||||
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
@ -963,7 +1004,11 @@ impl Tensor {
|
|||||||
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
||||||
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||||
let (n, c, _h, _w) = self.dims4()?;
|
let (n, c, _h, _w) = self.dims4()?;
|
||||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
|
||||||
|
arg,
|
||||||
|
target_h,
|
||||||
|
target_w,
|
||||||
|
});
|
||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.upsample_nearest2d(self.layout(), target_h, target_w)?;
|
.upsample_nearest2d(self.layout(), target_h, target_w)?;
|
||||||
@ -996,6 +1041,9 @@ impl Tensor {
|
|||||||
let kernel_size = kernel_size.to_usize2();
|
let kernel_size = kernel_size.to_usize2();
|
||||||
let stride = stride.to_usize2();
|
let stride = stride.to_usize2();
|
||||||
let (n, c, h, w) = self.dims4()?;
|
let (n, c, h, w) = self.dims4()?;
|
||||||
|
if h < kernel_size.0 || w < kernel_size.1 {
|
||||||
|
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
|
||||||
|
}
|
||||||
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
||||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||||
@ -1031,6 +1079,9 @@ impl Tensor {
|
|||||||
let kernel_size = kernel_size.to_usize2();
|
let kernel_size = kernel_size.to_usize2();
|
||||||
let stride = stride.to_usize2();
|
let stride = stride.to_usize2();
|
||||||
let (n, c, h, w) = self.dims4()?;
|
let (n, c, h, w) = self.dims4()?;
|
||||||
|
if h < kernel_size.0 || w < kernel_size.1 {
|
||||||
|
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
|
||||||
|
}
|
||||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||||
@ -1186,14 +1237,16 @@ impl Tensor {
|
|||||||
op: "scatter-add (self, src)",
|
op: "scatter-add (self, src)",
|
||||||
lhs: self.shape().clone(),
|
lhs: self.shape().clone(),
|
||||||
rhs: source.shape().clone(),
|
rhs: source.shape().clone(),
|
||||||
})?
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
if indexes.dims() != source.dims() {
|
if indexes.dims() != source.dims() {
|
||||||
Err(Error::ShapeMismatchBinaryOp {
|
Err(Error::ShapeMismatchBinaryOp {
|
||||||
op: "scatter-add (indexes, src)",
|
op: "scatter-add (indexes, src)",
|
||||||
lhs: indexes.shape().clone(),
|
lhs: indexes.shape().clone(),
|
||||||
rhs: source.shape().clone(),
|
rhs: source.shape().clone(),
|
||||||
})?
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
let storage = self.storage().scatter_add(
|
let storage = self.storage().scatter_add(
|
||||||
self.layout(),
|
self.layout(),
|
||||||
@ -1265,7 +1318,8 @@ impl Tensor {
|
|||||||
op: "slice-scatter (self, src)",
|
op: "slice-scatter (self, src)",
|
||||||
lhs: self.shape().clone(),
|
lhs: self.shape().clone(),
|
||||||
rhs: src.shape().clone(),
|
rhs: src.shape().clone(),
|
||||||
})?
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||||
self.storage()
|
self.storage()
|
||||||
@ -1299,7 +1353,8 @@ impl Tensor {
|
|||||||
op: "index-add (self, source)",
|
op: "index-add (self, source)",
|
||||||
lhs: self.shape().clone(),
|
lhs: self.shape().clone(),
|
||||||
rhs: source.shape().clone(),
|
rhs: source.shape().clone(),
|
||||||
})?
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
// The number of element in indexes must match the dimension on which the add is
|
// The number of element in indexes must match the dimension on which the add is
|
||||||
// performed on the source tensor (and the index values from `indexes` are taken from
|
// performed on the source tensor (and the index values from `indexes` are taken from
|
||||||
@ -1310,7 +1365,8 @@ impl Tensor {
|
|||||||
op: "index-add (ids, source))",
|
op: "index-add (ids, source))",
|
||||||
lhs: indexes.shape().clone(),
|
lhs: indexes.shape().clone(),
|
||||||
rhs: source.shape().clone(),
|
rhs: source.shape().clone(),
|
||||||
})?
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
let storage = self.storage().index_add(
|
let storage = self.storage().index_add(
|
||||||
self.layout(),
|
self.layout(),
|
||||||
@ -1358,7 +1414,8 @@ impl Tensor {
|
|||||||
op: "gather",
|
op: "gather",
|
||||||
lhs: self.shape().clone(),
|
lhs: self.shape().clone(),
|
||||||
rhs: indexes.shape().clone(),
|
rhs: indexes.shape().clone(),
|
||||||
})?
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
let storage =
|
let storage =
|
||||||
self.storage()
|
self.storage()
|
||||||
@ -1432,6 +1489,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1462,6 +1520,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1502,6 +1561,7 @@ impl Tensor {
|
|||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
|
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1744,7 +1804,7 @@ impl Tensor {
|
|||||||
let is_permutation =
|
let is_permutation =
|
||||||
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
|
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
|
||||||
if !is_permutation {
|
if !is_permutation {
|
||||||
crate::bail!(
|
bail!(
|
||||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||||
self.dims(),
|
self.dims(),
|
||||||
dims
|
dims
|
||||||
@ -1791,7 +1851,12 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
||||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||||
|
///
|
||||||
|
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||||
pub fn detach(&self) -> Result<Tensor> {
|
pub fn detach(&self) -> Result<Tensor> {
|
||||||
|
if self.op.is_none() && !self.is_variable {
|
||||||
|
Ok(self.clone())
|
||||||
|
} else {
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: self.storage.clone(),
|
storage: self.storage.clone(),
|
||||||
@ -1803,6 +1868,7 @@ impl Tensor {
|
|||||||
};
|
};
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
||||||
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
|
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
|
||||||
@ -1813,7 +1879,11 @@ impl Tensor {
|
|||||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||||
}
|
}
|
||||||
|
(Storage::Cpu(storage), Device::Metal(metal)) => {
|
||||||
|
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||||
|
}
|
||||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||||
|
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||||
// are the same.
|
// are the same.
|
||||||
@ -1821,6 +1891,9 @@ impl Tensor {
|
|||||||
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
||||||
}
|
}
|
||||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||||
|
_ => {
|
||||||
|
bail!("not implemented yet")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
@ -2226,7 +2299,7 @@ impl Tensor {
|
|||||||
if left == 0 && right == 0 {
|
if left == 0 && right == 0 {
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
} else if self.elem_count() == 0 {
|
} else if self.elem_count() == 0 {
|
||||||
crate::bail!("cannot use pad_with_same on an empty tensor")
|
bail!("cannot use pad_with_same on an empty tensor")
|
||||||
} else if left == 0 {
|
} else if left == 0 {
|
||||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||||
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||||
@ -2265,6 +2338,11 @@ impl Tensor {
|
|||||||
m.forward(self)
|
m.forward(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run the `forward` method of `m` on `self`.
|
||||||
|
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
|
||||||
|
m.forward_t(self, train)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||||
self.storage.read().unwrap()
|
self.storage.read().unwrap()
|
||||||
}
|
}
|
||||||
@ -2379,6 +2457,142 @@ impl Tensor {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||||
|
/// values means counting the dimensions from the back.
|
||||||
|
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||||
|
let rank = self.rank() as i64;
|
||||||
|
if rank <= axis {
|
||||||
|
bail!("axis {axis} is too large, tensor rank {rank}")
|
||||||
|
} else if 0 <= axis {
|
||||||
|
Ok(axis as usize)
|
||||||
|
} else {
|
||||||
|
let naxis = rank + axis;
|
||||||
|
if naxis < 0 {
|
||||||
|
bail!("axis {axis} is too small, tensor rank {rank}")
|
||||||
|
}
|
||||||
|
Ok(naxis as usize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a lower triangular matrix of ones of size n by n.
|
||||||
|
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||||
|
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||||
|
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||||
|
t1.le(&t2)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an upper triangular matrix of ones of size n by n.
|
||||||
|
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||||
|
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||||
|
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||||
|
t1.ge(&t2)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a matrix with a diagonal of ones of size n by n.
|
||||||
|
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||||
|
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||||
|
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||||
|
t1.eq(&t2)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the cumulative sum of elements of the input tensor summed over the specified
|
||||||
|
/// dimension.
|
||||||
|
///
|
||||||
|
/// This operation is most efficient when dim is the last dimension of the tensor.
|
||||||
|
pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "cumsum")?;
|
||||||
|
let rank = self.rank();
|
||||||
|
if rank == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
|
let n_axis = self.dim(dim)?;
|
||||||
|
let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
|
||||||
|
if rank == 1 {
|
||||||
|
self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
|
||||||
|
} else {
|
||||||
|
let last = rank - 1;
|
||||||
|
let t = self.transpose(dim, last)?;
|
||||||
|
let t = t.broadcast_matmul(&triu)?;
|
||||||
|
t.transpose(dim, last)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a copy of `self` where the values within `ranges` have been replaced with the
|
||||||
|
/// content of `src`.
|
||||||
|
pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
|
||||||
|
&self,
|
||||||
|
ranges: &[D],
|
||||||
|
src: &Tensor,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let src_dims = src.dims();
|
||||||
|
let self_dims = self.dims();
|
||||||
|
if self_dims.len() != src_dims.len() {
|
||||||
|
bail!(
|
||||||
|
"slice-assign requires input with the same rank {} <> {}",
|
||||||
|
self_dims.len(),
|
||||||
|
src_dims.len()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if self_dims.len() != ranges.len() {
|
||||||
|
bail!(
|
||||||
|
"slice-assign requires input with the same rank as there are ranges {} <> {}",
|
||||||
|
self_dims.len(),
|
||||||
|
ranges.len()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let mut src = src.clone();
|
||||||
|
let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
|
||||||
|
for (i, range) in ranges.iter().enumerate() {
|
||||||
|
let start_included = match range.start_bound() {
|
||||||
|
std::ops::Bound::Unbounded => 0,
|
||||||
|
std::ops::Bound::Included(v) => *v,
|
||||||
|
std::ops::Bound::Excluded(v) => *v + 1,
|
||||||
|
};
|
||||||
|
let end_excluded = match range.end_bound() {
|
||||||
|
std::ops::Bound::Unbounded => self_dims[i],
|
||||||
|
std::ops::Bound::Included(v) => *v + 1,
|
||||||
|
std::ops::Bound::Excluded(v) => *v,
|
||||||
|
};
|
||||||
|
if end_excluded <= start_included {
|
||||||
|
bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
|
||||||
|
}
|
||||||
|
if self_dims[i] < end_excluded {
|
||||||
|
bail!(
|
||||||
|
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
|
||||||
|
self_dims[i]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if end_excluded - start_included != src_dims[i] {
|
||||||
|
bail!(
|
||||||
|
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
|
||||||
|
mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
|
||||||
|
}
|
||||||
|
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns log(sum(exp(tensor), dim)).
|
||||||
|
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||||
|
let exp = self.exp()?;
|
||||||
|
let sum = exp.sum(sum_dims)?;
|
||||||
|
sum.log()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pointwise pow operation.
|
||||||
|
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||||
|
rhs.mul(&self.log()?)?.exp()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Broadcasting version of `pow`.
|
||||||
|
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||||
|
rhs.broadcast_mul(&self.log()?)?.exp()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
@ -4,7 +4,7 @@ use crate::{Result, Tensor};
|
|||||||
macro_rules! test_device {
|
macro_rules! test_device {
|
||||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
|
||||||
#[test]
|
#[test]
|
||||||
fn $test_cpu() -> Result<()> {
|
fn $test_cpu() -> Result<()> {
|
||||||
$fn_name(&Device::Cpu)
|
$fn_name(&Device::Cpu)
|
||||||
@ -15,6 +15,12 @@ macro_rules! test_device {
|
|||||||
fn $test_cuda() -> Result<()> {
|
fn $test_cuda() -> Result<()> {
|
||||||
$fn_name(&Device::new_cuda(0)?)
|
$fn_name(&Device::new_cuda(0)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
#[test]
|
||||||
|
fn $test_metal() -> Result<()> {
|
||||||
|
$fn_name(&Device::new_metal(0)?)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,6 +23,10 @@ pub fn cuda_is_available() -> bool {
|
|||||||
cfg!(feature = "cuda")
|
cfg!(feature = "cuda")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn metal_is_available() -> bool {
|
||||||
|
cfg!(feature = "metal")
|
||||||
|
}
|
||||||
|
|
||||||
pub fn with_avx() -> bool {
|
pub fn with_avx() -> bool {
|
||||||
cfg!(target_feature = "avx")
|
cfg!(target_feature = "avx")
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,11 @@ res = torch.nn.functional.conv1d(t, w)
|
|||||||
print(res.flatten())
|
print(res.flatten())
|
||||||
res = torch.nn.functional.conv1d(t, w, padding=1)
|
res = torch.nn.functional.conv1d(t, w, padding=1)
|
||||||
print(res.flatten())
|
print(res.flatten())
|
||||||
|
|
||||||
|
w_t = w.transpose(0, 1)
|
||||||
|
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||||
|
print(res.shape)
|
||||||
|
print(res)
|
||||||
*/
|
*/
|
||||||
fn conv1d(dev: &Device) -> Result<()> {
|
fn conv1d(dev: &Device) -> Result<()> {
|
||||||
let t = Tensor::new(
|
let t = Tensor::new(
|
||||||
@ -45,6 +50,17 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||||
);
|
);
|
||||||
|
if dev.is_cpu() {
|
||||||
|
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
|
assert_eq!(res.dims(), [1, 2, 7]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
|
[
|
||||||
|
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||||
|
4.7076, -5.9745, -0.8276, 1.621
|
||||||
|
],
|
||||||
|
);
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -479,17 +495,103 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
|
||||||
|
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
|
||||||
|
let loss = res.sqr()?.sum_all()?;
|
||||||
|
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
let grad_t = grads.get(&t).unwrap();
|
||||||
|
let grad_w = grads.get(&w).unwrap();
|
||||||
|
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||||
|
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[9.29, -7.03, 7.87, 0.0, 0.0],
|
||||||
|
[-1.8, -7.82, 5.9, 0.0, 0.0],
|
||||||
|
[-3.12, 4.49, 5.52, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[21.73, 3.39, 4.77, 0.0, 0.0],
|
||||||
|
[8.25, 3.73, 27.61, 0.0, 0.0],
|
||||||
|
[-20.55, -5.61, -2.77, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-8.98, 9.91, -7.15, 0.0, 0.0],
|
||||||
|
[4.93, -0.33, 4.56, 0.0, 0.0],
|
||||||
|
[-6.7, -5.76, -8.05, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[23.54, 6.98, -10.0, 0.0, 0.0],
|
||||||
|
[9.65, 6.18, 18.72, 0.0, 0.0],
|
||||||
|
[3.29, -5.27, 0.79, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[-3.47, 7.44, 0.66],
|
||||||
|
[12.89, -3.4, -9.29],
|
||||||
|
[-14.16, -0.83, 7.14]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-3.23, 5.37, -3.02],
|
||||||
|
[-2.12, -11.24, 1.94],
|
||||||
|
[6.97, 7.2, 2.99]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-4.04, -3.31, 4.87],
|
||||||
|
[-6.68, -5.68, 1.73],
|
||||||
|
[-5.54, 4.32, 0.52]
|
||||||
|
],
|
||||||
|
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
|
||||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
test_device!(
|
||||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
conv1d_small,
|
||||||
|
conv1d_small_cpu,
|
||||||
|
conv1d_small_gpu,
|
||||||
|
conv1d_small_metal
|
||||||
|
);
|
||||||
|
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
|
||||||
test_device!(
|
test_device!(
|
||||||
conv2d_non_square,
|
conv2d_non_square,
|
||||||
conv2d_non_square_cpu,
|
conv2d_non_square_cpu,
|
||||||
conv2d_non_square_gpu
|
conv2d_non_square_gpu,
|
||||||
|
conv2d_non_square_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
conv2d_small,
|
||||||
|
conv2d_small_cpu,
|
||||||
|
conv2d_small_gpu,
|
||||||
|
conv2d_small_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
conv2d_smaller,
|
||||||
|
conv2d_smaller_cpu,
|
||||||
|
conv2d_smaller_gpu,
|
||||||
|
conv2d_smaller_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
conv2d_grad,
|
||||||
|
conv2d_grad_cpu,
|
||||||
|
conv2d_grad_gpu,
|
||||||
|
conv2_grad_metal
|
||||||
);
|
);
|
||||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
|
||||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
|
||||||
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
|
||||||
|
@ -205,6 +205,231 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
[1.0116, 1.0830, 1.0003, 0.6188],
|
[1.0116, 1.0830, 1.0003, 0.6188],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Testing compared to pytorch torch.erf
|
||||||
|
//
|
||||||
|
// import torch
|
||||||
|
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||||
|
// y = x.erf()
|
||||||
|
// print(y)
|
||||||
|
// loss = y.sum()
|
||||||
|
// loss.backward()
|
||||||
|
// print(x.grad)
|
||||||
|
let y = x.erf()?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[0.0001, 0.4151, 0.0, 1.1033],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Testing compared to pytorch nn.GELU(approximate = 'none')
|
||||||
|
//
|
||||||
|
// import torch
|
||||||
|
// import torch.nn.functional as F
|
||||||
|
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||||
|
// y = F.gelu(x, approximate='none')
|
||||||
|
// print(y)
|
||||||
|
// loss = y.sum()
|
||||||
|
// loss.backward()
|
||||||
|
// print(x.grad)
|
||||||
|
let y = x.gelu_erf()?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[2.9960, 0.8413, 3.9999, 0.0839]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[1.0119, 1.0833, 1.0005, 0.6188],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Testing compared to pytorch elu
|
||||||
|
//
|
||||||
|
// import torch
|
||||||
|
// import torch.nn.functional as F
|
||||||
|
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
|
||||||
|
// y = F.elu(x, alpha=2.0)
|
||||||
|
// print(y)
|
||||||
|
// loss = y.min
|
||||||
|
// loss = y.sum()
|
||||||
|
// loss.backward()
|
||||||
|
// print(x.grad)
|
||||||
|
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
|
||||||
|
let y = elu_x.elu(2.)?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||||
|
);
|
||||||
|
|
||||||
|
// manually checked: see comments
|
||||||
|
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||||
|
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let z = Tensor::new(
|
||||||
|
&[
|
||||||
|
1_f32, 02., 03., 04., 05., 06.,
|
||||||
|
07., 08., 09., 10., 11., 12.,
|
||||||
|
13., 14., 15., 16., 17., 18.,
|
||||||
|
19., 20., 21., 22., 23., 24.,
|
||||||
|
25., 26., 27., 28., 29., 30.,
|
||||||
|
31., 32., 33., 34., 35., 36.,
|
||||||
|
],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
// gradient should be
|
||||||
|
// row 1
|
||||||
|
// 1+2+7+8 = 18
|
||||||
|
// 3+4+9+10 = 26
|
||||||
|
// 5+6+11+12 = 34
|
||||||
|
// row 2
|
||||||
|
// 13+14+19+20 = 66
|
||||||
|
// 15+16+21+22 = 74
|
||||||
|
// 17+18+23+24 = 82
|
||||||
|
// row 3
|
||||||
|
// 25+26+31+32 = 114
|
||||||
|
// 27+28+33+34 = 122
|
||||||
|
// 29+30+35+36 = 130
|
||||||
|
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||||
|
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
||||||
|
[[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]]
|
||||||
|
);
|
||||||
|
|
||||||
|
// manually checked: see comments
|
||||||
|
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
||||||
|
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let z = Tensor::new(
|
||||||
|
&[
|
||||||
|
1_f32, 02., 03., 04., 05., 06.,
|
||||||
|
07., 08., 09., 10., 11., 12.,
|
||||||
|
13., 14., 15., 16., 17., 18.,
|
||||||
|
19., 20., 21., 22., 23., 24.,
|
||||||
|
25., 26., 27., 28., 29., 30.,
|
||||||
|
31., 32., 33., 34., 35., 36.,
|
||||||
|
],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
// gradient should be
|
||||||
|
// row 1
|
||||||
|
// 1+2+3+7+8+9+13+14+15 = 72
|
||||||
|
// 4+5+6+10+11+12+16+17+18 = 99
|
||||||
|
// row 2
|
||||||
|
// 19+20+21+25+26+27+31+32+33 = 234
|
||||||
|
// 22+23+24+28+29+30+34+35+36 = 243
|
||||||
|
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||||
|
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
||||||
|
[[72_f32, 99.], [234., 261.]]
|
||||||
|
);
|
||||||
|
|
||||||
|
// manually checked: see comments
|
||||||
|
let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?;
|
||||||
|
|
||||||
|
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let z = Tensor::new(
|
||||||
|
&[
|
||||||
|
1_f32, 02., 03., 04.,
|
||||||
|
05., 06., 07., 08.,
|
||||||
|
09., 10., 11., 12.,
|
||||||
|
13., 14., 15., 16.,
|
||||||
|
17., 18., 19., 20.,
|
||||||
|
21., 22., 23., 24.,
|
||||||
|
25., 26., 27., 28.,
|
||||||
|
29., 30., 31., 32.
|
||||||
|
],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
// gradient should be
|
||||||
|
// m1r1
|
||||||
|
// 1+2+5+6=14
|
||||||
|
// 3+4+7+8=22
|
||||||
|
// m1r2
|
||||||
|
// 9+10+13+14=46
|
||||||
|
// 11+12+15+16=54
|
||||||
|
// m2r1
|
||||||
|
// 17+18+21+22=78
|
||||||
|
// 19+20+23+24=86
|
||||||
|
// m2r2
|
||||||
|
// 25+26+29+30=110
|
||||||
|
// 27+28+31+32=118
|
||||||
|
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||||
|
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
||||||
|
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
||||||
|
);
|
||||||
|
|
||||||
|
// manually checked: see comments
|
||||||
|
let x = Var::new(
|
||||||
|
&[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let z = Tensor::new(
|
||||||
|
&[
|
||||||
|
1_f32, 02., 03., 04.,
|
||||||
|
05., 06., 07., 08.,
|
||||||
|
09., 10., 11., 12.,
|
||||||
|
13., 14., 15., 16.,
|
||||||
|
17., 18., 19., 20.,
|
||||||
|
21., 22., 23., 24.,
|
||||||
|
25., 26., 27., 28.,
|
||||||
|
29., 30., 31., 32.
|
||||||
|
],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
// gradient should be
|
||||||
|
// m1r1
|
||||||
|
// 1+2+5+6=14
|
||||||
|
// 3+4+7+8=22
|
||||||
|
// m1r2
|
||||||
|
// 9+10+13+14=46
|
||||||
|
// 11+12+15+16=54
|
||||||
|
// m2r1
|
||||||
|
// 17+18+21+22=78
|
||||||
|
// 19+20+23+24=86
|
||||||
|
// m2r2
|
||||||
|
// 25+26+29+30=110
|
||||||
|
// 27+28+31+32=118
|
||||||
|
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||||
|
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
||||||
|
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -250,9 +475,29 @@ fn binary_grad(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
test_device!(
|
||||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
simple_grad,
|
||||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
simple_grad_cpu,
|
||||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
simple_grad_gpu,
|
||||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
simple_grad_metal
|
||||||
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
|
);
|
||||||
|
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
|
||||||
|
test_device!(
|
||||||
|
matmul_grad,
|
||||||
|
matmul_grad_cpu,
|
||||||
|
matmul_grad_gpu,
|
||||||
|
matmul_grad_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
grad_descent,
|
||||||
|
grad_descent_cpu,
|
||||||
|
grad_descent_gpu,
|
||||||
|
grad_descent_metal
|
||||||
|
);
|
||||||
|
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
|
||||||
|
test_device!(
|
||||||
|
binary_grad,
|
||||||
|
binary_grad_cpu,
|
||||||
|
binary_grad_gpu,
|
||||||
|
binary_grad_metal
|
||||||
|
);
|
||||||
|
@ -91,3 +91,32 @@ fn index_3d() -> Result<()> {
|
|||||||
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn slice_assign() -> Result<()> {
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
|
||||||
|
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
|
||||||
|
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
|
||||||
|
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
|
||||||
|
assert_eq!(
|
||||||
|
out.to_vec2::<u32>()?,
|
||||||
|
&[
|
||||||
|
[0, 1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 0, 1],
|
||||||
|
[10, 11, 12, 2, 3],
|
||||||
|
[15, 16, 17, 4, 5]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
|
||||||
|
assert_eq!(
|
||||||
|
out.to_vec2::<u32>()?,
|
||||||
|
&[
|
||||||
|
[0, 1, 2, 3, 4],
|
||||||
|
[2, 3, 7, 8, 9],
|
||||||
|
[4, 5, 12, 13, 14],
|
||||||
|
[15, 16, 17, 18, 19]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn strided_blocks() -> Result<()> {
|
fn strided_blocks() -> Result<()> {
|
||||||
|
@ -98,15 +98,17 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
|
||||||
test_device!(
|
test_device!(
|
||||||
avg_pool2d_pytorch,
|
avg_pool2d_pytorch,
|
||||||
avg_pool2d_pytorch_cpu,
|
avg_pool2d_pytorch_cpu,
|
||||||
avg_pool2d_pytorch_gpu
|
avg_pool2d_pytorch_gpu,
|
||||||
|
avg_pool2d_pytorch_metal
|
||||||
);
|
);
|
||||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
|
||||||
test_device!(
|
test_device!(
|
||||||
upsample_nearest2d,
|
upsample_nearest2d,
|
||||||
upsample_nearest2d_cpu,
|
upsample_nearest2d_cpu,
|
||||||
upsample_nearest2d_gpu
|
upsample_nearest2d_gpu,
|
||||||
|
upsample_nearest2d_metal
|
||||||
);
|
);
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
use candle_core::{
|
use candle_core::{
|
||||||
|
bail,
|
||||||
quantized::{self, GgmlDType},
|
quantized::{self, GgmlDType},
|
||||||
|
test_device,
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, Result, Tensor,
|
Device, Module, Result, Tensor,
|
||||||
};
|
};
|
||||||
use quantized::{k_quants, GgmlType};
|
use quantized::{k_quants, GgmlType};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
@ -13,16 +15,48 @@ const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
|
|||||||
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
|
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
|
||||||
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
|
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
|
||||||
|
|
||||||
#[test]
|
fn test_matmul(
|
||||||
fn quantized_matmul() -> Result<()> {
|
device: &Device,
|
||||||
let cpu = &Device::Cpu;
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
|
dtype: GgmlDType,
|
||||||
|
) -> Result<()> {
|
||||||
|
let lhs = (0..(m * k))
|
||||||
|
.map(|v| v as f32 / (m * k) as f32)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let rhs = (0..(k * n))
|
||||||
|
.map(|v| v as f32 / (n * k) as f32)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||||
|
let rhs = Tensor::from_slice(&rhs, (k, n), device)?;
|
||||||
|
let mm = lhs.matmul(&rhs)?;
|
||||||
|
let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?;
|
||||||
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
|
let res = matmul.forward(&lhs)?;
|
||||||
|
|
||||||
|
let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)?
|
||||||
|
.sum_all()?
|
||||||
|
.to_scalar()?;
|
||||||
|
let error = error / (b * m * n) as f32;
|
||||||
|
assert!(
|
||||||
|
error <= 0.02,
|
||||||
|
"Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn quantized_matmul(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let (m, k, n) = (3, 64, 4);
|
let (m, k, n) = (3, 64, 4);
|
||||||
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||||
let mut dst = vec![42.; 3 * 4];
|
let mut dst = vec![42.; 3 * 4];
|
||||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||||
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
|
||||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -32,6 +66,7 @@ fn quantized_matmul() -> Result<()> {
|
|||||||
341876.0, 994283.0, 1655709.0, 2301518.0
|
341876.0, 994283.0, 1655709.0, 2301518.0
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
mm.to_vec2::<f32>()?,
|
mm.to_vec2::<f32>()?,
|
||||||
@ -42,35 +77,49 @@ fn quantized_matmul() -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
assert_eq!(
|
match device {
|
||||||
|
Device::Metal(_) => assert_eq!(
|
||||||
|
to_vec2_round(&res, 0)?,
|
||||||
|
&[
|
||||||
|
[84946.0, 214126.0, 344757.0, 473798.0],
|
||||||
|
[213458.0, 604350.0, 1000469.0, 1387990.0],
|
||||||
|
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_ => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
&[
|
&[
|
||||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||||
[213475.0, 604465.0, 1000686.0, 1388317.0],
|
[213475.0, 604465.0, 1000686.0, 1388317.0],
|
||||||
[341876.0, 994283.0, 1655709.0, 2301518.0]
|
[341876.0, 994283.0, 1655709.0, 2301518.0]
|
||||||
]
|
]
|
||||||
);
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||||
fn quantized_matmul_neg() -> Result<()> {
|
// TODO Enable this later when we enable cuda.
|
||||||
let cpu = &Device::Cpu;
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let (m, k, n) = (3, 64, 4);
|
let (m, k, n) = (3, 64, 4);
|
||||||
let lhs = (0..(m * k))
|
let lhs = (0..(m * k))
|
||||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||||
let mut dst = vec![42.; 3 * 4];
|
let mut dst = vec![42.; 3 * 4];
|
||||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||||
let rhs = (0..k * n)
|
let rhs = (0..k * n)
|
||||||
.map(|v| v as f32 - (k * n) as f32 / 3.0)
|
.map(|v| v as f32 - (k * n) as f32 / 3.0)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -90,32 +139,56 @@ fn quantized_matmul_neg() -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
assert_eq!(
|
match device {
|
||||||
|
Device::Metal(_) => assert_eq!(
|
||||||
|
to_vec2_round(&res, 0)?,
|
||||||
|
&[
|
||||||
|
[243666.0, -19714.0, -285433.0, -550453.0],
|
||||||
|
[23782.0, 21654.0, 19400.0, 18369.0],
|
||||||
|
[-196102.0, 63022.0, 324233.0, 587191.0]
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_ => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
&[
|
&[
|
||||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||||
[23777.0, 21651.0, 19398.0, 18367.0],
|
[23777.0, 21651.0, 19398.0, 18367.0],
|
||||||
[-196472.0, 63012.0, 324585.0, 587902.0]
|
[-196472.0, 63012.0, 324585.0, 587902.0]
|
||||||
]
|
]
|
||||||
);
|
),
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
test_device!(
|
||||||
fn quantize_q4_0() -> Result<()> {
|
quantized_matmul,
|
||||||
use k_quants::BlockQ4_0;
|
quantized_matmul_cpu,
|
||||||
|
quantized_matmul_cuda,
|
||||||
|
quantized_matmul_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
quantized_matmul_neg,
|
||||||
|
quantized_matmul_neg_cpu,
|
||||||
|
quantized_matmul_neg_cuda,
|
||||||
|
quantized_matmul_neg_metal
|
||||||
|
);
|
||||||
|
|
||||||
|
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let mut dst = vec![0f32; 32 * 4];
|
|
||||||
let mut quant = vec![BlockQ4_0::zeros(); 4];
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
BlockQ4_0::from_float(&src, &mut quant)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
||||||
BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
|
let dst = quant.dequantize(device)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst,
|
dst.to_vec1::<f32>()?,
|
||||||
&[
|
&[
|
||||||
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
|
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
|
||||||
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
|
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
|
||||||
@ -131,21 +204,21 @@ fn quantize_q4_0() -> Result<()> {
|
|||||||
127.0, 127.0
|
127.0, 127.0
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||||
fn quantize_q4_1() -> Result<()> {
|
// TODO Enable this later when we enable cuda.
|
||||||
use k_quants::BlockQ4_1;
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let mut dst = vec![0f32; 32 * 4];
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let mut quant = vec![BlockQ4_1::zeros(); 4];
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||||
BlockQ4_1::from_float(&src, &mut quant)?;
|
let dst = quant.dequantize(device)?;
|
||||||
BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst),
|
round_vector(&dst.to_vec1::<f32>()?),
|
||||||
&[
|
&[
|
||||||
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
|
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
|
||||||
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
|
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
|
||||||
@ -161,21 +234,21 @@ fn quantize_q4_1() -> Result<()> {
|
|||||||
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
|
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||||
fn quantize_q5_0() -> Result<()> {
|
// TODO Enable this later when we enable cuda.
|
||||||
use k_quants::BlockQ5_0;
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let mut dst = vec![0f32; 32 * 4];
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let mut quant = vec![BlockQ5_0::zeros(); 4];
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||||
BlockQ5_0::from_float(&src, &mut quant)?;
|
let dst = quant.dequantize(device)?;
|
||||||
BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst),
|
round_vector(&dst.to_vec1::<f32>()?),
|
||||||
&[
|
&[
|
||||||
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
|
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
|
||||||
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
|
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
|
||||||
@ -191,21 +264,21 @@ fn quantize_q5_0() -> Result<()> {
|
|||||||
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
|
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||||
fn quantize_q5_1() -> Result<()> {
|
// TODO Enable this later when we enable cuda.
|
||||||
use k_quants::BlockQ5_1;
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let mut dst = vec![0f32; 32 * 4];
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let mut quant = vec![BlockQ5_1::zeros(); 4];
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||||
BlockQ5_1::from_float(&src, &mut quant)?;
|
let dst = quant.dequantize(device)?;
|
||||||
BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst,
|
round_vector(&dst.to_vec1::<f32>()?),
|
||||||
&[
|
&[
|
||||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
|
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
|
||||||
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
|
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
|
||||||
@ -219,13 +292,11 @@ fn quantize_q5_1() -> Result<()> {
|
|||||||
124.0, 125.0, 126.0, 127.0
|
124.0, 125.0, 126.0, 127.0
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
|
fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor> {
|
||||||
fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
|
|
||||||
assert!(
|
assert!(
|
||||||
size % crate::quantized::k_quants::QK_K == 0,
|
size % crate::quantized::k_quants::QK_K == 0,
|
||||||
"size must be a multiple of {}",
|
"size must be a multiple of {}",
|
||||||
@ -235,10 +306,8 @@ fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
|
|||||||
let src = (0..size)
|
let src = (0..size)
|
||||||
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
|
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
let dst = vec![0f32; size];
|
|
||||||
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
|
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
|
||||||
(src, dst)
|
Tensor::from_vec(src, (size,), device)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Round a vector
|
/// Round a vector
|
||||||
@ -265,7 +334,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
/// Creates a vector similar to the ones used in GGML unit tests:
|
||||||
|
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||||
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
||||||
(0..GGML_TEST_SIZE)
|
(0..GGML_TEST_SIZE)
|
||||||
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
|
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
|
||||||
@ -284,14 +354,16 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
|
|||||||
sum / a.len() as f32
|
sum / a.len() as f32
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
/// Similar to the GGML quantization unit test:
|
||||||
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
|
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||||
|
fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> {
|
||||||
let src = create_ggml_like_vector(0.0);
|
let src = create_ggml_like_vector(0.0);
|
||||||
let mut dst = vec![0.0; GGML_TEST_SIZE];
|
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
||||||
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let error = calculate_rmse(src.as_slice(), dst.as_slice());
|
let dst = quant.dequantize(device)?;
|
||||||
|
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
||||||
if error > max_error {
|
if error > max_error {
|
||||||
candle_core::bail!(
|
bail!(
|
||||||
"Quantization error {} exceeds max error {}",
|
"Quantization error {} exceeds max error {}",
|
||||||
error,
|
error,
|
||||||
max_error
|
max_error
|
||||||
@ -300,19 +372,19 @@ fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
|
fn quantize_q2k(device: &Device) -> Result<()> {
|
||||||
let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
// TODO Enable this later when we enable cuda.
|
||||||
T::from_float(src, &mut quant)?;
|
if device.is_cuda() {
|
||||||
T::to_float(&quant, dst)?;
|
return Ok(());
|
||||||
Ok(quant)
|
}
|
||||||
}
|
let dtype = GgmlDType::Q2K;
|
||||||
|
|
||||||
#[test]
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
fn quantize_q2k() -> Result<()> {
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
use k_quants::BlockQ2K;
|
let dst = quant.dequantize(device)?;
|
||||||
|
|
||||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
let src = src.to_vec1::<f32>()?;
|
||||||
let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
|
compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
|
||||||
|
|
||||||
// Test some specific values
|
// Test some specific values
|
||||||
@ -326,20 +398,30 @@ fn quantize_q2k() -> Result<()> {
|
|||||||
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
|
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
|
||||||
);
|
);
|
||||||
|
|
||||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
|
||||||
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
|
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
|
||||||
|
|
||||||
ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
|
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn quantize_q3k(device: &Device) -> Result<()> {
|
||||||
fn quantize_q3k() -> Result<()> {
|
// TODO Enable this later when we enable cuda.
|
||||||
use k_quants::BlockQ3K;
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
let dtype = GgmlDType::Q3K;
|
||||||
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
|
let dst = quant.dequantize(device)?;
|
||||||
|
|
||||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
let src = src.to_vec1::<f32>()?;
|
||||||
let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
|
compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
|
||||||
|
|
||||||
// Test some specific values
|
// Test some specific values
|
||||||
@ -353,20 +435,30 @@ fn quantize_q3k() -> Result<()> {
|
|||||||
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
|
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
|
||||||
);
|
);
|
||||||
|
|
||||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
|
||||||
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
|
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
|
||||||
|
|
||||||
ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
|
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn quantize_q4k(device: &Device) -> Result<()> {
|
||||||
fn quantize_q4k() -> Result<()> {
|
// TODO Enable this later when we enable cuda.
|
||||||
use k_quants::BlockQ4K;
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
let dtype = GgmlDType::Q4K;
|
||||||
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
|
let dst = quant.dequantize(device)?;
|
||||||
|
|
||||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
let src = src.to_vec1::<f32>()?;
|
||||||
let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
|
compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
|
||||||
|
|
||||||
// Test some specific values
|
// Test some specific values
|
||||||
@ -380,21 +472,31 @@ fn quantize_q4k() -> Result<()> {
|
|||||||
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
|
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
|
||||||
);
|
);
|
||||||
|
|
||||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
|
||||||
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
|
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
|
||||||
|
|
||||||
ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn quantize_q5k(device: &Device) -> Result<()> {
|
||||||
fn quantize_q5k() -> Result<()> {
|
// TODO Enable this later when we enable cuda.
|
||||||
use k_quants::BlockQ5K;
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
let dtype = GgmlDType::Q5K;
|
||||||
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
|
let dst = quant.dequantize(device)?;
|
||||||
|
|
||||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
let src = src.to_vec1::<f32>()?;
|
||||||
let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
compare_with_error(dst.as_slice(), src.as_slice(), 0.009);
|
||||||
|
|
||||||
// Test some specific values
|
// Test some specific values
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -404,24 +506,33 @@ fn quantize_q5k() -> Result<()> {
|
|||||||
let dst = round_vector(&dst);
|
let dst = round_vector(&dst);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||||
[-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
|
[-0.5, -0.373, -0.25, 0.0, 0.279, 0.499]
|
||||||
);
|
);
|
||||||
|
|
||||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
|
||||||
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
|
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
|
||||||
|
|
||||||
ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn quantize_q6k(device: &Device) -> Result<()> {
|
||||||
fn quantize_q6k() -> Result<()> {
|
// TODO Enable this later when we enable cuda.
|
||||||
use k_quants::BlockQ6K;
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
let dtype = GgmlDType::Q6K;
|
||||||
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
|
let dst = quant.dequantize(device)?;
|
||||||
|
|
||||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
let src = src.to_vec1::<f32>()?;
|
||||||
let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||||
|
|
||||||
// Test some specific values
|
// Test some specific values
|
||||||
@ -435,22 +546,31 @@ fn quantize_q6k() -> Result<()> {
|
|||||||
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
|
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
|
||||||
);
|
);
|
||||||
|
|
||||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
|
||||||
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
|
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
|
||||||
|
|
||||||
ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn quantize_q8k(device: &Device) -> Result<()> {
|
||||||
fn quantize_q8k() -> Result<()> {
|
// TODO Enable this later when we enable cuda.
|
||||||
use k_quants::BlockQ8K;
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
let dtype = GgmlDType::Q8K;
|
||||||
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
|
let dst = quant.dequantize(device)?;
|
||||||
|
|
||||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
let src = src.to_vec1::<f32>()?;
|
||||||
let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
|
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||||
|
|
||||||
// Test some specific values
|
// Test some specific values
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -463,15 +583,79 @@ fn quantize_q8k() -> Result<()> {
|
|||||||
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
|
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
|
||||||
);
|
);
|
||||||
|
|
||||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
|
||||||
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
|
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
|
||||||
|
|
||||||
ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test_device!(
|
||||||
|
quantize_q4_0,
|
||||||
|
quantize_q4_0_cpu,
|
||||||
|
quantize_q4_0_cuda,
|
||||||
|
quantize_q4_0_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
quantize_q4_1,
|
||||||
|
quantize_q4_1_cpu,
|
||||||
|
quantize_q4_1_cuda,
|
||||||
|
quantize_q4_1_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
quantize_q5_0,
|
||||||
|
quantize_q5_0_cpu,
|
||||||
|
quantize_q5_0_cuda,
|
||||||
|
quantize_q5_0_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
quantize_q5_1,
|
||||||
|
quantize_q5_1_cpu,
|
||||||
|
quantize_q5_1_cuda,
|
||||||
|
quantize_q5_1_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
quantize_q2k,
|
||||||
|
quantize_q2k_cpu,
|
||||||
|
quantize_q2k_cuda,
|
||||||
|
quantize_q2k_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
quantize_q3k,
|
||||||
|
quantize_q3k_cpu,
|
||||||
|
quantize_q3k_cuda,
|
||||||
|
quantize_q3k_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
quantize_q4k,
|
||||||
|
quantize_q4k_cpu,
|
||||||
|
quantize_q4k_cuda,
|
||||||
|
quantize_q4k_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
quantize_q5k,
|
||||||
|
quantize_q5k_cpu,
|
||||||
|
quantize_q5k_cuda,
|
||||||
|
quantize_q5k_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
quantize_q6k,
|
||||||
|
quantize_q6k_cpu,
|
||||||
|
quantize_q6k_cuda,
|
||||||
|
quantize_q6k_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
quantize_q8k,
|
||||||
|
quantize_q8k_cpu,
|
||||||
|
quantize_q8k_cuda,
|
||||||
|
quantize_q8k_metal
|
||||||
|
);
|
||||||
|
|
||||||
/// Very simple dot product implementation
|
/// Very simple dot product implementation
|
||||||
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
|
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
|
||||||
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
||||||
@ -487,54 +671,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
|||||||
GgmlDType::Q5K => 0.000740,
|
GgmlDType::Q5K => 0.000740,
|
||||||
GgmlDType::Q6K => 0.000952,
|
GgmlDType::Q6K => 0.000952,
|
||||||
GgmlDType::Q4_0 => 0.001143,
|
GgmlDType::Q4_0 => 0.001143,
|
||||||
GgmlDType::Q4_1 => 0.007784,
|
GgmlDType::Q4_1 => 0.008,
|
||||||
GgmlDType::Q5_0 => 0.001353,
|
GgmlDType::Q5_0 => 0.001353,
|
||||||
GgmlDType::Q5_1 => 0.001363,
|
GgmlDType::Q5_1 => 0.00149,
|
||||||
GgmlDType::Q8_0 => 0.000092,
|
GgmlDType::Q8_0 => 0.000092,
|
||||||
|
|
||||||
// Not from the ggml repo.
|
// Not from the ggml repo.
|
||||||
GgmlDType::Q8K => 0.00065,
|
GgmlDType::Q8K => 0.00065,
|
||||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
_ => bail!("No GGML results for quantization type {dtype:?}",),
|
||||||
};
|
};
|
||||||
Ok(err)
|
Ok(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
/// Similar to the GGML matmul unit test:
|
||||||
|
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||||
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||||
let a = create_ggml_like_vector(0.0);
|
let a = create_ggml_like_vector(0.0);
|
||||||
let b = create_ggml_like_vector(1.0);
|
let b = create_ggml_like_vector(1.0);
|
||||||
|
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
|
||||||
|
// Another example that is more likely to trigger the overflow reported in #1526
|
||||||
|
let a = (0..GGML_TEST_SIZE)
|
||||||
|
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let b = (0..GGML_TEST_SIZE)
|
||||||
|
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
|
||||||
let length = a.len();
|
let length = a.len();
|
||||||
|
|
||||||
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
|
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
|
||||||
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
|
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
|
||||||
T::from_float(&a, &mut a_quant)?;
|
T::from_float(a, &mut a_quant)?;
|
||||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
T::VecDotType::from_float(b, &mut b_quant)?;
|
||||||
|
|
||||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||||
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
||||||
let reference_result = vec_dot_reference(&a, &b);
|
let reference_result = vec_dot_reference(a, b);
|
||||||
|
|
||||||
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
||||||
candle_core::bail!(
|
bail!(
|
||||||
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
let error = (result - reference_result).abs() / length as f32;
|
let error = (result - reference_result).abs() / length as f32;
|
||||||
|
|
||||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
|
||||||
|
|
||||||
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||||
candle_core::bail!(
|
bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
|
||||||
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
|
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
|
||||||
// => we use a slightly higher error threshold
|
// => we use a slightly higher error threshold
|
||||||
const ERROR_LENIENCY: f32 = 0.00001;
|
const ERROR_LENIENCY: f32 = 0.00001;
|
||||||
if error - ERROR_LENIENCY > ggml_error {
|
if error - ERROR_LENIENCY > ggml_error {
|
||||||
candle_core::bail!(
|
bail!(
|
||||||
"Dot product error {} exceeds ggml reference error {}",
|
"Dot product error {} exceeds ggml reference error {}",
|
||||||
error,
|
error,
|
||||||
ggml_error
|
ggml_error
|
||||||
@ -543,6 +739,16 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn quantized_mm() -> Result<()> {
|
||||||
|
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
|
||||||
|
ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
|
||||||
|
ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
|
||||||
|
ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
|
||||||
|
ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
|
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
|
||||||
fn get_random_tensors(
|
fn get_random_tensors(
|
||||||
m: usize,
|
m: usize,
|
||||||
@ -566,6 +772,112 @@ fn get_random_tensors(
|
|||||||
Ok((lhs, rhs, mm))
|
Ok((lhs, rhs, mm))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! quantized_matmul {
|
||||||
|
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||||
|
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||||
|
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
||||||
|
fn $fn_name(device: &Device) -> Result<()> {
|
||||||
|
if device.is_cuda() {
|
||||||
|
// TODO Enable Cuda GGML sometime maybe.
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
quantized_matmul!(
|
||||||
|
quantized_matmul_q4_0_bis,
|
||||||
|
quantized_matmul_q4_0_cpu,
|
||||||
|
quantized_matmul_q4_0_cuda,
|
||||||
|
quantized_matmul_q4_0_metal,
|
||||||
|
GgmlDType::Q4_0
|
||||||
|
);
|
||||||
|
quantized_matmul!(
|
||||||
|
quantized_matmul_q4_1_bis,
|
||||||
|
quantized_matmul_q4_1_cpu,
|
||||||
|
quantized_matmul_q4_1_cuda,
|
||||||
|
quantized_matmul_q4_1_metal,
|
||||||
|
GgmlDType::Q4_1
|
||||||
|
);
|
||||||
|
quantized_matmul!(
|
||||||
|
quantized_matmul_q5_0_bis,
|
||||||
|
quantized_matmul_q5_0_cpu,
|
||||||
|
quantized_matmul_q5_0_cuda,
|
||||||
|
quantized_matmul_q5_0_metal,
|
||||||
|
GgmlDType::Q5_0
|
||||||
|
);
|
||||||
|
quantized_matmul!(
|
||||||
|
quantized_matmul_q5_1_bis,
|
||||||
|
quantized_matmul_q5_1_cpu,
|
||||||
|
quantized_matmul_q5_1_cuda,
|
||||||
|
quantized_matmul_q5_1_metal,
|
||||||
|
GgmlDType::Q5_1
|
||||||
|
);
|
||||||
|
quantized_matmul!(
|
||||||
|
quantized_matmul_q8_0_bis,
|
||||||
|
quantized_matmul_q8_0_cpu,
|
||||||
|
quantized_matmul_q8_0_cuda,
|
||||||
|
quantized_matmul_q8_0_metal,
|
||||||
|
GgmlDType::Q8_0
|
||||||
|
);
|
||||||
|
// Not implemented in Ggml
|
||||||
|
// quantized_matmul!(
|
||||||
|
// quantized_matmul_q8_1_bis,
|
||||||
|
// quantized_matmul_q8_1_cpu,
|
||||||
|
// quantized_matmul_q8_1_cuda,
|
||||||
|
// quantized_matmul_q8_1_metal,
|
||||||
|
// GgmlDType::Q8_1
|
||||||
|
// );
|
||||||
|
// TODO This is bugged (also bugged in GGML
|
||||||
|
quantized_matmul!(
|
||||||
|
quantized_matmul_q2k_bis,
|
||||||
|
quantized_matmul_q2k_cpu,
|
||||||
|
quantized_matmul_q2k_cuda,
|
||||||
|
quantized_matmul_q2k_metal,
|
||||||
|
GgmlDType::Q2K
|
||||||
|
);
|
||||||
|
quantized_matmul!(
|
||||||
|
quantized_matmul_q3k_bis,
|
||||||
|
quantized_matmul_q3k_cpu,
|
||||||
|
quantized_matmul_q3k_cuda,
|
||||||
|
quantized_matmul_q3k_metal,
|
||||||
|
GgmlDType::Q3K
|
||||||
|
);
|
||||||
|
quantized_matmul!(
|
||||||
|
quantized_matmul_q4k_bis,
|
||||||
|
quantized_matmul_q4k_cpu,
|
||||||
|
quantized_matmul_q4k_cuda,
|
||||||
|
quantized_matmul_q4k_metal,
|
||||||
|
GgmlDType::Q4K
|
||||||
|
);
|
||||||
|
quantized_matmul!(
|
||||||
|
quantized_matmul_q5k_bis,
|
||||||
|
quantized_matmul_q5k_cpu,
|
||||||
|
quantized_matmul_q5k_cuda,
|
||||||
|
quantized_matmul_q5k_metal,
|
||||||
|
GgmlDType::Q5K
|
||||||
|
);
|
||||||
|
quantized_matmul!(
|
||||||
|
quantized_matmul_q6k_bis,
|
||||||
|
quantized_matmul_q6k_cpu,
|
||||||
|
quantized_matmul_q6k_cuda,
|
||||||
|
quantized_matmul_q6k_metal,
|
||||||
|
GgmlDType::Q6K
|
||||||
|
);
|
||||||
|
// Not implemented on metal
|
||||||
|
// quantized_matmul!(
|
||||||
|
// quantized_matmul_q8k_bis,
|
||||||
|
// quantized_matmul_q8k_cpu,
|
||||||
|
// quantized_matmul_q8k_cuda,
|
||||||
|
// quantized_matmul_q8k_metal,
|
||||||
|
// GgmlDType::Q8K
|
||||||
|
// );
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn quantized_matmul_q2k() -> Result<()> {
|
fn quantized_matmul_q2k() -> Result<()> {
|
||||||
use k_quants::BlockQ2K;
|
use k_quants::BlockQ2K;
|
||||||
@ -578,7 +890,7 @@ fn quantized_matmul_q2k() -> Result<()> {
|
|||||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
@ -604,7 +916,7 @@ fn quantized_matmul_q3k() -> Result<()> {
|
|||||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
@ -630,7 +942,7 @@ fn quantized_matmul_q4k() -> Result<()> {
|
|||||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
@ -656,7 +968,7 @@ fn quantized_matmul_q5k() -> Result<()> {
|
|||||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
@ -683,7 +995,7 @@ fn quantized_matmul_q6k() -> Result<()> {
|
|||||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
@ -708,7 +1020,7 @@ fn quantized_matmul_q8k() -> Result<()> {
|
|||||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
|
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
|
||||||
|
|
||||||
fn zeros(device: &Device) -> Result<()> {
|
fn zeros(device: &Device) -> Result<()> {
|
||||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||||
@ -29,7 +29,34 @@ fn ones(device: &Device) -> Result<()> {
|
|||||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn full(device: &Device) -> Result<()> {
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
|
||||||
|
[[42, 42, 42], [42, 42, 42]],
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn arange(device: &Device) -> Result<()> {
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
|
||||||
|
[0, 1, 2, 3, 4],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
|
||||||
|
[0, 2, 4],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
|
||||||
|
[0, 3],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
|
||||||
|
[5, 4, 3, 2, 1],
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,6 +188,22 @@ fn transpose(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn var(device: &Device) -> Result<()> {
|
||||||
|
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
|
||||||
|
let data = &[
|
||||||
|
[0.2035f32, 1.2959, 1.8101, -0.4644],
|
||||||
|
[1.5027, -0.3270, 0.5905, 0.6538],
|
||||||
|
[-1.5745, 1.3330, -0.5596, -0.6548],
|
||||||
|
[0.1264, -0.5080, 1.6420, 0.1992],
|
||||||
|
];
|
||||||
|
let tensor = Tensor::new(data, device)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
|
||||||
|
&[[1.0631], [0.559], [1.4893], [0.8258]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn sum(device: &Device) -> Result<()> {
|
fn sum(device: &Device) -> Result<()> {
|
||||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
@ -1035,33 +1078,61 @@ fn randn(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||||
test_device!(ones, ones_cpu, ones_gpu);
|
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||||
test_device!(cat, cat_cpu, cat_gpu);
|
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||||
test_device!(sum, sum_cpu, sum_gpu);
|
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||||
test_device!(min, min_cpu, min_gpu);
|
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||||
test_device!(max, max_cpu, max_gpu);
|
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
test_device!(max, max_cpu, max_gpu, max_metal);
|
||||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
|
||||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
|
||||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
|
||||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||||
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
|
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
test_device!(
|
||||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
broadcast_matmul,
|
||||||
test_device!(gather, gather_cpu, gather_gpu);
|
broadcast_matmul_cpu,
|
||||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
broadcast_matmul_gpu,
|
||||||
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
broadcast_matmul_metal
|
||||||
test_device!(randn, randn_cpu, randn_gpu);
|
);
|
||||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
test_device!(
|
||||||
|
broadcasting,
|
||||||
|
broadcasting_cpu,
|
||||||
|
broadcasting_gpu,
|
||||||
|
broadcasting_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
index_select,
|
||||||
|
index_select_cpu,
|
||||||
|
index_select_gpu,
|
||||||
|
index_select_metal
|
||||||
|
);
|
||||||
|
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||||
|
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||||
|
test_device!(
|
||||||
|
scatter_add,
|
||||||
|
scatter_add_cpu,
|
||||||
|
scatter_add_gpu,
|
||||||
|
scatter_add_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
slice_scatter,
|
||||||
|
slice_scatter_cpu,
|
||||||
|
slice_scatter_gpu,
|
||||||
|
slice_scatter_metal
|
||||||
|
);
|
||||||
|
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||||
|
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||||
|
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
@ -1089,3 +1160,108 @@ fn pad_with_same() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn i64_abs() -> Result<()> {
|
||||||
|
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
|
||||||
|
let t = t.abs()?;
|
||||||
|
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tril_triu_eye() -> Result<()> {
|
||||||
|
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[1.0, 0.0, 0.0, 0.0],
|
||||||
|
[1.0, 1.0, 0.0, 0.0],
|
||||||
|
[1.0, 1.0, 1.0, 0.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0]
|
||||||
|
],
|
||||||
|
);
|
||||||
|
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[1.0, 1.0, 1.0, 1.0],
|
||||||
|
[0.0, 1.0, 1.0, 1.0],
|
||||||
|
[0.0, 0.0, 1.0, 1.0],
|
||||||
|
[0.0, 0.0, 0.0, 1.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[1.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 1.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 1.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cumsum() -> Result<()> {
|
||||||
|
let t = &[3f32, 1., 4., 1., 5.];
|
||||||
|
let t = Tensor::new(t, &Device::Cpu)?;
|
||||||
|
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
|
||||||
|
let t = t.unsqueeze(1)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||||
|
[[3.0], [4.0], [8.0], [9.0], [14.0]]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||||
|
[[3.0], [1.0], [4.0], [1.0], [5.0]]
|
||||||
|
);
|
||||||
|
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||||
|
let t = Tensor::new(t, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||||
|
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||||
|
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
|
||||||
|
/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
|
||||||
|
fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
||||||
|
let a_vec: Vec<f64> = a.to_vec1()?;
|
||||||
|
let b_vec: Vec<f64> = b.to_vec1()?;
|
||||||
|
|
||||||
|
assert_eq!(a_vec.len(), b_vec.len());
|
||||||
|
for (a, b) in a_vec.iter().zip(b_vec.iter()) {
|
||||||
|
assert!((a - b).abs() < epsilon);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn log_sum_exp() -> Result<()> {
|
||||||
|
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||||
|
let output = input.log_sum_exp(D::Minus1)?;
|
||||||
|
// The expectations obtained from pytorch.
|
||||||
|
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
||||||
|
assert_close(&output, &expected, 0.00001)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pow() -> Result<()> {
|
||||||
|
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||||
|
let rhs = (&lhs - 2.)?;
|
||||||
|
let res = lhs.pow(&rhs)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&res, 4)?,
|
||||||
|
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -11,8 +11,8 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { workspace = true }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { workspace = true }
|
||||||
hf-hub = { workspace = true}
|
hf-hub = { workspace = true}
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
|
@ -4,7 +4,9 @@
|
|||||||
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
|
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
|
||||||
//! The binary version of the dataset is used.
|
//! The binary version of the dataset is used.
|
||||||
use crate::vision::Dataset;
|
use crate::vision::Dataset;
|
||||||
use candle::{DType, Device, Result, Tensor};
|
use candle::{DType, Device, Error, Result, Tensor};
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{BufReader, Read};
|
use std::io::{BufReader, Read};
|
||||||
|
|
||||||
@ -60,3 +62,58 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
|
|||||||
labels: 10,
|
labels: 10,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
|
||||||
|
let samples = parquet.metadata().file_metadata().num_rows() as usize;
|
||||||
|
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 1_024);
|
||||||
|
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
|
||||||
|
for row in parquet.into_iter().flatten() {
|
||||||
|
for (_name, field) in row.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Group(subrow) = field {
|
||||||
|
for (_name, field) in subrow.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Bytes(value) = field {
|
||||||
|
let image = image::load_from_memory(value.data()).unwrap();
|
||||||
|
buffer_images.extend(image.to_rgb8().as_raw());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if let parquet::record::Field::Long(label) = field {
|
||||||
|
buffer_labels.push(*label as u8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
|
||||||
|
.to_dtype(DType::U8)?
|
||||||
|
/ 255.)?;
|
||||||
|
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||||
|
Ok((images, labels))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load() -> Result<Dataset> {
|
||||||
|
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
|
let dataset_id = "cifar10".to_string();
|
||||||
|
let repo = Repo::with_revision(
|
||||||
|
dataset_id,
|
||||||
|
RepoType::Dataset,
|
||||||
|
"refs/convert/parquet".to_string(),
|
||||||
|
);
|
||||||
|
let repo = api.repo(repo);
|
||||||
|
let test_parquet_filename = repo
|
||||||
|
.get("plain_text/test/0000.parquet")
|
||||||
|
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
|
let train_parquet_filename = repo
|
||||||
|
.get("plain_text/train/0000.parquet")
|
||||||
|
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
|
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
||||||
|
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||||
|
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
|
||||||
|
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||||
|
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
||||||
|
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
||||||
|
Ok(crate::vision::Dataset {
|
||||||
|
train_images,
|
||||||
|
train_labels,
|
||||||
|
test_images,
|
||||||
|
test_labels,
|
||||||
|
labels: 10,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -11,17 +11,21 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { workspace = true }
|
||||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
candle-datasets = { workspace = true }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { workspace = true }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
candle-transformers = { workspace = true }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
candle-flash-attn = { workspace = true, optional = true }
|
||||||
|
candle-onnx = { workspace = true, optional = true }
|
||||||
|
|
||||||
|
csv = "1.3.0"
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
|
hf-hub = { workspace = true, features=["tokio"]}
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true }
|
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
@ -32,7 +36,6 @@ tokenizers = { workspace = true, features = ["onig"] }
|
|||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
clap = { workspace = true }
|
clap = { workspace = true }
|
||||||
hf-hub = { workspace = true, features=["tokio"]}
|
|
||||||
imageproc = { workspace = true }
|
imageproc = { workspace = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
@ -46,15 +49,18 @@ tokio = "1.29.1"
|
|||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
|
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
||||||
cudnn = ["candle/cudnn"]
|
cudnn = ["candle/cudnn"]
|
||||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
|
onnx = ["candle-onnx"]
|
||||||
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
@ -63,3 +69,11 @@ required-features = ["cuda", "nccl", "flash-attn"]
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "reinforcement-learning"
|
name = "reinforcement-learning"
|
||||||
required-features = ["pyo3"]
|
required-features = ["pyo3"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "onnx"
|
||||||
|
required-features = ["onnx"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "onnx_basics"
|
||||||
|
required-features = ["onnx"]
|
||||||
|
@ -4,235 +4,28 @@ use std::io::Write;
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
struct KernelDirectories {
|
struct KernelDirectories {
|
||||||
kernel_dir: &'static str,
|
kernel_glob: &'static str,
|
||||||
rust_target: &'static str,
|
rust_target: &'static str,
|
||||||
include_dirs: &'static [&'static str],
|
include_dirs: &'static [&'static str],
|
||||||
}
|
}
|
||||||
|
|
||||||
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||||
kernel_dir: "examples/custom-ops/kernels/",
|
kernel_glob: "examples/custom-ops/kernels/*.cu",
|
||||||
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
||||||
include_dirs: &[],
|
include_dirs: &[],
|
||||||
}];
|
}];
|
||||||
|
|
||||||
impl KernelDirectories {
|
|
||||||
fn maybe_build_ptx(
|
|
||||||
&self,
|
|
||||||
cu_file: &std::path::Path,
|
|
||||||
ptx_file: &std::path::Path,
|
|
||||||
compute_cap: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
let should_compile = if ptx_file.exists() {
|
|
||||||
let ptx_modified = ptx_file.metadata()?.modified()?;
|
|
||||||
let cu_modified = cu_file.metadata()?.modified()?;
|
|
||||||
cu_modified.duration_since(ptx_modified).is_ok()
|
|
||||||
} else {
|
|
||||||
true
|
|
||||||
};
|
|
||||||
if should_compile {
|
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
{
|
|
||||||
let mut command = std::process::Command::new("nvcc");
|
|
||||||
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
|
|
||||||
let include_dirs: Vec<String> =
|
|
||||||
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
|
|
||||||
command
|
|
||||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
|
||||||
.arg("--ptx")
|
|
||||||
.args(["--default-stream", "per-thread"])
|
|
||||||
.args(["--output-directory", out_dir.to_str().unwrap()])
|
|
||||||
.arg(format!("-I/{}", self.kernel_dir))
|
|
||||||
.args(include_dirs)
|
|
||||||
.arg(cu_file);
|
|
||||||
let output = command
|
|
||||||
.spawn()
|
|
||||||
.context("failed spawning nvcc")?
|
|
||||||
.wait_with_output()?;
|
|
||||||
if !output.status.success() {
|
|
||||||
anyhow::bail!(
|
|
||||||
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
|
||||||
String::from_utf8_lossy(&output.stdout),
|
|
||||||
String::from_utf8_lossy(&output.stderr)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
std::fs::OpenOptions::new()
|
|
||||||
.create(true)
|
|
||||||
.write(true)
|
|
||||||
.open(ptx_file)?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
|
|
||||||
println!("cargo:rerun-if-changed={}", self.kernel_dir);
|
|
||||||
let kernel_dir = PathBuf::from(self.kernel_dir);
|
|
||||||
let out_dir = out_dir.join(self.kernel_dir);
|
|
||||||
if !out_dir.exists() {
|
|
||||||
std::fs::create_dir_all(&out_dir)?;
|
|
||||||
}
|
|
||||||
let mut cu_files = vec![];
|
|
||||||
let mut cuh_files = vec![];
|
|
||||||
for file in std::fs::read_dir(kernel_dir)?.flatten() {
|
|
||||||
let file = file.path();
|
|
||||||
match file.extension().and_then(|v| v.to_str()) {
|
|
||||||
Some("cu") => cu_files.push(file),
|
|
||||||
Some("cuh") => cuh_files.push(file),
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut ptx_paths = vec![];
|
|
||||||
for cu_file in cu_files.iter() {
|
|
||||||
let file_stem = cu_file
|
|
||||||
.file_stem()
|
|
||||||
.with_context(|| format!("no stem {cu_file:?}"))?;
|
|
||||||
let file_stem = file_stem.to_string_lossy().into_owned();
|
|
||||||
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
|
|
||||||
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
|
|
||||||
ptx_paths.push(ptx_file);
|
|
||||||
}
|
|
||||||
|
|
||||||
let regenerate_rs_file = true;
|
|
||||||
if regenerate_rs_file {
|
|
||||||
let mut file = std::fs::File::create(self.rust_target)?;
|
|
||||||
for ptx_path in ptx_paths {
|
|
||||||
let name = ptx_path
|
|
||||||
.file_stem()
|
|
||||||
.context("empty stem")?
|
|
||||||
.to_string_lossy();
|
|
||||||
file.write_all(b"#[rustfmt::skip]\n")?;
|
|
||||||
let const_definition = format!(
|
|
||||||
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
|
|
||||||
name.to_uppercase().replace('.', "_"),
|
|
||||||
self.kernel_dir,
|
|
||||||
);
|
|
||||||
file.write_all(const_definition.as_bytes())?;
|
|
||||||
file.write_all(b"\n")?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
println!("cargo:rerun-if-changed=build.rs");
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
|
|
||||||
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
|
|
||||||
let out_dir = PathBuf::from(out_dir);
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
set_cuda_include_dir()?;
|
{
|
||||||
#[cfg(feature = "cuda")]
|
for kdir in KERNEL_DIRS.iter() {
|
||||||
let compute_cap = compute_cap()?;
|
let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
|
||||||
#[cfg(not(feature = "cuda"))]
|
println!("cargo:info={builder:?}");
|
||||||
let compute_cap = 0;
|
let bindings = builder.build_ptx().unwrap();
|
||||||
for d in DIRS {
|
bindings.write(kdir.rust_target).unwrap()
|
||||||
d.process(&out_dir, compute_cap)?
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_cuda_include_dir() -> Result<()> {
|
|
||||||
// NOTE: copied from cudarc build.rs.
|
|
||||||
let env_vars = [
|
|
||||||
"CUDA_PATH",
|
|
||||||
"CUDA_ROOT",
|
|
||||||
"CUDA_TOOLKIT_ROOT_DIR",
|
|
||||||
"CUDNN_LIB",
|
|
||||||
];
|
|
||||||
let env_vars = env_vars
|
|
||||||
.into_iter()
|
|
||||||
.map(std::env::var)
|
|
||||||
.filter_map(Result::ok)
|
|
||||||
.map(Into::<PathBuf>::into);
|
|
||||||
|
|
||||||
let roots = [
|
|
||||||
"/usr",
|
|
||||||
"/usr/local/cuda",
|
|
||||||
"/opt/cuda",
|
|
||||||
"/usr/lib/cuda",
|
|
||||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
|
||||||
"C:/CUDA",
|
|
||||||
];
|
|
||||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
|
||||||
let root = env_vars
|
|
||||||
.chain(roots)
|
|
||||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
|
||||||
.context("cannot find include/cuda.h")?;
|
|
||||||
println!(
|
|
||||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
|
||||||
root.join("include").display()
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
fn compute_cap() -> Result<usize> {
|
|
||||||
// Grab compute code from nvidia-smi
|
|
||||||
let mut compute_cap = {
|
|
||||||
let out = std::process::Command::new("nvidia-smi")
|
|
||||||
.arg("--query-gpu=compute_cap")
|
|
||||||
.arg("--format=csv")
|
|
||||||
.output()
|
|
||||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
|
||||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
|
||||||
let mut lines = out.lines();
|
|
||||||
assert_eq!(
|
|
||||||
lines.next().context("missing line in stdout")?,
|
|
||||||
"compute_cap"
|
|
||||||
);
|
|
||||||
let cap = lines
|
|
||||||
.next()
|
|
||||||
.context("missing line in stdout")?
|
|
||||||
.replace('.', "");
|
|
||||||
cap.parse::<usize>()
|
|
||||||
.with_context(|| format!("cannot parse as int {cap}"))?
|
|
||||||
};
|
|
||||||
|
|
||||||
// Grab available GPU codes from nvcc and select the highest one
|
|
||||||
let max_nvcc_code = {
|
|
||||||
let out = std::process::Command::new("nvcc")
|
|
||||||
.arg("--list-gpu-code")
|
|
||||||
.output()
|
|
||||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
|
||||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
|
||||||
|
|
||||||
let out = out.lines().collect::<Vec<&str>>();
|
|
||||||
let mut codes = Vec::with_capacity(out.len());
|
|
||||||
for code in out {
|
|
||||||
let code = code.split('_').collect::<Vec<&str>>();
|
|
||||||
if !code.is_empty() && code.contains(&"sm") {
|
|
||||||
if let Ok(num) = code[1].parse::<usize>() {
|
|
||||||
codes.push(num);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
codes.sort();
|
|
||||||
if !codes.contains(&compute_cap) {
|
|
||||||
anyhow::bail!(
|
|
||||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
|
||||||
);
|
|
||||||
}
|
|
||||||
*codes.last().unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
|
||||||
// then choose the highest gpu code in nvcc
|
|
||||||
if compute_cap > max_nvcc_code {
|
|
||||||
println!(
|
|
||||||
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
|
||||||
);
|
|
||||||
compute_cap = max_nvcc_code;
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
|
||||||
|
|
||||||
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
|
||||||
compute_cap = compute_cap_str
|
|
||||||
.parse::<usize>()
|
|
||||||
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
|
||||||
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
|
||||||
}
|
|
||||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
|
||||||
Ok(compute_cap)
|
|
||||||
}
|
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
Bert is a general large language model. In this example it can be used for two
|
Bert is a general large language model. In this example it can be used for two
|
||||||
different tasks:
|
different tasks:
|
||||||
|
|
||||||
- Compute sentence embeddings for a prompt.
|
- Compute sentence embeddings for a prompt.
|
||||||
- Compute similarities between a set of sentences.
|
- Compute similarities between a set of sentences.
|
||||||
|
|
||||||
|
|
||||||
## Sentence embeddings
|
## Sentence embeddings
|
||||||
|
|
||||||
Bert is used to compute the sentence embeddings for a prompt. The model weights
|
Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||||
@ -24,6 +24,48 @@ cargo run --example bert --release -- --prompt "Here is a test sentence"
|
|||||||
> Tensor[[1, 7, 384], f32]
|
> Tensor[[1, 7, 384], f32]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Custom models
|
||||||
|
|
||||||
|
You can specify different models, such as BGE, with the `--model-id` flag:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example bert --release -- \
|
||||||
|
--model-id BAAI/bge-large-zh-v1.5 \
|
||||||
|
--prompt "Here is a test sentence"
|
||||||
|
Loaded and encoded 435.70775ms
|
||||||
|
[[[ 3.0944e-1, -7.8455e-5, -1.2768e0, ..., 1.3755e-2, -3.2371e-1, 2.3819e-1],
|
||||||
|
[-2.8506e-1, 1.9953e-1, -1.3076e0, ..., 6.9819e-2, 1.0833e-2, -1.1512e0],
|
||||||
|
[ 3.9892e-1, 2.0000e-1, -9.3178e-1, ..., -4.1393e-1, -4.9644e-2, -3.3786e-1],
|
||||||
|
...
|
||||||
|
[ 6.0345e-1, 3.5744e-1, -1.2672e0, ..., -6.9165e-1, -3.4973e-3, -8.4214e-1],
|
||||||
|
[ 3.9218e-1, -3.2735e-1, -1.3123e0, ..., -4.9318e-1, -5.1334e-1, -3.6391e-1],
|
||||||
|
[ 3.0978e-1, 2.5662e-4, -1.2773e0, ..., 1.3357e-2, -3.2390e-1, 2.3858e-1]]]
|
||||||
|
Tensor[[1, 9, 1024], f32]
|
||||||
|
Took 176.744667ms
|
||||||
|
```
|
||||||
|
|
||||||
|
### Gelu approximation
|
||||||
|
|
||||||
|
You can get a speedup by using an approximation of the gelu activation, with a
|
||||||
|
small loss of precision, by passing the `--approximate-gelu` flag:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example bert --release -- \
|
||||||
|
--model-id BAAI/bge-large-zh-v1.5 \
|
||||||
|
--prompt "Here is a test sentence" \
|
||||||
|
--approximate-gelu
|
||||||
|
Loaded and encoded 244.388042ms
|
||||||
|
[[[ 3.1048e-1, -6.0339e-4, -1.2758e0, ..., 1.3718e-2, -3.2362e-1, 2.3775e-1],
|
||||||
|
[-2.8354e-1, 1.9984e-1, -1.3077e0, ..., 6.9390e-2, 9.9681e-3, -1.1531e0],
|
||||||
|
[ 3.9947e-1, 1.9917e-1, -9.3178e-1, ..., -4.1301e-1, -5.0719e-2, -3.3955e-1],
|
||||||
|
...
|
||||||
|
[ 6.0499e-1, 3.5664e-1, -1.2642e0, ..., -6.9134e-1, -3.4581e-3, -8.4471e-1],
|
||||||
|
[ 3.9311e-1, -3.2812e-1, -1.3105e0, ..., -4.9291e-1, -5.1270e-1, -3.6543e-1],
|
||||||
|
[ 3.1082e-1, -2.6737e-4, -1.2762e0, ..., 1.3319e-2, -3.2381e-1, 2.3815e-1]]]
|
||||||
|
Tensor[[1, 9, 1024], f32]
|
||||||
|
Took 116.840791ms
|
||||||
|
```
|
||||||
|
|
||||||
## Similarities
|
## Similarities
|
||||||
|
|
||||||
In this example, Bert is used to compute the sentence embeddings for a set of
|
In this example, Bert is used to compute the sentence embeddings for a set of
|
||||||
|
@ -3,7 +3,7 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::Tensor;
|
use candle::Tensor;
|
||||||
@ -45,6 +45,10 @@ struct Args {
|
|||||||
/// L2 normalization for embeddings.
|
/// L2 normalization for embeddings.
|
||||||
#[arg(long, default_value = "true")]
|
#[arg(long, default_value = "true")]
|
||||||
normalize_embeddings: bool,
|
normalize_embeddings: bool,
|
||||||
|
|
||||||
|
/// Use tanh based approximation for Gelu instead of erf implementation.
|
||||||
|
#[arg(long, default_value = "false")]
|
||||||
|
approximate_gelu: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -73,7 +77,7 @@ impl Args {
|
|||||||
(config, tokenizer, weights)
|
(config, tokenizer, weights)
|
||||||
};
|
};
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let config: Config = serde_json::from_str(&config)?;
|
let mut config: Config = serde_json::from_str(&config)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let vb = if self.use_pth {
|
let vb = if self.use_pth {
|
||||||
@ -81,6 +85,9 @@ impl Args {
|
|||||||
} else {
|
} else {
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||||
};
|
};
|
||||||
|
if self.approximate_gelu {
|
||||||
|
config.hidden_act = HiddenAct::GeluApproximate;
|
||||||
|
}
|
||||||
let model = BertModel::load(vb, &config)?;
|
let model = BertModel::load(vb, &config)?;
|
||||||
Ok((model, tokenizer))
|
Ok((model, tokenizer))
|
||||||
}
|
}
|
||||||
|
@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let config = blip::Config::image_captioning_large();
|
let config = blip::Config::image_captioning_large();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (image_embeds, device, mut model) = if args.quantized {
|
let (image_embeds, device, mut model) = if args.quantized {
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let image = load_image(args.image)?.to_device(&device)?;
|
let image = load_image(args.image)?.to_device(&device)?;
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
|
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?;
|
||||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||||
(image_embeds, device, Model::Q(model))
|
(image_embeds, device, Model::Q(model))
|
||||||
} else {
|
} else {
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let image = load_image(args.image)?.to_device(&device)?;
|
let image = load_image(args.image)?.to_device(&device)?;
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
@ -149,6 +149,6 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
print!("{rest}");
|
print!("{rest}");
|
||||||
}
|
}
|
||||||
|
println!();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
#[rustfmt::skip]
|
|
||||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
|
|
||||||
|
@ -6,7 +6,8 @@
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
#[allow(unused)]
|
#[rustfmt::skip]
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
mod cuda_kernels;
|
mod cuda_kernels;
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
22
candle-examples/examples/distilbert/README.md
Normal file
22
candle-examples/examples/distilbert/README.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# candle-distilbert
|
||||||
|
|
||||||
|
DistilBert is a distiled version of the Bert model.
|
||||||
|
|
||||||
|
## Sentence embeddings
|
||||||
|
|
||||||
|
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
|
||||||
|
are downloaded from the hub on the first run.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||||
|
|
||||||
|
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
||||||
|
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
||||||
|
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
|
||||||
|
> ...
|
||||||
|
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
|
||||||
|
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
|
||||||
|
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
|
||||||
|
> Tensor[[1, 7, 768], f32]
|
||||||
|
|
||||||
|
```
|
135
candle-examples/examples/distilbert/main.rs
Normal file
135
candle-examples/examples/distilbert/main.rs
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use candle::{Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use clap::Parser;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
/// When set, compute embeddings for this prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// Use the pytorch weights rather than the safetensors ones
|
||||||
|
#[arg(long)]
|
||||||
|
use_pth: bool,
|
||||||
|
|
||||||
|
/// The number of times to run the prompt.
|
||||||
|
#[arg(long, default_value = "1")]
|
||||||
|
n: usize,
|
||||||
|
|
||||||
|
/// L2 normalization for embeddings.
|
||||||
|
#[arg(long, default_value = "true")]
|
||||||
|
normalize_embeddings: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
||||||
|
let device = candle_examples::device(self.cpu)?;
|
||||||
|
let default_model = "distilbert-base-uncased".to_string();
|
||||||
|
let default_revision = "main".to_string();
|
||||||
|
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||||
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
|
(None, Some(revision)) => (default_model, revision),
|
||||||
|
(None, None) => (default_model, default_revision),
|
||||||
|
};
|
||||||
|
|
||||||
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
|
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||||
|
let api = Api::new()?;
|
||||||
|
let api = api.repo(repo);
|
||||||
|
let config = api.get("config.json")?;
|
||||||
|
let tokenizer = api.get("tokenizer.json")?;
|
||||||
|
let weights = if self.use_pth {
|
||||||
|
api.get("pytorch_model.bin")?
|
||||||
|
} else {
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
};
|
||||||
|
(config, tokenizer, weights)
|
||||||
|
};
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let vb = if self.use_pth {
|
||||||
|
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||||
|
} else {
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||||
|
};
|
||||||
|
let model = DistilBertModel::load(vb, &config)?;
|
||||||
|
Ok((model, tokenizer))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mask(size: usize, device: &Device) -> Tensor {
|
||||||
|
let mask: Vec<_> = (0..size)
|
||||||
|
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
println!("tracing...");
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||||
|
let device = &model.device;
|
||||||
|
|
||||||
|
let tokenizer = tokenizer
|
||||||
|
.with_padding(None)
|
||||||
|
.with_truncation(None)
|
||||||
|
.map_err(E::msg)?;
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(args.prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
|
let mask = get_mask(tokens.len(), device);
|
||||||
|
|
||||||
|
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
||||||
|
println!("mask: {:?}", mask.to_vec2::<u8>());
|
||||||
|
|
||||||
|
let ys = model.forward(&token_ids, &mask)?;
|
||||||
|
println!("{ys}");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||||
|
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||||
|
}
|
@ -165,14 +165,7 @@ fn main() -> Result<()> {
|
|||||||
args.revision,
|
args.revision,
|
||||||
));
|
));
|
||||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||||
let mut filenames = vec![];
|
let filenames = candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?;
|
||||||
for rfilename in [
|
|
||||||
"model-00001-of-00002.safetensors",
|
|
||||||
"model-00002-of-00002.safetensors",
|
|
||||||
] {
|
|
||||||
let filename = repo.get(rfilename)?;
|
|
||||||
filenames.push(filename);
|
|
||||||
}
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
45
candle-examples/examples/jina-bert/README.md
Normal file
45
candle-examples/examples/jina-bert/README.md
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# candle-jina-bert
|
||||||
|
|
||||||
|
Jina-Bert is a general large language model with a context size of 8192, [model
|
||||||
|
card](https://huggingface.co/jinaai/jina-embeddings-v2-base-en). In this example
|
||||||
|
it can be used for two different tasks:
|
||||||
|
- Compute sentence embeddings for a prompt.
|
||||||
|
- Compute similarities between a set of sentences.
|
||||||
|
|
||||||
|
|
||||||
|
## Sentence embeddings
|
||||||
|
|
||||||
|
Jina-Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||||
|
are downloaded from the hub on the first run.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example jina-bert --release -- --prompt "Here is a test sentence"
|
||||||
|
|
||||||
|
> [[[ 0.1595, -0.9885, 0.6494, ..., 0.3003, -0.6901, -1.2355],
|
||||||
|
> [ 0.0374, -0.1798, 1.3359, ..., 0.6731, 0.2133, -1.6807],
|
||||||
|
> [ 0.1700, -0.8534, 0.8924, ..., -0.1785, -0.0727, -1.5087],
|
||||||
|
> ...
|
||||||
|
> [-0.3113, -1.3665, 0.2027, ..., -0.2519, 0.1711, -1.5811],
|
||||||
|
> [ 0.0907, -1.0492, 0.5382, ..., 0.0242, -0.7077, -1.0830],
|
||||||
|
> [ 0.0369, -0.6343, 0.6105, ..., 0.0671, 0.3778, -1.1505]]]
|
||||||
|
> Tensor[[1, 7, 768], f32]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Similarities
|
||||||
|
|
||||||
|
In this example, Jina-Bert is used to compute the sentence embeddings for a set of
|
||||||
|
sentences (hardcoded in the examples). Then cosine similarities are computed for
|
||||||
|
each sentence pair and they are reported by decreasing values, hence the first
|
||||||
|
reported pair contains the two sentences that have the highest similarity score.
|
||||||
|
The sentence embeddings are computed using average pooling through all the
|
||||||
|
sentence tokens, including some potential padding.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example jina-bert --release
|
||||||
|
|
||||||
|
> score: 0.94 'The new movie is awesome' 'The new movie is so great'
|
||||||
|
> score: 0.81 'The cat sits outside' 'The cat plays in the garden'
|
||||||
|
> score: 0.78 'I love pasta' 'Do you like pizza?'
|
||||||
|
> score: 0.68 'I love pasta' 'The new movie is awesome'
|
||||||
|
> score: 0.67 'A man is playing guitar' 'A woman watches TV'
|
||||||
|
```
|
180
candle-examples/examples/jina-bert/main.rs
Normal file
180
candle-examples/examples/jina-bert/main.rs
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle_transformers::models::jina_bert::{BertModel, Config};
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use candle::{DType, Module, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// When set, compute embeddings for this prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
/// The number of times to run the prompt.
|
||||||
|
#[arg(long, default_value = "1")]
|
||||||
|
n: usize,
|
||||||
|
|
||||||
|
/// L2 normalization for embeddings.
|
||||||
|
#[arg(long, default_value = "true")]
|
||||||
|
normalize_embeddings: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
let model = match &self.model {
|
||||||
|
Some(model_file) => std::path::PathBuf::from(model_file),
|
||||||
|
None => Api::new()?
|
||||||
|
.repo(Repo::new(
|
||||||
|
"jinaai/jina-embeddings-v2-base-en".to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
))
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
};
|
||||||
|
let tokenizer = match &self.tokenizer {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => Api::new()?
|
||||||
|
.repo(Repo::new(
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
))
|
||||||
|
.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let device = candle_examples::device(self.cpu)?;
|
||||||
|
let config = Config::v2_base();
|
||||||
|
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||||
|
let model = BertModel::new(vb, &config)?;
|
||||||
|
Ok((model, tokenizer))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
println!("tracing...");
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
|
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||||
|
let device = &model.device;
|
||||||
|
|
||||||
|
if let Some(prompt) = args.prompt {
|
||||||
|
let tokenizer = tokenizer
|
||||||
|
.with_padding(None)
|
||||||
|
.with_truncation(None)
|
||||||
|
.map_err(E::msg)?;
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
|
println!("Loaded and encoded {:?}", start.elapsed());
|
||||||
|
for idx in 0..args.n {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let ys = model.forward(&token_ids)?;
|
||||||
|
if idx == 0 {
|
||||||
|
println!("{ys}");
|
||||||
|
}
|
||||||
|
println!("Took {:?}", start.elapsed());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let sentences = [
|
||||||
|
"The cat sits outside",
|
||||||
|
"A man is playing guitar",
|
||||||
|
"I love pasta",
|
||||||
|
"The new movie is awesome",
|
||||||
|
"The cat plays in the garden",
|
||||||
|
"A woman watches TV",
|
||||||
|
"The new movie is so great",
|
||||||
|
"Do you like pizza?",
|
||||||
|
];
|
||||||
|
let n_sentences = sentences.len();
|
||||||
|
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||||
|
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||||
|
} else {
|
||||||
|
let pp = tokenizers::PaddingParams {
|
||||||
|
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
tokenizer.with_padding(Some(pp));
|
||||||
|
}
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode_batch(sentences.to_vec(), true)
|
||||||
|
.map_err(E::msg)?;
|
||||||
|
let token_ids = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|tokens| {
|
||||||
|
let tokens = tokens.get_ids().to_vec();
|
||||||
|
Tensor::new(tokens.as_slice(), device)
|
||||||
|
})
|
||||||
|
.collect::<candle::Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||||
|
println!("running inference on batch {:?}", token_ids.shape());
|
||||||
|
let embeddings = model.forward(&token_ids)?;
|
||||||
|
println!("generated embeddings {:?}", embeddings.shape());
|
||||||
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||||
|
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||||
|
let embeddings = if args.normalize_embeddings {
|
||||||
|
normalize_l2(&embeddings)?
|
||||||
|
} else {
|
||||||
|
embeddings
|
||||||
|
};
|
||||||
|
println!("pooled embeddings {:?}", embeddings.shape());
|
||||||
|
|
||||||
|
let mut similarities = vec![];
|
||||||
|
for i in 0..n_sentences {
|
||||||
|
let e_i = embeddings.get(i)?;
|
||||||
|
for j in (i + 1)..n_sentences {
|
||||||
|
let e_j = embeddings.get(j)?;
|
||||||
|
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||||
|
similarities.push((cosine_similarity, i, j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
||||||
|
for &(score, i, j) in similarities[..5].iter() {
|
||||||
|
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||||
|
}
|
@ -13,7 +13,7 @@ extern crate accelerate_src;
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::{bail, Error as E, Result};
|
use anyhow::{bail, Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle::{DType, Tensor};
|
use candle::{DType, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
@ -22,11 +22,21 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
|||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
use candle_transformers::models::llama as model;
|
use candle_transformers::models::llama as model;
|
||||||
use model::{Config, Llama, LlamaConfig};
|
use model::{Llama, LlamaConfig};
|
||||||
|
|
||||||
const EOS_TOKEN: &str = "</s>";
|
const EOS_TOKEN: &str = "</s>";
|
||||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
V1,
|
||||||
|
V2,
|
||||||
|
#[value(name = "solar-10.7b")]
|
||||||
|
Solar10_7B,
|
||||||
|
#[value(name = "tiny-llama-1.1b-chat")]
|
||||||
|
TinyLlama1_1BChat,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -34,10 +44,6 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
/// Use npy instead of safetensors
|
|
||||||
#[arg(long)]
|
|
||||||
npy: Option<String>,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
@ -76,17 +82,13 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
/// The model size to use.
|
||||||
v1: bool,
|
#[arg(long, default_value = "v2")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
|
||||||
/// The folder name that contains safetensor weights and json files
|
|
||||||
/// (same structure as huggingface online)
|
|
||||||
#[arg(long)]
|
|
||||||
local_weights: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
#[arg(long, default_value_t = 1.0)]
|
#[arg(long, default_value_t = 1.0)]
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
@ -118,65 +120,34 @@ fn main() -> Result<()> {
|
|||||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||||
None => DType::F16,
|
None => DType::F16,
|
||||||
};
|
};
|
||||||
let (llama, tokenizer_filename, cache) = match args.npy {
|
let (llama, tokenizer_filename, cache) = {
|
||||||
Some(filename) => {
|
|
||||||
let config = if args.v1 {
|
|
||||||
Config::config_7b_v1(args.use_flash_attn)
|
|
||||||
} else {
|
|
||||||
Config::config_7b_v2(args.use_flash_attn)
|
|
||||||
};
|
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
|
||||||
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
|
|
||||||
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
|
||||||
(Llama::load(vb, &cache, &config)?, tokenizer, cache)
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let model_id = args.model_id.unwrap_or_else(|| {
|
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||||
if args.v1 {
|
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||||
"Narsil/amall-7b".to_string()
|
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||||
} else {
|
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||||
"meta-llama/Llama-2-7b-hf".to_string()
|
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||||
}
|
|
||||||
});
|
});
|
||||||
println!("loading the model weights from {model_id}");
|
println!("loading the model weights from {model_id}");
|
||||||
let revision = args.revision.unwrap_or("main".to_string());
|
let revision = args.revision.unwrap_or("main".to_string());
|
||||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
|
||||||
let tokenizer_filename = match &args.local_weights {
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
let config_filename = api.get("config.json")?;
|
||||||
_ => api.get("tokenizer.json")?,
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_filename = match &args.local_weights {
|
|
||||||
Some(path) => (path.to_owned() + "config.json").into(),
|
|
||||||
_ => api.get("config.json")?,
|
|
||||||
};
|
|
||||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
let config = config.into_config(args.use_flash_attn);
|
let config = config.into_config(args.use_flash_attn);
|
||||||
|
|
||||||
let mut filenames = vec![];
|
let filenames = match args.which {
|
||||||
for rfilename in [
|
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
||||||
"model-00001-of-00002.safetensors",
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||||
"model-00002-of-00002.safetensors",
|
|
||||||
] {
|
|
||||||
match &args.local_weights {
|
|
||||||
Some(path) => {
|
|
||||||
filenames.push((path.to_owned() + rfilename).into());
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let filename = api.get(rfilename)?;
|
|
||||||
filenames.push(filename);
|
|
||||||
}
|
}
|
||||||
|
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||||
};
|
};
|
||||||
}
|
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||||
@ -194,14 +165,14 @@ fn main() -> Result<()> {
|
|||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut token_generated = 0;
|
let mut token_generated = 0;
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let context_size = if cache.use_kv_cache && index > 0 {
|
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||||
1
|
(1, index_pos)
|
||||||
} else {
|
} else {
|
||||||
tokens.len()
|
(tokens.len(), 0)
|
||||||
};
|
};
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let logits = llama.forward(&input, index_pos)?;
|
let logits = llama.forward(&input, context_index)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let logits = if args.repeat_penalty == 1. {
|
let logits = if args.repeat_penalty == 1. {
|
||||||
logits
|
logits
|
||||||
|
@ -6,9 +6,10 @@ extern crate accelerate_src;
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
mod model;
|
use candle_transformers::models::llama2_c as model;
|
||||||
|
use candle_transformers::models::llama2_c_weights as weights;
|
||||||
|
use candle_transformers::models::quantized_llama2_c as qmodel;
|
||||||
mod training;
|
mod training;
|
||||||
mod weights;
|
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
@ -19,6 +20,7 @@ use std::io::Write;
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
use model::{Config, Llama};
|
use model::{Config, Llama};
|
||||||
|
use qmodel::QLlama;
|
||||||
use weights::TransformerWeights;
|
use weights::TransformerWeights;
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
@ -152,6 +154,20 @@ fn main() -> anyhow::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
Llama(Llama),
|
||||||
|
QLlama(QLlama),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||||
|
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||||
use std::io::BufRead;
|
use std::io::BufRead;
|
||||||
|
|
||||||
@ -241,24 +257,66 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(common_args.cpu)?;
|
let device = candle_examples::device(common_args.cpu)?;
|
||||||
|
|
||||||
|
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
|
||||||
let is_safetensors = config_path
|
let is_safetensors = config_path
|
||||||
.extension()
|
.extension()
|
||||||
.map_or(false, |v| v == "safetensors");
|
.map_or(false, |v| v == "safetensors");
|
||||||
let (vb, config) = if is_safetensors {
|
let (model, config) = if is_gguf {
|
||||||
let config = Config::tiny();
|
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
||||||
|
let (_vocab_size, dim) = vb
|
||||||
|
.get_no_shape("model.embed_tokens.weight")?
|
||||||
|
.shape()
|
||||||
|
.dims2()?;
|
||||||
|
let config = match dim {
|
||||||
|
64 => Config::tiny_260k(),
|
||||||
|
288 => Config::tiny_15m(),
|
||||||
|
512 => Config::tiny_42m(),
|
||||||
|
768 => Config::tiny_110m(),
|
||||||
|
_ => anyhow::bail!("no config for dim {dim}"),
|
||||||
|
};
|
||||||
|
let freq_cis_real = vb
|
||||||
|
.get(
|
||||||
|
(config.seq_len, config.head_size() / 2),
|
||||||
|
"rot.freq_cis_real",
|
||||||
|
)?
|
||||||
|
.dequantize(&device)?;
|
||||||
|
let freq_cis_imag = vb
|
||||||
|
.get(
|
||||||
|
(config.seq_len, config.head_size() / 2),
|
||||||
|
"rot.freq_cis_imag",
|
||||||
|
)?
|
||||||
|
.dequantize(&device)?;
|
||||||
|
|
||||||
|
let fake_vb = candle_nn::VarBuilder::from_tensors(
|
||||||
|
[
|
||||||
|
("freq_cis_real".to_string(), freq_cis_real),
|
||||||
|
("freq_cis_imag".to_string(), freq_cis_imag),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
candle::DType::F32,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||||
|
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||||
|
(model, config)
|
||||||
|
} else if is_safetensors {
|
||||||
|
let config = Config::tiny_15m();
|
||||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||||
(vb, config)
|
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||||
|
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||||
|
(model, config)
|
||||||
} else {
|
} else {
|
||||||
let mut file = std::fs::File::open(config_path)?;
|
let mut file = std::fs::File::open(config_path)?;
|
||||||
let config = Config::from_reader(&mut file)?;
|
let config = Config::from_reader(&mut file)?;
|
||||||
println!("{config:?}");
|
println!("{config:?}");
|
||||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||||
let vb = weights.var_builder(&config, &device)?;
|
let vb = weights.var_builder(&config, &device)?;
|
||||||
(vb, config)
|
|
||||||
};
|
|
||||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||||
|
(model, config)
|
||||||
|
};
|
||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
||||||
@ -273,7 +331,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0.. {
|
for index in 0.. {
|
||||||
if tokens.len() >= model.config.seq_len {
|
if tokens.len() >= config.seq_len {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|||||||
);
|
);
|
||||||
let varmap = candle_nn::VarMap::new();
|
let varmap = candle_nn::VarMap::new();
|
||||||
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||||
let config = Config::tiny();
|
let config = Config::tiny_15m();
|
||||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
|
|
||||||
|
@ -143,14 +143,7 @@ fn main() -> Result<()> {
|
|||||||
let config_filename = api.get("config.json")?;
|
let config_filename = api.get("config.json")?;
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
let mut filenames = vec![];
|
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||||
for rfilename in [
|
|
||||||
"model-00001-of-00002.safetensors",
|
|
||||||
"model-00002-of-00002.safetensors",
|
|
||||||
] {
|
|
||||||
let filename = api.get(rfilename)?;
|
|
||||||
filenames.push(filename);
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.rank.is_none() {
|
if args.rank.is_none() {
|
||||||
let children: Vec<_> = (0..args.num_shards)
|
let children: Vec<_> = (0..args.num_shards)
|
||||||
|
12
candle-examples/examples/mamba-minimal/README.md
Normal file
12
candle-examples/examples/mamba-minimal/README.md
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# candle-mamba-minimal: minimal implementation of Mamba
|
||||||
|
|
||||||
|
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
|
||||||
|
Mamba is the most popular and best-selling game in the world. It has been downloaded more than 1,000 times by over 1 million people worldwide since its release on March 18th 2016.
|
||||||
|
|
||||||
|
The Mamba series of games are a collection that combines elements from all genres including action, adventure, strategy & puzzle games with some unique gameplay features such as stealth and survival. The game is also known for its innovative graphics and the ability to play in a variety of different modes like single player or multiplayer.
|
||||||
|
```
|
287
candle-examples/examples/mamba-minimal/main.rs
Normal file
287
candle-examples/examples/mamba-minimal/main.rs
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
mod model;
|
||||||
|
use model::{Config, Model};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Module, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the </s> token"),
|
||||||
|
};
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for _ in 0..sample_len {
|
||||||
|
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
||||||
|
enum Which {
|
||||||
|
Mamba130m,
|
||||||
|
Mamba370m,
|
||||||
|
Mamba790m,
|
||||||
|
Mamba1_4b,
|
||||||
|
Mamba2_8b,
|
||||||
|
Mamba2_8bSlimPj,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Which {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{:?}", self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn model_id(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Mamba130m => "state-spaces/mamba-130m",
|
||||||
|
Self::Mamba370m => "state-spaces/mamba-370m",
|
||||||
|
Self::Mamba790m => "state-spaces/mamba-790m",
|
||||||
|
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
|
||||||
|
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
|
||||||
|
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn revision(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Mamba130m
|
||||||
|
| Self::Mamba370m
|
||||||
|
| Self::Mamba790m
|
||||||
|
| Self::Mamba1_4b
|
||||||
|
| Self::Mamba2_8bSlimPj => "refs/pr/1",
|
||||||
|
Self::Mamba2_8b => "refs/pr/4",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "mamba130m")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
args.model_id
|
||||||
|
.unwrap_or_else(|| args.which.model_id().to_string()),
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision
|
||||||
|
.unwrap_or_else(|| args.which.revision().to_string()),
|
||||||
|
));
|
||||||
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => api
|
||||||
|
.model("EleutherAI/gpt-neox-20b".to_string())
|
||||||
|
.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let config_filename = match args.config_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("config.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => {
|
||||||
|
vec![repo.get("model.safetensors")?]
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
|
let model = Model::new(&config, vb.pp("backbone"))?;
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
204
candle-examples/examples/mamba-minimal/model.rs
Normal file
204
candle-examples/examples/mamba-minimal/model.rs
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
/// This follows the lines of:
|
||||||
|
/// https://github.com/johnma2006/mamba-minimal/blob/master/model.py
|
||||||
|
/// Simple, minimal implementation of Mamba in one file of PyTorch.
|
||||||
|
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{RmsNorm, VarBuilder};
|
||||||
|
|
||||||
|
use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
d_model: usize,
|
||||||
|
n_layer: usize,
|
||||||
|
vocab_size: usize,
|
||||||
|
pad_vocab_size_multiple: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
fn vocab_size(&self) -> usize {
|
||||||
|
let pad = self.pad_vocab_size_multiple;
|
||||||
|
(self.vocab_size + pad - 1) / pad * pad
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dt_rank(&self) -> usize {
|
||||||
|
(self.d_model + 15) / 16
|
||||||
|
}
|
||||||
|
|
||||||
|
fn d_conv(&self) -> usize {
|
||||||
|
4
|
||||||
|
}
|
||||||
|
|
||||||
|
fn d_state(&self) -> usize {
|
||||||
|
16
|
||||||
|
}
|
||||||
|
|
||||||
|
fn d_inner(&self) -> usize {
|
||||||
|
self.d_model * 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L177
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct MambaBlock {
|
||||||
|
in_proj: Linear,
|
||||||
|
conv1d: candle_nn::Conv1d,
|
||||||
|
x_proj: Linear,
|
||||||
|
dt_proj: Linear,
|
||||||
|
a_log: Tensor,
|
||||||
|
d: Tensor,
|
||||||
|
out_proj: Linear,
|
||||||
|
dt_rank: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MambaBlock {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let d_inner = cfg.d_inner();
|
||||||
|
let d_conv = cfg.d_conv();
|
||||||
|
let d_state = cfg.d_state();
|
||||||
|
let dt_rank = cfg.dt_rank();
|
||||||
|
let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?;
|
||||||
|
let conv_cfg = candle_nn::Conv1dConfig {
|
||||||
|
groups: d_inner,
|
||||||
|
padding: d_conv - 1,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let conv1d = candle_nn::conv1d(d_inner, d_inner, d_conv, conv_cfg, vb.pp("conv1d"))?;
|
||||||
|
let x_proj = linear_no_bias(d_inner, dt_rank + d_state * 2, vb.pp("x_proj"))?;
|
||||||
|
let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
|
||||||
|
let a_log = vb.get((d_inner, d_state), "A_log")?;
|
||||||
|
let d = vb.get(d_inner, "D")?;
|
||||||
|
let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
in_proj,
|
||||||
|
conv1d,
|
||||||
|
x_proj,
|
||||||
|
dt_proj,
|
||||||
|
a_log,
|
||||||
|
d,
|
||||||
|
out_proj,
|
||||||
|
dt_rank,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ssm(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (_d_in, n) = self.a_log.dims2()?;
|
||||||
|
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
|
||||||
|
let d = self.d.to_dtype(candle::DType::F32)?;
|
||||||
|
let x_dbl = xs.apply(&self.x_proj)?;
|
||||||
|
let delta = x_dbl.narrow(D::Minus1, 0, self.dt_rank)?;
|
||||||
|
let b = x_dbl.narrow(D::Minus1, self.dt_rank, n)?;
|
||||||
|
let c = x_dbl.narrow(D::Minus1, self.dt_rank + n, n)?;
|
||||||
|
let delta = delta.contiguous()?.apply(&self.dt_proj)?;
|
||||||
|
// softplus without threshold
|
||||||
|
let delta = (delta.exp()? + 1.)?.log()?;
|
||||||
|
let ss = selective_scan(xs, &delta, &a, &b, &c, &d)?;
|
||||||
|
Ok(ss)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L275
|
||||||
|
fn selective_scan(
|
||||||
|
u: &Tensor,
|
||||||
|
delta: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
b: &Tensor,
|
||||||
|
c: &Tensor,
|
||||||
|
d: &Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (b_sz, l, d_in) = u.dims3()?;
|
||||||
|
let n = a.dim(1)?;
|
||||||
|
let delta = delta.t()?.reshape((b_sz, d_in, l, 1))?; // b d_in l 1
|
||||||
|
let delta_a = delta.broadcast_mul(&a.reshape((1, d_in, 1, n))?)?.exp()?;
|
||||||
|
let delta_b_u = delta
|
||||||
|
.broadcast_mul(&b.reshape((b_sz, 1, l, n))?)?
|
||||||
|
.broadcast_mul(&u.t()?.reshape((b_sz, d_in, l, 1))?)?;
|
||||||
|
let mut xs = Tensor::zeros((b_sz, d_in, n), delta_a.dtype(), delta_a.device())?;
|
||||||
|
let mut ys = Vec::with_capacity(l);
|
||||||
|
for i in 0..l {
|
||||||
|
xs = ((delta_a.i((.., .., i))? * xs)? + delta_b_u.i((.., .., i))?)?;
|
||||||
|
let y = xs.matmul(&c.i((.., i, ..))?.unsqueeze(2)?)?.squeeze(2)?;
|
||||||
|
ys.push(y)
|
||||||
|
}
|
||||||
|
let ys = Tensor::stack(ys.as_slice(), 1)?;
|
||||||
|
ys + u.broadcast_mul(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MambaBlock {
|
||||||
|
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L206
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (_b_sz, seq_len, _dim) = xs.dims3()?;
|
||||||
|
let xs_and_res = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?;
|
||||||
|
let (xs, res) = (&xs_and_res[0], &xs_and_res[1]);
|
||||||
|
let xs = xs
|
||||||
|
.t()?
|
||||||
|
.apply(&self.conv1d)?
|
||||||
|
.narrow(D::Minus1, 0, seq_len)?
|
||||||
|
.t()?;
|
||||||
|
let xs = candle_nn::ops::silu(&xs)?;
|
||||||
|
let ys = (self.ssm(&xs)? * candle_nn::ops::silu(res))?;
|
||||||
|
ys.apply(&self.out_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L143
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ResidualBlock {
|
||||||
|
mixer: MambaBlock,
|
||||||
|
norm: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResidualBlock {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?;
|
||||||
|
let mixer = MambaBlock::new(cfg, vb.pp("mixer"))?;
|
||||||
|
Ok(Self { mixer, norm })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ResidualBlock {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.norm)?.apply(&self.mixer)? + xs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct Model {
|
||||||
|
embedding: candle_nn::Embedding,
|
||||||
|
layers: Vec<ResidualBlock>,
|
||||||
|
norm_f: RmsNorm,
|
||||||
|
lm_head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?;
|
||||||
|
let mut layers = Vec::with_capacity(cfg.n_layer);
|
||||||
|
let vb_l = vb.pp("layers");
|
||||||
|
for layer_idx in 0..cfg.n_layer {
|
||||||
|
let layer = ResidualBlock::new(cfg, vb_l.pp(layer_idx))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
|
||||||
|
let lm_head = Linear::from_weights(embedding.embeddings().clone(), None);
|
||||||
|
Ok(Self {
|
||||||
|
embedding,
|
||||||
|
layers,
|
||||||
|
norm_f,
|
||||||
|
lm_head,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Model {
|
||||||
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let (_b_size, seq_len) = input_ids.dims2()?;
|
||||||
|
let mut xs = self.embedding.forward(input_ids)?;
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
xs = layer.forward(&xs)?
|
||||||
|
}
|
||||||
|
xs.narrow(1, seq_len - 1, 1)?
|
||||||
|
.apply(&self.norm_f)?
|
||||||
|
.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
}
|
38
candle-examples/examples/marian-mt/README.md
Normal file
38
candle-examples/examples/marian-mt/README.md
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# candle-marian-mt
|
||||||
|
|
||||||
|
`marian-mt` is a neural machine translation model. In this example it is used to
|
||||||
|
translate text from French to English. See the associated [model
|
||||||
|
card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on
|
||||||
|
the model itself.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example marian-mt --release -- \
|
||||||
|
--text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps."
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
<NIL> Tomorrow, at dawn, at the time when the country is whitening, I will go. See,
|
||||||
|
I know you are waiting for me. I will go through the forest, I will go through the
|
||||||
|
mountain. I cannot stay far from you any longer.</s>
|
||||||
|
```
|
||||||
|
|
||||||
|
## Generating the tokenizer.json files
|
||||||
|
|
||||||
|
You can use the following script to generate the `tokenizer.json` config files
|
||||||
|
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
|
||||||
|
packages to be install and use the `convert_slow_tokenizer.py` script from this
|
||||||
|
directory.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from convert_slow_tokenizer import MarianConverter
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
||||||
|
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
||||||
|
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
|
||||||
|
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
||||||
|
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
|
||||||
|
```
|
1385
candle-examples/examples/marian-mt/convert_slow_tokenizer.py
Normal file
1385
candle-examples/examples/marian-mt/convert_slow_tokenizer.py
Normal file
File diff suppressed because it is too large
Load Diff
152
candle-examples/examples/marian-mt/main.rs
Normal file
152
candle-examples/examples/marian-mt/main.rs
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::marian;
|
||||||
|
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
Base,
|
||||||
|
Big,
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Maybe add support for the conditional prompt.
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_dec: Option<String>,
|
||||||
|
|
||||||
|
/// Choose the variant of the model to run.
|
||||||
|
#[arg(long, default_value = "big")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Use the quantized version of the model.
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
|
||||||
|
/// Text to be translated
|
||||||
|
#[arg(long)]
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let config = match args.which {
|
||||||
|
Which::Base => marian::Config::opus_mt_fr_en(),
|
||||||
|
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
|
||||||
|
};
|
||||||
|
let tokenizer = {
|
||||||
|
let tokenizer = match args.tokenizer {
|
||||||
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
|
None => {
|
||||||
|
let name = match args.which {
|
||||||
|
Which::Base => "tokenizer-marian-base-fr.json",
|
||||||
|
Which::Big => "tokenizer-marian-fr.json",
|
||||||
|
};
|
||||||
|
Api::new()?
|
||||||
|
.model("lmz/candle-marian".to_string())
|
||||||
|
.get(name)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokenizer_dec = {
|
||||||
|
let tokenizer = match args.tokenizer_dec {
|
||||||
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
|
None => {
|
||||||
|
let name = match args.which {
|
||||||
|
Which::Base => "tokenizer-marian-base-en.json",
|
||||||
|
Which::Big => "tokenizer-marian-en.json",
|
||||||
|
};
|
||||||
|
Api::new()?
|
||||||
|
.model("lmz/candle-marian".to_string())
|
||||||
|
.get(name)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
|
};
|
||||||
|
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let vb = {
|
||||||
|
let model = match args.model {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => match args.which {
|
||||||
|
Which::Base => Api::new()?
|
||||||
|
.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/4".to_string(),
|
||||||
|
))
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
Which::Big => Api::new()?
|
||||||
|
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||||
|
};
|
||||||
|
let mut model = marian::MTModel::new(&config, vb)?;
|
||||||
|
|
||||||
|
let mut logits_processor =
|
||||||
|
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||||
|
|
||||||
|
let encoder_xs = {
|
||||||
|
let mut tokens = tokenizer
|
||||||
|
.encode(args.text, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
tokens.push(config.eos_token_id);
|
||||||
|
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
|
model.encoder().forward(&tokens, 0)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut token_ids = vec![config.decoder_start_token_id];
|
||||||
|
for index in 0..1000 {
|
||||||
|
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||||
|
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||||
|
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||||
|
let token = logits_processor.sample(&logits)?;
|
||||||
|
token_ids.push(token);
|
||||||
|
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||||
|
use std::io::Write;
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -155,8 +155,8 @@ struct Args {
|
|||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long, default_value = "lmz/candle-mistral")]
|
#[arg(long)]
|
||||||
model_id: String,
|
model_id: Option<String>,
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
#[arg(long, default_value = "main")]
|
||||||
revision: String,
|
revision: String,
|
||||||
@ -207,8 +207,18 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id,
|
||||||
|
None => {
|
||||||
|
if args.quantized {
|
||||||
|
"lmz/candle-mistral".to_string()
|
||||||
|
} else {
|
||||||
|
"mistralai/Mistral-7B-v0.1".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
args.model_id,
|
model_id,
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
args.revision,
|
args.revision,
|
||||||
));
|
));
|
||||||
@ -225,10 +235,7 @@ fn main() -> Result<()> {
|
|||||||
if args.quantized {
|
if args.quantized {
|
||||||
vec![repo.get("model-q4k.gguf")?]
|
vec![repo.get("model-q4k.gguf")?]
|
||||||
} else {
|
} else {
|
||||||
vec![
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
repo.get("pytorch_model-00001-of-00002.safetensors")?,
|
|
||||||
repo.get("pytorch_model-00002-of-00002.safetensors")?,
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -237,13 +244,14 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (model, device) = if args.quantized {
|
let (model, device) = if args.quantized {
|
||||||
let filename = &filenames[0];
|
let filename = &filenames[0];
|
||||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
let vb =
|
||||||
|
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||||
let model = QMistral::new(&config, vb)?;
|
let model = QMistral::new(&config, vb)?;
|
||||||
(Model::Quantized(model), Device::Cpu)
|
(Model::Quantized(model), device)
|
||||||
} else {
|
} else {
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let dtype = if device.is_cuda() {
|
let dtype = if device.is_cuda() {
|
||||||
DType::BF16
|
DType::BF16
|
||||||
} else {
|
} else {
|
||||||
|
25
candle-examples/examples/mixtral/README.md
Normal file
25
candle-examples/examples/mixtral/README.md
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# candle-mixtral: 8x7b LLM using a sparse mixture of experts.
|
||||||
|
|
||||||
|
Mixtral-8x7B-v0.1 is a pretrained generative LLM with 56 billion parameters.
|
||||||
|
|
||||||
|
- [Blog post](https://mistral.ai/news/mixtral-of-experts/) from Mistral announcing the model release.
|
||||||
|
- [Model card](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) on the HuggingFace Hub.
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example mixtral --release -- --prompt "def print_prime(n): "
|
||||||
|
def print_prime(n): # n is the number of prime numbers to be printed
|
||||||
|
i = 2
|
||||||
|
count = 0
|
||||||
|
while (count < n):
|
||||||
|
if (isPrime(i)):
|
||||||
|
print(i)
|
||||||
|
count += 1
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
def isPrime(n):
|
||||||
|
for x in range(2, int(n**0.5)+1):
|
||||||
|
if (n % x == 0):
|
||||||
|
...
|
||||||
|
```
|
241
candle-examples/examples/mixtral/main.rs
Normal file
241
candle-examples/examples/mixtral/main.rs
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::mixtral::{Config, Model};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("</s>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the </s> token"),
|
||||||
|
};
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
|
||||||
|
model_id: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
args.model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = Config::v0_1_8x7b(args.use_flash_attn);
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
|
|||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
|
||||||
use candle::{DType, Result, Tensor, D};
|
use candle::{DType, Result, Tensor, D};
|
||||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
|
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
|
||||||
|
|
||||||
const IMAGE_DIM: usize = 784;
|
const IMAGE_DIM: usize = 784;
|
||||||
const LABELS: usize = 10;
|
const LABELS: usize = 10;
|
||||||
@ -95,7 +95,7 @@ impl ConvNet {
|
|||||||
.flatten_from(1)?
|
.flatten_from(1)?
|
||||||
.apply(&self.fc1)?
|
.apply(&self.fc1)?
|
||||||
.relu()?;
|
.relu()?;
|
||||||
self.dropout.forward(&xs, train)?.apply(&self.fc2)
|
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
22
candle-examples/examples/mobileone/README.md
Normal file
22
candle-examples/examples/mobileone/README.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# candle-mobileone
|
||||||
|
|
||||||
|
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
|
||||||
|
|
||||||
|
This candle implementation uses a pre-trained MobileOne network for inference. The
|
||||||
|
classification head has been trained on the ImageNet dataset and returns the
|
||||||
|
probabilities for the top-5 classes.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
|
||||||
|
|
||||||
|
loaded image Tensor[dims 3, 224, 224; f32]
|
||||||
|
model built
|
||||||
|
mountain bike, all-terrain bike, off-roader: 79.33%
|
||||||
|
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
|
||||||
|
crash helmet : 2.58%
|
||||||
|
unicycle, monocycle : 1.70%
|
||||||
|
alp : 0.21%
|
||||||
|
|
||||||
|
```
|
96
candle-examples/examples/mobileone/main.rs
Normal file
96
candle-examples/examples/mobileone/main.rs
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::mobileone;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
S0,
|
||||||
|
S1,
|
||||||
|
S2,
|
||||||
|
S3,
|
||||||
|
S4,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn model_filename(&self) -> String {
|
||||||
|
let name = match self {
|
||||||
|
Self::S0 => "s0",
|
||||||
|
Self::S1 => "s1",
|
||||||
|
Self::S2 => "s2",
|
||||||
|
Self::S3 => "s3",
|
||||||
|
Self::S4 => "s4",
|
||||||
|
};
|
||||||
|
format!("timm/mobileone_{}.apple_in1k", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config(&self) -> mobileone::Config {
|
||||||
|
match self {
|
||||||
|
Self::S0 => mobileone::Config::s0(),
|
||||||
|
Self::S1 => mobileone::Config::s1(),
|
||||||
|
Self::S2 => mobileone::Config::s2(),
|
||||||
|
Self::S3 => mobileone::Config::s3(),
|
||||||
|
Self::S4 => mobileone::Config::s4(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(value_enum, long, default_value_t=Which::S0)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let model_name = args.which.model_filename();
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model(model_name);
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -8,6 +8,7 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
|||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
enum NormType {
|
enum NormType {
|
||||||
WeightNorm,
|
WeightNorm,
|
||||||
|
TimeGroupNorm,
|
||||||
None,
|
None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -268,6 +269,7 @@ impl Module for EncodecConvTranspose1d {
|
|||||||
struct EncodecConv1d {
|
struct EncodecConv1d {
|
||||||
causal: bool,
|
causal: bool,
|
||||||
conv: Conv1d,
|
conv: Conv1d,
|
||||||
|
norm: Option<candle_nn::GroupNorm>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecConv1d {
|
impl EncodecConv1d {
|
||||||
@ -292,7 +294,7 @@ impl EncodecConv1d {
|
|||||||
},
|
},
|
||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
NormType::None => conv1d(
|
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||||
in_c,
|
in_c,
|
||||||
out_c,
|
out_c,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
@ -305,9 +307,17 @@ impl EncodecConv1d {
|
|||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
};
|
};
|
||||||
|
let norm = match cfg.norm_type {
|
||||||
|
NormType::None | NormType::WeightNorm => None,
|
||||||
|
NormType::TimeGroupNorm => {
|
||||||
|
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||||
|
Some(gn)
|
||||||
|
}
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
causal: cfg.use_causal_conv,
|
causal: cfg.use_causal_conv,
|
||||||
conv,
|
conv,
|
||||||
|
norm,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -316,8 +326,10 @@ impl Module for EncodecConv1d {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
// TODO: padding, depending on causal.
|
// TODO: padding, depending on causal.
|
||||||
let xs = self.conv.forward(xs)?;
|
let xs = self.conv.forward(xs)?;
|
||||||
// If we add support for NormType "time_group_norm", we should add some normalization here.
|
match &self.norm {
|
||||||
Ok(xs)
|
None => Ok(xs),
|
||||||
|
Some(norm) => xs.apply(norm),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -321,7 +321,7 @@ impl MusicgenDecoder {
|
|||||||
let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
|
let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
|
||||||
let mut xs = inputs_embeds.broadcast_add(&positions)?;
|
let mut xs = inputs_embeds.broadcast_add(&positions)?;
|
||||||
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
|
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
|
||||||
for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() {
|
for decoder_layer in self.layers.iter_mut() {
|
||||||
xs = decoder_layer.forward(&xs, &attention_mask, None)?;
|
xs = decoder_layer.forward(&xs, &attention_mask, None)?;
|
||||||
}
|
}
|
||||||
let xs = self.layer_norm.forward(&xs)?;
|
let xs = self.layer_norm.forward(&xs)?;
|
||||||
|
10
candle-examples/examples/onnx/README.md
Normal file
10
candle-examples/examples/onnx/README.md
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
## Using ONNX models in Candle
|
||||||
|
|
||||||
|
This example demonstrates how to run ONNX based models in Candle, the model
|
||||||
|
being used here is a small sequeezenet variant.
|
||||||
|
|
||||||
|
You can run the example with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
```
|
78
candle-examples/examples/onnx/main.rs
Normal file
78
candle-examples/examples/onnx/main.rs
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::{IndexOp, D};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
SqueezeNet,
|
||||||
|
EfficientNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// The model to be used.
|
||||||
|
#[arg(value_enum, long, default_value_t = Which::SqueezeNet)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
|
let image = match args.which {
|
||||||
|
Which::SqueezeNet => image,
|
||||||
|
Which::EfficientNet => image.permute((1, 2, 0))?,
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model = match args.model {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => match args.which {
|
||||||
|
Which::SqueezeNet => hf_hub::api::sync::Api::new()?
|
||||||
|
.model("lmz/candle-onnx".into())
|
||||||
|
.get("squeezenet1.1-7.onnx")?,
|
||||||
|
Which::EfficientNet => hf_hub::api::sync::Api::new()?
|
||||||
|
.model("onnx/EfficientNet-Lite4".into())
|
||||||
|
.get("efficientnet-lite4-11.onnx")?,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = candle_onnx::read_file(model)?;
|
||||||
|
let graph = model.graph.as_ref().unwrap();
|
||||||
|
let mut inputs = std::collections::HashMap::new();
|
||||||
|
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
|
||||||
|
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||||
|
let output = outputs.remove(&graph.output[0].name).unwrap();
|
||||||
|
let prs = match args.which {
|
||||||
|
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
|
||||||
|
Which::EfficientNet => output,
|
||||||
|
};
|
||||||
|
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
// Sort the predictions and take the top 5
|
||||||
|
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
||||||
|
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||||
|
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Print the top predictions
|
||||||
|
for &(i, p) in &top {
|
||||||
|
println!(
|
||||||
|
"{:50}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[i],
|
||||||
|
p * 100.0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
87
candle-examples/examples/onnx_basics.rs
Normal file
87
candle-examples/examples/onnx_basics.rs
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use candle::{Device, Tensor};
|
||||||
|
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
|
#[derive(Subcommand, Debug, Clone)]
|
||||||
|
enum Command {
|
||||||
|
Print {
|
||||||
|
#[arg(long)]
|
||||||
|
file: String,
|
||||||
|
},
|
||||||
|
SimpleEval {
|
||||||
|
#[arg(long)]
|
||||||
|
file: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
pub struct Args {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Command,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
match args.command {
|
||||||
|
Command::Print { file } => {
|
||||||
|
let model = candle_onnx::read_file(file)?;
|
||||||
|
println!("{model:?}");
|
||||||
|
let graph = model.graph.unwrap();
|
||||||
|
for node in graph.node.iter() {
|
||||||
|
println!("{node:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Command::SimpleEval { file } => {
|
||||||
|
let model = candle_onnx::read_file(file)?;
|
||||||
|
let graph = model.graph.as_ref().unwrap();
|
||||||
|
let constants: std::collections::HashSet<_> =
|
||||||
|
graph.initializer.iter().map(|i| i.name.as_str()).collect();
|
||||||
|
let mut inputs = std::collections::HashMap::new();
|
||||||
|
for input in graph.input.iter() {
|
||||||
|
use candle_onnx::onnx::tensor_proto::DataType;
|
||||||
|
if constants.contains(input.name.as_str()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let type_ = input.r#type.as_ref().expect("no type for input");
|
||||||
|
let type_ = type_.value.as_ref().expect("no type.value for input");
|
||||||
|
let value = match type_ {
|
||||||
|
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
|
||||||
|
let dt = match DataType::try_from(tt.elem_type) {
|
||||||
|
Ok(dt) => match candle_onnx::dtype(dt) {
|
||||||
|
Some(dt) => dt,
|
||||||
|
None => {
|
||||||
|
anyhow::bail!(
|
||||||
|
"unsupported 'value' data-type {dt:?} for {}",
|
||||||
|
input.name
|
||||||
|
)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||||
|
};
|
||||||
|
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
|
||||||
|
let dims = shape
|
||||||
|
.dim
|
||||||
|
.iter()
|
||||||
|
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
||||||
|
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
||||||
|
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<usize>>>()?;
|
||||||
|
Tensor::zeros(dims, dt, &Device::Cpu)?
|
||||||
|
}
|
||||||
|
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||||
|
};
|
||||||
|
println!("input {}: {value:?}", input.name);
|
||||||
|
inputs.insert(input.name.clone(), value);
|
||||||
|
}
|
||||||
|
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||||
|
for (name, value) in outputs.iter() {
|
||||||
|
println!("output {name}: {value:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -1,14 +1,33 @@
|
|||||||
# candle-phi: 1.3b LLM with state of the art performance for <10b models.
|
# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
|
||||||
|
|
||||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using
|
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and
|
||||||
only 1.3 billion parameters but with state of the art performance compared to
|
[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using
|
||||||
|
only 1.3 and 2.7 billion parameters but with state of the art performance compared to
|
||||||
models with up to 10 billion parameters.
|
models with up to 10 billion parameters.
|
||||||
|
|
||||||
The candle implementation provides both the standard version as well as a
|
The candle implementation provides both the standard version as well as a
|
||||||
quantized variant.
|
quantized variant.
|
||||||
|
|
||||||
## Running some example
|
## Running some examples
|
||||||
|
|
||||||
|
For the v2 version.
|
||||||
|
```bash
|
||||||
|
$ cargo run --example phi --release -- --model 2 \
|
||||||
|
--prompt "A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?"
|
||||||
|
|
||||||
|
A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?
|
||||||
|
|
||||||
|
Solution:
|
||||||
|
The potential energy of the skier is converted into kinetic energy as it slides down the slope. The formula for potential energy is mgh, where m is mass, g is acceleration due to gravity (9.8 m/s^2), and h is height. Since there's no friction, all the potential energy is converted into kinetic energy at the bottom of the slope. The formula for kinetic energy is 1/2mv^2, where v is velocity. We can equate these two formulas:
|
||||||
|
mgh = 1/2mv^2
|
||||||
|
Solving for v, we get:
|
||||||
|
v = sqrt(2gh)
|
||||||
|
Substituting the given values, we get:
|
||||||
|
v = sqrt(2*9.8*40) = 28 m/s
|
||||||
|
Therefore, the skier speed at the bottom of the slope is 28 m/s.
|
||||||
|
```
|
||||||
|
|
||||||
|
For the v1.5 version.
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
|
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ use anyhow::{Error as E, Result};
|
|||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||||
|
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
|
||||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
@ -18,6 +19,7 @@ use tokenizers::Tokenizer;
|
|||||||
|
|
||||||
enum Model {
|
enum Model {
|
||||||
MixFormer(MixFormer),
|
MixFormer(MixFormer),
|
||||||
|
Phi(Phi),
|
||||||
Quantized(QMixFormer),
|
Quantized(QMixFormer),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,6 +86,7 @@ impl TextGeneration {
|
|||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
let logits = match &mut self.model {
|
let logits = match &mut self.model {
|
||||||
Model::MixFormer(m) => m.forward(&input)?,
|
Model::MixFormer(m) => m.forward(&input)?,
|
||||||
|
Model::Phi(m) => m.forward(&input)?,
|
||||||
Model::Quantized(m) => m.forward(&input)?,
|
Model::Quantized(m) => m.forward(&input)?,
|
||||||
};
|
};
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
@ -117,13 +120,18 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
||||||
enum WhichModel {
|
enum WhichModel {
|
||||||
#[value(name = "1")]
|
#[value(name = "1")]
|
||||||
V1,
|
V1,
|
||||||
#[value(name = "1.5")]
|
#[value(name = "1.5")]
|
||||||
V1_5,
|
V1_5,
|
||||||
|
#[value(name = "2")]
|
||||||
|
V2,
|
||||||
|
#[value(name = "2-old")]
|
||||||
|
V2Old,
|
||||||
PuffinPhiV2,
|
PuffinPhiV2,
|
||||||
|
PhiHermes,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -142,7 +150,10 @@ struct Args {
|
|||||||
verbose_prompt: bool,
|
verbose_prompt: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: String,
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
mmlu_dir: Option<String>,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -157,13 +168,13 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
#[arg(long, default_value = "1.5")]
|
#[arg(long, default_value = "2")]
|
||||||
model: WhichModel,
|
model: WhichModel,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -224,7 +235,10 @@ fn main() -> Result<()> {
|
|||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||||
WhichModel::PuffinPhiV2 => "lmz/candle-quantized-phi".to_string(),
|
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
||||||
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
|
"lmz/candle-quantized-phi".to_string()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -236,9 +250,12 @@ fn main() -> Result<()> {
|
|||||||
"main".to_string()
|
"main".to_string()
|
||||||
} else {
|
} else {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => "refs/pr/2".to_string(),
|
WhichModel::V1 => "refs/pr/8".to_string(),
|
||||||
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
WhichModel::V1_5 => "refs/pr/73".to_string(),
|
||||||
WhichModel::PuffinPhiV2 => "main".to_string(),
|
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
||||||
|
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
|
"main".to_string()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -247,23 +264,34 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer_filename = match args.tokenizer {
|
let tokenizer_filename = match args.tokenizer {
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => match args.model {
|
None => match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
|
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => {
|
||||||
WhichModel::PuffinPhiV2 => repo.get("tokenizer-puffin-phi-v2.json")?,
|
repo.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
|
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||||
|
}
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
let filename = match args.weight_file {
|
let filenames = match args.weight_file {
|
||||||
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||||
None => {
|
None => {
|
||||||
if args.quantized {
|
if args.quantized {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
|
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
|
||||||
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
|
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
|
||||||
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
|
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
||||||
|
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||||
|
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
|
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||||
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
|
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors(
|
||||||
|
&repo,
|
||||||
|
"model.safetensors.index.json",
|
||||||
|
)?,
|
||||||
|
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
|
||||||
|
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -272,23 +300,52 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = match args.model {
|
let config = || match args.model {
|
||||||
WhichModel::V1 => Config::v1(),
|
WhichModel::V1 => Config::v1(),
|
||||||
WhichModel::V1_5 => Config::v1_5(),
|
WhichModel::V1_5 => Config::v1_5(),
|
||||||
|
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
||||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||||
|
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||||
};
|
};
|
||||||
let (model, device) = if args.quantized {
|
|
||||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
|
||||||
let model = QMixFormer::new(&config, vb)?;
|
|
||||||
(Model::Quantized(model), Device::Cpu)
|
|
||||||
} else {
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
let model = if args.quantized {
|
||||||
let model = MixFormer::new(&config, vb)?;
|
let config = config();
|
||||||
(Model::MixFormer(model), device)
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||||
|
&filenames[0],
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
let model = match args.model {
|
||||||
|
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
|
||||||
|
_ => QMixFormer::new(&config, vb)?,
|
||||||
|
};
|
||||||
|
Model::Quantized(model)
|
||||||
|
} else {
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
|
match args.model {
|
||||||
|
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
||||||
|
let config_filename = repo.get("config.json")?;
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let config: PhiConfig = serde_json::from_str(&config)?;
|
||||||
|
let phi = Phi::new(&config, vb)?;
|
||||||
|
Model::Phi(phi)
|
||||||
|
}
|
||||||
|
WhichModel::V2Old => {
|
||||||
|
let config = config();
|
||||||
|
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
|
||||||
|
}
|
||||||
|
WhichModel::PhiHermes | WhichModel::PuffinPhiV2 => {
|
||||||
|
let config = config();
|
||||||
|
Model::MixFormer(MixFormer::new(&config, vb)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
match (args.prompt, args.mmlu_dir) {
|
||||||
|
(None, None) | (Some(_), Some(_)) => {
|
||||||
|
anyhow::bail!("exactly one of --prompt and --mmlu-dir must be specified")
|
||||||
|
}
|
||||||
|
(Some(prompt), None) => {
|
||||||
let mut pipeline = TextGeneration::new(
|
let mut pipeline = TextGeneration::new(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -300,6 +357,93 @@ fn main() -> Result<()> {
|
|||||||
args.verbose_prompt,
|
args.verbose_prompt,
|
||||||
&device,
|
&device,
|
||||||
);
|
);
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
pipeline.run(&prompt, args.sample_len)?;
|
||||||
|
}
|
||||||
|
(None, Some(mmlu_dir)) => mmlu(model, tokenizer, &device, mmlu_dir)?,
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mmlu<P: AsRef<std::path::Path>>(
|
||||||
|
mut model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
device: &Device,
|
||||||
|
mmlu_dir: P,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
for dir_entry in mmlu_dir.as_ref().read_dir()?.flatten() {
|
||||||
|
let dir_entry = dir_entry.path();
|
||||||
|
let theme = match dir_entry.file_stem().and_then(|v| v.to_str()) {
|
||||||
|
None => "".to_string(),
|
||||||
|
Some(v) => match v.strip_suffix("_test") {
|
||||||
|
None => v.replace('_', " "),
|
||||||
|
Some(v) => v.replace('_', " "),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
if dir_entry.extension().as_ref().and_then(|v| v.to_str()) != Some("csv") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
println!("reading {dir_entry:?}");
|
||||||
|
let dir_entry = std::fs::File::open(dir_entry)?;
|
||||||
|
let mut reader = csv::ReaderBuilder::new()
|
||||||
|
.has_headers(false)
|
||||||
|
.from_reader(dir_entry);
|
||||||
|
let token_a = tokenizer.token_to_id("A").unwrap();
|
||||||
|
let token_b = tokenizer.token_to_id("B").unwrap();
|
||||||
|
let token_c = tokenizer.token_to_id("C").unwrap();
|
||||||
|
let token_d = tokenizer.token_to_id("D").unwrap();
|
||||||
|
for row in reader.records() {
|
||||||
|
let row = match row {
|
||||||
|
Err(_) => continue,
|
||||||
|
Ok(row) => row,
|
||||||
|
};
|
||||||
|
if row.len() < 5 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let question = row.get(0).unwrap();
|
||||||
|
let answer_a = row.get(1).unwrap();
|
||||||
|
let answer_b = row.get(2).unwrap();
|
||||||
|
let answer_c = row.get(3).unwrap();
|
||||||
|
let answer_d = row.get(4).unwrap();
|
||||||
|
let answer = row.get(5).unwrap();
|
||||||
|
let prompt = format!(
|
||||||
|
"{} {theme}.\n{question}\nA. {answer_a}\nB. {answer_b}\nC. {answer_c}\nD. {answer_d}\nAnswer:\n",
|
||||||
|
"The following are multiple choice questions (with answers) about"
|
||||||
|
);
|
||||||
|
let tokens = tokenizer.encode(prompt.as_str(), true).map_err(E::msg)?;
|
||||||
|
let tokens = tokens.get_ids().to_vec();
|
||||||
|
let input = Tensor::new(tokens, device)?.unsqueeze(0)?;
|
||||||
|
let logits = match &mut model {
|
||||||
|
Model::MixFormer(m) => {
|
||||||
|
m.clear_kv_cache();
|
||||||
|
m.forward(&input)?
|
||||||
|
}
|
||||||
|
Model::Phi(m) => {
|
||||||
|
m.clear_kv_cache();
|
||||||
|
m.forward(&input)?
|
||||||
|
}
|
||||||
|
Model::Quantized(m) => {
|
||||||
|
m.clear_kv_cache();
|
||||||
|
m.forward(&input)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||||
|
let pr_a = logits_v[token_a as usize];
|
||||||
|
let pr_b = logits_v[token_b as usize];
|
||||||
|
let pr_c = logits_v[token_c as usize];
|
||||||
|
let pr_d = logits_v[token_d as usize];
|
||||||
|
let model_answer = if pr_a > pr_b && pr_a > pr_c && pr_a > pr_d {
|
||||||
|
"A"
|
||||||
|
} else if pr_b > pr_c && pr_b > pr_d {
|
||||||
|
"B"
|
||||||
|
} else if pr_c > pr_d {
|
||||||
|
"C"
|
||||||
|
} else {
|
||||||
|
"D"
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("{prompt}\n -> {model_answer} vs {answer}");
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# candle-quantized-t5
|
# candle-quantized-t5
|
||||||
|
|
||||||
|
## Seq2Seq example
|
||||||
|
|
||||||
This example uses a quantized version of the t5 model.
|
This example uses a quantized version of the t5 model.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -8,6 +10,8 @@ $ cargo run --example quantized-t5 --release -- --prompt "translate to German: A
|
|||||||
Eine schöne Kerze.
|
Eine schöne Kerze.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Generating Quantized weight files
|
||||||
|
|
||||||
The weight file is automatically retrieved from the hub. It is also possible to
|
The weight file is automatically retrieved from the hub. It is also possible to
|
||||||
generate quantized weight files from the original safetensors file by using the
|
generate quantized weight files from the original safetensors file by using the
|
||||||
`tensor-tools` command line utility via:
|
`tensor-tools` command line utility via:
|
||||||
@ -16,8 +20,11 @@ generate quantized weight files from the original safetensors file by using the
|
|||||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
To use a different model, specify the `model-id`. For example, you can use
|
## Using custom models
|
||||||
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
|
||||||
|
To use a different model, specify the `model-id`.
|
||||||
|
|
||||||
|
For example, for text editing, you can use quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example quantized-t5 --release -- \
|
$ cargo run --example quantized-t5 --release -- \
|
||||||
@ -26,6 +33,7 @@ $ cargo run --example quantized-t5 --release -- \
|
|||||||
--temperature 0
|
--temperature 0
|
||||||
...
|
...
|
||||||
Although their flight is weak, they run quickly through the tree canopy.
|
Although their flight is weak, they run quickly through the tree canopy.
|
||||||
|
```
|
||||||
|
|
||||||
By default, it will look for `model.gguf` and `config.json`, but you can specify
|
By default, it will look for `model.gguf` and `config.json`, but you can specify
|
||||||
custom local or remote `weight-file` and `config-file`s:
|
custom local or remote `weight-file` and `config-file`s:
|
||||||
@ -40,3 +48,16 @@ cargo run --example quantized-t5 --release -- \
|
|||||||
...
|
...
|
||||||
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
|
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### [MADLAD-400](https://arxiv.org/abs/2309.04662)
|
||||||
|
|
||||||
|
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example quantized-t5 --release -- \
|
||||||
|
--model-id "jbochi/madlad400-3b-mt" --weight-file "model-q4k.gguf" \
|
||||||
|
--prompt "<2de> How are you, my friend?" \
|
||||||
|
--temperature 0
|
||||||
|
...
|
||||||
|
Wie geht es dir, mein Freund?
|
||||||
|
```
|
||||||
|
@ -132,7 +132,8 @@ impl T5ModelBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
let device = Device::Cpu;
|
||||||
|
let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?;
|
||||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -173,7 +174,11 @@ fn main() -> Result<()> {
|
|||||||
.to_vec();
|
.to_vec();
|
||||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
let mut model = builder.build_model()?;
|
let mut model = builder.build_model()?;
|
||||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
let mut output_token_ids = [builder
|
||||||
|
.config
|
||||||
|
.decoder_start_token_id
|
||||||
|
.unwrap_or(builder.config.pad_token_id) as u32]
|
||||||
|
.to_vec();
|
||||||
let temperature = if args.temperature <= 0. {
|
let temperature = if args.temperature <= 0. {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
|
@ -26,6 +26,19 @@ cargo run --example quantized --release -- --prompt "The best thing about coding
|
|||||||
> The best thing about coding in rust is 1.) that I don’t need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines.
|
> The best thing about coding in rust is 1.) that I don’t need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Using the mixtral sparse mixture of expert model:
|
||||||
|
```bash
|
||||||
|
|
||||||
|
$ cargo run --example quantized --release -- --which mixtral --prompt "Lebesgue's integral is superior to Riemann's because "
|
||||||
|
> avx: true, neon: false, simd128: false, f16c: true
|
||||||
|
> temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64
|
||||||
|
> loaded 995 tensors (26.44GB) in 0.03s
|
||||||
|
Lebesgue's integral is superior to Riemann's because 1. it is defined for a wider class of functions, those which are absolutely integrable; 2. the definition does not involve limits in two variables---one being computed before the other (which makes some computations more difficult); and 3. interchange of order of integration is easier to establish than with Riemann's integral. On the other hand, Lebesgue's integral applies only for bounded functions defined on finite intervals; it does not provide numerical values for improper integrals. The latter are best evaluated using Cauchy's limit definition.
|
||||||
|
|
||||||
|
The reason $f(x) = x^2$ is discontinuous at the ends of its interval of definition, and Riemann's integral requires continuity on the whole of an open interval containing it (see our earlier post), sine no such function exists with this property, is that the endpoints are infinite in measure for Lebesgue's integral.
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Command-line flags
|
## Command-line flags
|
||||||
|
|
||||||
Run with `--help` to see all options.
|
Run with `--help` to see all options.
|
||||||
|
@ -9,9 +9,10 @@ use std::io::Write;
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
use candle::quantized::{ggml_file, gguf_file};
|
use candle::quantized::{ggml_file, gguf_file};
|
||||||
use candle::{Device, Tensor};
|
use candle::Tensor;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_transformers::models::quantized_llama as model;
|
use candle_transformers::models::quantized_llama as model;
|
||||||
use model::ModelWeights;
|
use model::ModelWeights;
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ enum Prompt {
|
|||||||
One(String),
|
One(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
enum Which {
|
enum Which {
|
||||||
#[value(name = "7b")]
|
#[value(name = "7b")]
|
||||||
L7b,
|
L7b,
|
||||||
@ -44,12 +45,28 @@ enum Which {
|
|||||||
L13bCode,
|
L13bCode,
|
||||||
#[value(name = "32b-code")]
|
#[value(name = "32b-code")]
|
||||||
L34bCode,
|
L34bCode,
|
||||||
|
#[value(name = "7b-leo")]
|
||||||
|
Leo7b,
|
||||||
|
#[value(name = "13b-leo")]
|
||||||
|
Leo13b,
|
||||||
#[value(name = "7b-mistral")]
|
#[value(name = "7b-mistral")]
|
||||||
Mistral7b,
|
Mistral7b,
|
||||||
#[value(name = "7b-mistral-instruct")]
|
#[value(name = "7b-mistral-instruct")]
|
||||||
Mistral7bInstruct,
|
Mistral7bInstruct,
|
||||||
#[value(name = "7b-zephyr")]
|
#[value(name = "7b-mistral-instruct-v0.2")]
|
||||||
Zephyr7b,
|
Mistral7bInstructV02,
|
||||||
|
#[value(name = "7b-zephyr-a")]
|
||||||
|
Zephyr7bAlpha,
|
||||||
|
#[value(name = "7b-zephyr-b")]
|
||||||
|
Zephyr7bBeta,
|
||||||
|
#[value(name = "7b-open-chat-3.5")]
|
||||||
|
OpenChat35,
|
||||||
|
#[value(name = "7b-starling-a")]
|
||||||
|
Starling7bAlpha,
|
||||||
|
#[value(name = "mixtral")]
|
||||||
|
Mixtral,
|
||||||
|
#[value(name = "mixtral-instruct")]
|
||||||
|
MixtralInstruct,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -63,8 +80,93 @@ impl Which {
|
|||||||
| Self::L70bChat
|
| Self::L70bChat
|
||||||
| Self::L7bCode
|
| Self::L7bCode
|
||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode => false,
|
| Self::L34bCode
|
||||||
Self::Mistral7b | Self::Mistral7bInstruct | Self::Zephyr7b => true,
|
| Self::Leo7b
|
||||||
|
| Self::Leo13b => false,
|
||||||
|
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||||
|
// same way. Starling is a fine tuned version of OpenChat.
|
||||||
|
Self::OpenChat35
|
||||||
|
| Self::Starling7bAlpha
|
||||||
|
| Self::Zephyr7bAlpha
|
||||||
|
| Self::Zephyr7bBeta
|
||||||
|
| Self::Mixtral
|
||||||
|
| Self::MixtralInstruct
|
||||||
|
| Self::Mistral7b
|
||||||
|
| Self::Mistral7bInstruct
|
||||||
|
| Self::Mistral7bInstructV02 => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_zephyr(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::L7b
|
||||||
|
| Self::L13b
|
||||||
|
| Self::L70b
|
||||||
|
| Self::L7bChat
|
||||||
|
| Self::L13bChat
|
||||||
|
| Self::L70bChat
|
||||||
|
| Self::L7bCode
|
||||||
|
| Self::L13bCode
|
||||||
|
| Self::L34bCode
|
||||||
|
| Self::Leo7b
|
||||||
|
| Self::Leo13b
|
||||||
|
| Self::Mixtral
|
||||||
|
| Self::MixtralInstruct
|
||||||
|
| Self::Mistral7b
|
||||||
|
| Self::Mistral7bInstruct
|
||||||
|
| Self::Mistral7bInstructV02
|
||||||
|
| Self::OpenChat35
|
||||||
|
| Self::Starling7bAlpha => false,
|
||||||
|
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_open_chat(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::L7b
|
||||||
|
| Self::L13b
|
||||||
|
| Self::L70b
|
||||||
|
| Self::L7bChat
|
||||||
|
| Self::L13bChat
|
||||||
|
| Self::L70bChat
|
||||||
|
| Self::L7bCode
|
||||||
|
| Self::L13bCode
|
||||||
|
| Self::L34bCode
|
||||||
|
| Self::Leo7b
|
||||||
|
| Self::Leo13b
|
||||||
|
| Self::Mixtral
|
||||||
|
| Self::MixtralInstruct
|
||||||
|
| Self::Mistral7b
|
||||||
|
| Self::Mistral7bInstruct
|
||||||
|
| Self::Mistral7bInstructV02
|
||||||
|
| Self::Zephyr7bAlpha
|
||||||
|
| Self::Zephyr7bBeta => false,
|
||||||
|
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tokenizer_repo(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Which::L7b
|
||||||
|
| Which::L13b
|
||||||
|
| Which::L70b
|
||||||
|
| Which::L7bChat
|
||||||
|
| Which::L13bChat
|
||||||
|
| Which::L70bChat
|
||||||
|
| Which::L7bCode
|
||||||
|
| Which::L13bCode
|
||||||
|
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
|
||||||
|
Which::Leo7b => "LeoLM/leo-hessianai-7b",
|
||||||
|
Which::Leo13b => "LeoLM/leo-hessianai-13b",
|
||||||
|
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
||||||
|
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
|
Which::Mistral7b
|
||||||
|
| Which::Mistral7bInstruct
|
||||||
|
| Which::Mistral7bInstructV02
|
||||||
|
| Which::Zephyr7bAlpha
|
||||||
|
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
||||||
|
Which::OpenChat35 => "openchat/openchat_3.5",
|
||||||
|
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -72,7 +174,7 @@ impl Which {
|
|||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
|
/// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from llama.cpp
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
|
||||||
@ -83,7 +185,7 @@ struct Args {
|
|||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(short = 'n', long, default_value_t = 100)]
|
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
/// The tokenizer config in json format.
|
/// The tokenizer config in json format.
|
||||||
@ -133,11 +235,7 @@ impl Args {
|
|||||||
Some(config) => std::path::PathBuf::from(config),
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let repo = if self.which.is_mistral() {
|
let repo = self.which.tokenizer_repo();
|
||||||
"mistralai/Mistral-7B-v0.1"
|
|
||||||
} else {
|
|
||||||
"hf-internal-testing/llama-tokenizer"
|
|
||||||
};
|
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
api.get("tokenizer.json")?
|
api.get("tokenizer.json")?
|
||||||
}
|
}
|
||||||
@ -168,6 +266,22 @@ impl Args {
|
|||||||
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
|
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
|
||||||
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
|
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
|
||||||
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
|
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
|
||||||
|
Which::Leo7b => (
|
||||||
|
"TheBloke/leo-hessianai-7B-GGUF",
|
||||||
|
"leo-hessianai-7b.Q4_K_M.gguf",
|
||||||
|
),
|
||||||
|
Which::Leo13b => (
|
||||||
|
"TheBloke/leo-hessianai-13B-GGUF",
|
||||||
|
"leo-hessianai-13b.Q4_K_M.gguf",
|
||||||
|
),
|
||||||
|
Which::Mixtral => (
|
||||||
|
"TheBloke/Mixtral-8x7B-v0.1-GGUF",
|
||||||
|
"mixtral-8x7b-v0.1.Q4_K_M.gguf",
|
||||||
|
),
|
||||||
|
Which::MixtralInstruct => (
|
||||||
|
"TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF",
|
||||||
|
"mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf",
|
||||||
|
),
|
||||||
Which::Mistral7b => (
|
Which::Mistral7b => (
|
||||||
"TheBloke/Mistral-7B-v0.1-GGUF",
|
"TheBloke/Mistral-7B-v0.1-GGUF",
|
||||||
"mistral-7b-v0.1.Q4_K_S.gguf",
|
"mistral-7b-v0.1.Q4_K_S.gguf",
|
||||||
@ -176,10 +290,22 @@ impl Args {
|
|||||||
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
||||||
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
||||||
),
|
),
|
||||||
Which::Zephyr7b => (
|
Which::Mistral7bInstructV02 => (
|
||||||
|
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
|
||||||
|
"mistral-7b-instruct-v0.2.Q4_K_S.gguf",
|
||||||
|
),
|
||||||
|
Which::Zephyr7bAlpha => (
|
||||||
"TheBloke/zephyr-7B-alpha-GGUF",
|
"TheBloke/zephyr-7B-alpha-GGUF",
|
||||||
"zephyr-7b-alpha.Q4_K_M.gguf",
|
"zephyr-7b-alpha.Q4_K_M.gguf",
|
||||||
),
|
),
|
||||||
|
Which::Zephyr7bBeta => {
|
||||||
|
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
||||||
|
}
|
||||||
|
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
|
||||||
|
Which::Starling7bAlpha => (
|
||||||
|
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
||||||
|
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
||||||
|
),
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
@ -190,31 +316,6 @@ impl Args {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
|
|
||||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
|
||||||
// heuristics as it seems to work well enough for this example. See the following for more
|
|
||||||
// details:
|
|
||||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
|
||||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
|
||||||
let text = text.replace('▁', " ");
|
|
||||||
let ascii = text
|
|
||||||
.strip_prefix("<0x")
|
|
||||||
.and_then(|t| t.strip_suffix('>'))
|
|
||||||
.and_then(|t| u8::from_str_radix(t, 16).ok());
|
|
||||||
match ascii {
|
|
||||||
None => print!("{text}"),
|
|
||||||
Some(ascii) => {
|
|
||||||
if let Some(chr) = char::from_u32(ascii as u32) {
|
|
||||||
if chr.is_ascii() {
|
|
||||||
print!("{chr}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let _ = std::io::stdout().flush();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_size(size_in_bytes: usize) -> String {
|
fn format_size(size_in_bytes: usize) -> String {
|
||||||
if size_in_bytes < 1_000 {
|
if size_in_bytes < 1_000 {
|
||||||
format!("{}B", size_in_bytes)
|
format!("{}B", size_in_bytes)
|
||||||
@ -260,15 +361,16 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let model_path = args.model()?;
|
let model_path = args.model()?;
|
||||||
let mut file = std::fs::File::open(&model_path)?;
|
let mut file = std::fs::File::open(&model_path)?;
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
|
let device = candle_examples::device(false)?;
|
||||||
|
|
||||||
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
||||||
Some("gguf") => {
|
Some("gguf") => {
|
||||||
let model = gguf_file::Content::read(&mut file)?;
|
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
||||||
let mut total_size_in_bytes = 0;
|
let mut total_size_in_bytes = 0;
|
||||||
for (_, tensor) in model.tensor_infos.iter() {
|
for (_, tensor) in model.tensor_infos.iter() {
|
||||||
let elem_count = tensor.shape.elem_count();
|
let elem_count = tensor.shape.elem_count();
|
||||||
total_size_in_bytes +=
|
total_size_in_bytes +=
|
||||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
|
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||||
}
|
}
|
||||||
println!(
|
println!(
|
||||||
"loaded {:?} tensors ({}) in {:.2}s",
|
"loaded {:?} tensors ({}) in {:.2}s",
|
||||||
@ -276,15 +378,16 @@ fn main() -> anyhow::Result<()> {
|
|||||||
&format_size(total_size_in_bytes),
|
&format_size(total_size_in_bytes),
|
||||||
start.elapsed().as_secs_f32(),
|
start.elapsed().as_secs_f32(),
|
||||||
);
|
);
|
||||||
ModelWeights::from_gguf(model, &mut file)?
|
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||||
}
|
}
|
||||||
Some("ggml" | "bin") | Some(_) | None => {
|
Some("ggml" | "bin") | Some(_) | None => {
|
||||||
let model = ggml_file::Content::read(&mut file)?;
|
let model = ggml_file::Content::read(&mut file, &device)
|
||||||
|
.map_err(|e| e.with_path(model_path))?;
|
||||||
let mut total_size_in_bytes = 0;
|
let mut total_size_in_bytes = 0;
|
||||||
for (_, tensor) in model.tensors.iter() {
|
for (_, tensor) in model.tensors.iter() {
|
||||||
let elem_count = tensor.shape().elem_count();
|
let elem_count = tensor.shape().elem_count();
|
||||||
total_size_in_bytes +=
|
total_size_in_bytes +=
|
||||||
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
|
elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
|
||||||
}
|
}
|
||||||
println!(
|
println!(
|
||||||
"loaded {:?} tensors ({}) in {:.2}s",
|
"loaded {:?} tensors ({}) in {:.2}s",
|
||||||
@ -300,12 +403,20 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L13bChat
|
| Which::L13bChat
|
||||||
| Which::L7bCode
|
| Which::L7bCode
|
||||||
| Which::L13bCode
|
| Which::L13bCode
|
||||||
| Which::L34bCode => 1,
|
| Which::L34bCode
|
||||||
Which::Mistral7b
|
| Which::Leo7b
|
||||||
|
| Which::Leo13b => 1,
|
||||||
|
Which::Mixtral
|
||||||
|
| Which::MixtralInstruct
|
||||||
|
| Which::Mistral7b
|
||||||
| Which::Mistral7bInstruct
|
| Which::Mistral7bInstruct
|
||||||
| Which::Zephyr7b
|
| Which::Mistral7bInstructV02
|
||||||
|
| Which::Zephyr7bAlpha
|
||||||
|
| Which::Zephyr7bBeta
|
||||||
| Which::L70b
|
| Which::L70b
|
||||||
| Which::L70bChat => 8,
|
| Which::L70bChat
|
||||||
|
| Which::OpenChat35
|
||||||
|
| Which::Starling7bAlpha => 8,
|
||||||
};
|
};
|
||||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||||
}
|
}
|
||||||
@ -313,6 +424,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
println!("model built");
|
println!("model built");
|
||||||
|
|
||||||
let tokenizer = args.tokenizer()?;
|
let tokenizer = args.tokenizer()?;
|
||||||
|
let mut tos = TokenOutputStream::new(tokenizer);
|
||||||
let prompt = match args.prompt.as_deref() {
|
let prompt = match args.prompt.as_deref() {
|
||||||
Some("chat") => Prompt::Chat,
|
Some("chat") => Prompt::Chat,
|
||||||
Some("interactive") => Prompt::Interactive,
|
Some("interactive") => Prompt::Interactive,
|
||||||
@ -321,10 +433,11 @@ fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let mut pre_prompt_tokens = vec![];
|
let mut pre_prompt_tokens = vec![];
|
||||||
loop {
|
for prompt_index in 0.. {
|
||||||
let prompt_str = match &prompt {
|
let prompt_str = match &prompt {
|
||||||
Prompt::One(prompt) => prompt.clone(),
|
Prompt::One(prompt) => prompt.clone(),
|
||||||
Prompt::Interactive | Prompt::Chat => {
|
Prompt::Interactive | Prompt::Chat => {
|
||||||
|
let is_interactive = matches!(prompt, Prompt::Interactive);
|
||||||
print!("> ");
|
print!("> ");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
@ -335,7 +448,15 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt.pop();
|
prompt.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if args.which.is_mistral() {
|
if args.which.is_open_chat() {
|
||||||
|
format!("GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:")
|
||||||
|
} else if args.which.is_zephyr() {
|
||||||
|
if prompt_index == 0 || is_interactive {
|
||||||
|
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
|
||||||
|
} else {
|
||||||
|
format!("<|user|>\n{prompt}</s>\n<|assistant|>")
|
||||||
|
}
|
||||||
|
} else if args.which.is_mistral() {
|
||||||
format!("[INST] {prompt} [/INST]")
|
format!("[INST] {prompt} [/INST]")
|
||||||
} else {
|
} else {
|
||||||
prompt
|
prompt
|
||||||
@ -343,7 +464,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
print!("{}", &prompt_str);
|
print!("{}", &prompt_str);
|
||||||
let tokens = tokenizer
|
let tokens = tos
|
||||||
|
.tokenizer()
|
||||||
.encode(prompt_str, true)
|
.encode(prompt_str, true)
|
||||||
.map_err(anyhow::Error::msg)?;
|
.map_err(anyhow::Error::msg)?;
|
||||||
if args.verbose_prompt {
|
if args.verbose_prompt {
|
||||||
@ -366,20 +488,28 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let start_prompt_processing = std::time::Instant::now();
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
let mut next_token = {
|
let mut next_token = {
|
||||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, 0)?;
|
let logits = model.forward(&input, 0)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
logits_processor.sample(&logits)?
|
logits_processor.sample(&logits)?
|
||||||
};
|
};
|
||||||
let prompt_dt = start_prompt_processing.elapsed();
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
print_token(next_token, &tokenizer);
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let eos_token = if args.which.is_open_chat() {
|
||||||
|
"<|end_of_turn|>"
|
||||||
|
} else {
|
||||||
|
"</s>"
|
||||||
|
};
|
||||||
|
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||||
let start_post_prompt = std::time::Instant::now();
|
let start_post_prompt = std::time::Instant::now();
|
||||||
|
let mut sampled = 0;
|
||||||
for index in 0..to_sample {
|
for index in 0..to_sample {
|
||||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let logits = if args.repeat_penalty == 1. {
|
let logits = if args.repeat_penalty == 1. {
|
||||||
@ -394,11 +524,19 @@ fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
next_token = logits_processor.sample(&logits)?;
|
next_token = logits_processor.sample(&logits)?;
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
print_token(next_token, &tokenizer);
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
sampled += 1;
|
||||||
if next_token == eos_token {
|
if next_token == eos_token {
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
let dt = start_post_prompt.elapsed();
|
let dt = start_post_prompt.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||||
@ -406,9 +544,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||||
);
|
);
|
||||||
println!(
|
println!(
|
||||||
"{:4} tokens generated: {:.2} token/s",
|
"{sampled:4} tokens generated: {:.2} token/s",
|
||||||
to_sample,
|
sampled as f64 / dt.as_secs_f64(),
|
||||||
to_sample as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
match prompt {
|
match prompt {
|
||||||
|
@ -8,9 +8,16 @@ Python package with:
|
|||||||
pip install "gymnasium[accept-rom-license]"
|
pip install "gymnasium[accept-rom-license]"
|
||||||
```
|
```
|
||||||
|
|
||||||
In order to run the example, use the following command. Note the additional
|
In order to run the examples, use the following commands. Note the additional
|
||||||
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
||||||
crate.
|
crate.
|
||||||
|
|
||||||
|
For the Policy Gradient example:
|
||||||
```bash
|
```bash
|
||||||
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
|
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- pg
|
||||||
|
```
|
||||||
|
|
||||||
|
For the Deep Deterministic Policy Gradient example:
|
||||||
|
```bash
|
||||||
|
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- ddpg
|
||||||
```
|
```
|
||||||
|
@ -78,7 +78,7 @@ class EpisodicLifeEnv(gym.Wrapper):
|
|||||||
# then update lives to handle bonus lives
|
# then update lives to handle bonus lives
|
||||||
lives = self.env.unwrapped.ale.lives()
|
lives = self.env.unwrapped.ale.lives()
|
||||||
if lives < self.lives and lives > 0:
|
if lives < self.lives and lives > 0:
|
||||||
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
|
# for Qbert sometimes we stay in lives == 0 condition for a few frames
|
||||||
# so its important to keep lives > 0, so that we only reset once
|
# so its important to keep lives > 0, so that we only reset once
|
||||||
# the environment advertises done.
|
# the environment advertises done.
|
||||||
done = True
|
done = True
|
||||||
|
556
candle-examples/examples/reinforcement-learning/ddpg.rs
Normal file
556
candle-examples/examples/reinforcement-learning/ddpg.rs
Normal file
@ -0,0 +1,556 @@
|
|||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Error, Module, Result, Tensor, Var};
|
||||||
|
use candle_nn::{
|
||||||
|
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
|
||||||
|
VarBuilder, VarMap,
|
||||||
|
};
|
||||||
|
use rand::{distributions::Uniform, thread_rng, Rng};
|
||||||
|
|
||||||
|
use super::gym_env::GymEnv;
|
||||||
|
|
||||||
|
pub struct OuNoise {
|
||||||
|
mu: f64,
|
||||||
|
theta: f64,
|
||||||
|
sigma: f64,
|
||||||
|
state: Tensor,
|
||||||
|
}
|
||||||
|
impl OuNoise {
|
||||||
|
pub fn new(mu: f64, theta: f64, sigma: f64, size_action: usize) -> Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
mu,
|
||||||
|
theta,
|
||||||
|
sigma,
|
||||||
|
state: Tensor::ones(size_action, DType::F32, &Device::Cpu)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sample(&mut self) -> Result<Tensor> {
|
||||||
|
let rand = Tensor::randn_like(&self.state, 0.0, 1.0)?;
|
||||||
|
let dx = ((self.theta * (self.mu - &self.state)?)? + (self.sigma * rand)?)?;
|
||||||
|
self.state = (&self.state + dx)?;
|
||||||
|
Ok(self.state.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Transition {
|
||||||
|
state: Tensor,
|
||||||
|
action: Tensor,
|
||||||
|
reward: Tensor,
|
||||||
|
next_state: Tensor,
|
||||||
|
terminated: bool,
|
||||||
|
truncated: bool,
|
||||||
|
}
|
||||||
|
impl Transition {
|
||||||
|
fn new(
|
||||||
|
state: &Tensor,
|
||||||
|
action: &Tensor,
|
||||||
|
reward: &Tensor,
|
||||||
|
next_state: &Tensor,
|
||||||
|
terminated: bool,
|
||||||
|
truncated: bool,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
state: state.clone(),
|
||||||
|
action: action.clone(),
|
||||||
|
reward: reward.clone(),
|
||||||
|
next_state: next_state.clone(),
|
||||||
|
terminated,
|
||||||
|
truncated,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ReplayBuffer {
|
||||||
|
buffer: VecDeque<Transition>,
|
||||||
|
capacity: usize,
|
||||||
|
size: usize,
|
||||||
|
}
|
||||||
|
impl ReplayBuffer {
|
||||||
|
pub fn new(capacity: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: VecDeque::with_capacity(capacity),
|
||||||
|
capacity,
|
||||||
|
size: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push(
|
||||||
|
&mut self,
|
||||||
|
state: &Tensor,
|
||||||
|
action: &Tensor,
|
||||||
|
reward: &Tensor,
|
||||||
|
next_state: &Tensor,
|
||||||
|
terminated: bool,
|
||||||
|
truncated: bool,
|
||||||
|
) {
|
||||||
|
if self.size == self.capacity {
|
||||||
|
self.buffer.pop_front();
|
||||||
|
} else {
|
||||||
|
self.size += 1;
|
||||||
|
}
|
||||||
|
self.buffer.push_back(Transition::new(
|
||||||
|
state, action, reward, next_state, terminated, truncated,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
pub fn random_batch(
|
||||||
|
&self,
|
||||||
|
batch_size: usize,
|
||||||
|
) -> Result<Option<(Tensor, Tensor, Tensor, Tensor, Vec<bool>, Vec<bool>)>> {
|
||||||
|
if self.size < batch_size {
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
|
let transitions: Vec<&Transition> = thread_rng()
|
||||||
|
.sample_iter(Uniform::from(0..self.size))
|
||||||
|
.take(batch_size)
|
||||||
|
.map(|i| self.buffer.get(i).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let states: Vec<Tensor> = transitions
|
||||||
|
.iter()
|
||||||
|
.map(|t| t.state.unsqueeze(0))
|
||||||
|
.collect::<Result<_>>()?;
|
||||||
|
let actions: Vec<Tensor> = transitions
|
||||||
|
.iter()
|
||||||
|
.map(|t| t.action.unsqueeze(0))
|
||||||
|
.collect::<Result<_>>()?;
|
||||||
|
let rewards: Vec<Tensor> = transitions
|
||||||
|
.iter()
|
||||||
|
.map(|t| t.reward.unsqueeze(0))
|
||||||
|
.collect::<Result<_>>()?;
|
||||||
|
let next_states: Vec<Tensor> = transitions
|
||||||
|
.iter()
|
||||||
|
.map(|t| t.next_state.unsqueeze(0))
|
||||||
|
.collect::<Result<_>>()?;
|
||||||
|
let terminateds: Vec<bool> = transitions.iter().map(|t| t.terminated).collect();
|
||||||
|
let truncateds: Vec<bool> = transitions.iter().map(|t| t.truncated).collect();
|
||||||
|
|
||||||
|
Ok(Some((
|
||||||
|
Tensor::cat(&states, 0)?,
|
||||||
|
Tensor::cat(&actions, 0)?,
|
||||||
|
Tensor::cat(&rewards, 0)?,
|
||||||
|
Tensor::cat(&next_states, 0)?,
|
||||||
|
terminateds,
|
||||||
|
truncateds,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn track(
|
||||||
|
varmap: &mut VarMap,
|
||||||
|
vb: &VarBuilder,
|
||||||
|
target_prefix: &str,
|
||||||
|
network_prefix: &str,
|
||||||
|
dims: &[(usize, usize)],
|
||||||
|
tau: f64,
|
||||||
|
) -> Result<()> {
|
||||||
|
for (i, &(in_dim, out_dim)) in dims.iter().enumerate() {
|
||||||
|
let target_w = vb.get((out_dim, in_dim), &format!("{target_prefix}-fc{i}.weight"))?;
|
||||||
|
let network_w = vb.get((out_dim, in_dim), &format!("{network_prefix}-fc{i}.weight"))?;
|
||||||
|
varmap.set_one(
|
||||||
|
format!("{target_prefix}-fc{i}.weight"),
|
||||||
|
((tau * network_w)? + ((1.0 - tau) * target_w)?)?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let target_b = vb.get(out_dim, &format!("{target_prefix}-fc{i}.bias"))?;
|
||||||
|
let network_b = vb.get(out_dim, &format!("{network_prefix}-fc{i}.bias"))?;
|
||||||
|
varmap.set_one(
|
||||||
|
format!("{target_prefix}-fc{i}.bias"),
|
||||||
|
((tau * network_b)? + ((1.0 - tau) * target_b)?)?,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Actor<'a> {
|
||||||
|
varmap: VarMap,
|
||||||
|
vb: VarBuilder<'a>,
|
||||||
|
network: Sequential,
|
||||||
|
target_network: Sequential,
|
||||||
|
size_state: usize,
|
||||||
|
size_action: usize,
|
||||||
|
dims: Vec<(usize, usize)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Actor<'_> {
|
||||||
|
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
|
||||||
|
let mut varmap = VarMap::new();
|
||||||
|
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
|
||||||
|
|
||||||
|
let dims = vec![(size_state, 400), (400, 300), (300, size_action)];
|
||||||
|
|
||||||
|
let make_network = |prefix: &str| {
|
||||||
|
let seq = seq()
|
||||||
|
.add(linear(
|
||||||
|
dims[0].0,
|
||||||
|
dims[0].1,
|
||||||
|
vb.pp(format!("{prefix}-fc0")),
|
||||||
|
)?)
|
||||||
|
.add(Activation::Relu)
|
||||||
|
.add(linear(
|
||||||
|
dims[1].0,
|
||||||
|
dims[1].1,
|
||||||
|
vb.pp(format!("{prefix}-fc1")),
|
||||||
|
)?)
|
||||||
|
.add(Activation::Relu)
|
||||||
|
.add(linear(
|
||||||
|
dims[2].0,
|
||||||
|
dims[2].1,
|
||||||
|
vb.pp(format!("{prefix}-fc2")),
|
||||||
|
)?)
|
||||||
|
.add(func(|xs| xs.tanh()));
|
||||||
|
Ok::<Sequential, Error>(seq)
|
||||||
|
};
|
||||||
|
|
||||||
|
let network = make_network("actor")?;
|
||||||
|
let target_network = make_network("target-actor")?;
|
||||||
|
|
||||||
|
// this sets the two networks to be equal to each other using tau = 1.0
|
||||||
|
track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
varmap,
|
||||||
|
vb,
|
||||||
|
network,
|
||||||
|
target_network,
|
||||||
|
size_state,
|
||||||
|
size_action,
|
||||||
|
dims,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, state: &Tensor) -> Result<Tensor> {
|
||||||
|
self.network.forward(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn target_forward(&self, state: &Tensor) -> Result<Tensor> {
|
||||||
|
self.target_network.forward(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn track(&mut self, tau: f64) -> Result<()> {
|
||||||
|
track(
|
||||||
|
&mut self.varmap,
|
||||||
|
&self.vb,
|
||||||
|
"target-actor",
|
||||||
|
"actor",
|
||||||
|
&self.dims,
|
||||||
|
tau,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Critic<'a> {
|
||||||
|
varmap: VarMap,
|
||||||
|
vb: VarBuilder<'a>,
|
||||||
|
network: Sequential,
|
||||||
|
target_network: Sequential,
|
||||||
|
size_state: usize,
|
||||||
|
size_action: usize,
|
||||||
|
dims: Vec<(usize, usize)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Critic<'_> {
|
||||||
|
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
|
||||||
|
let mut varmap = VarMap::new();
|
||||||
|
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
|
||||||
|
|
||||||
|
let dims: Vec<(usize, usize)> = vec![(size_state + size_action, 400), (400, 300), (300, 1)];
|
||||||
|
|
||||||
|
let make_network = |prefix: &str| {
|
||||||
|
let seq = seq()
|
||||||
|
.add(linear(
|
||||||
|
dims[0].0,
|
||||||
|
dims[0].1,
|
||||||
|
vb.pp(format!("{prefix}-fc0")),
|
||||||
|
)?)
|
||||||
|
.add(Activation::Relu)
|
||||||
|
.add(linear(
|
||||||
|
dims[1].0,
|
||||||
|
dims[1].1,
|
||||||
|
vb.pp(format!("{prefix}-fc1")),
|
||||||
|
)?)
|
||||||
|
.add(Activation::Relu)
|
||||||
|
.add(linear(
|
||||||
|
dims[2].0,
|
||||||
|
dims[2].1,
|
||||||
|
vb.pp(format!("{prefix}-fc2")),
|
||||||
|
)?);
|
||||||
|
Ok::<Sequential, Error>(seq)
|
||||||
|
};
|
||||||
|
|
||||||
|
let network = make_network("critic")?;
|
||||||
|
let target_network = make_network("target-critic")?;
|
||||||
|
|
||||||
|
// this sets the two networks to be equal to each other using tau = 1.0
|
||||||
|
track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
varmap,
|
||||||
|
vb,
|
||||||
|
network,
|
||||||
|
target_network,
|
||||||
|
size_state,
|
||||||
|
size_action,
|
||||||
|
dims,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = Tensor::cat(&[action, state], 1)?;
|
||||||
|
self.network.forward(&xs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn target_forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = Tensor::cat(&[action, state], 1)?;
|
||||||
|
self.target_network.forward(&xs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn track(&mut self, tau: f64) -> Result<()> {
|
||||||
|
track(
|
||||||
|
&mut self.varmap,
|
||||||
|
&self.vb,
|
||||||
|
"target-critic",
|
||||||
|
"critic",
|
||||||
|
&self.dims,
|
||||||
|
tau,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
pub struct DDPG<'a> {
|
||||||
|
actor: Actor<'a>,
|
||||||
|
actor_optim: AdamW,
|
||||||
|
critic: Critic<'a>,
|
||||||
|
critic_optim: AdamW,
|
||||||
|
gamma: f64,
|
||||||
|
tau: f64,
|
||||||
|
replay_buffer: ReplayBuffer,
|
||||||
|
ou_noise: OuNoise,
|
||||||
|
|
||||||
|
size_state: usize,
|
||||||
|
size_action: usize,
|
||||||
|
pub train: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DDPG<'_> {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn new(
|
||||||
|
device: &Device,
|
||||||
|
size_state: usize,
|
||||||
|
size_action: usize,
|
||||||
|
train: bool,
|
||||||
|
actor_lr: f64,
|
||||||
|
critic_lr: f64,
|
||||||
|
gamma: f64,
|
||||||
|
tau: f64,
|
||||||
|
buffer_capacity: usize,
|
||||||
|
ou_noise: OuNoise,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let filter_by_prefix = |varmap: &VarMap, prefix: &str| {
|
||||||
|
varmap
|
||||||
|
.data()
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(name, var)| name.starts_with(prefix).then_some(var.clone()))
|
||||||
|
.collect::<Vec<Var>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
let actor = Actor::new(device, DType::F32, size_state, size_action)?;
|
||||||
|
let actor_optim = AdamW::new(
|
||||||
|
filter_by_prefix(&actor.varmap, "actor"),
|
||||||
|
ParamsAdamW {
|
||||||
|
lr: actor_lr,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let critic = Critic::new(device, DType::F32, size_state, size_action)?;
|
||||||
|
let critic_optim = AdamW::new(
|
||||||
|
filter_by_prefix(&critic.varmap, "critic"),
|
||||||
|
ParamsAdamW {
|
||||||
|
lr: critic_lr,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
actor,
|
||||||
|
actor_optim,
|
||||||
|
critic,
|
||||||
|
critic_optim,
|
||||||
|
gamma,
|
||||||
|
tau,
|
||||||
|
replay_buffer: ReplayBuffer::new(buffer_capacity),
|
||||||
|
ou_noise,
|
||||||
|
size_state,
|
||||||
|
size_action,
|
||||||
|
train,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remember(
|
||||||
|
&mut self,
|
||||||
|
state: &Tensor,
|
||||||
|
action: &Tensor,
|
||||||
|
reward: &Tensor,
|
||||||
|
next_state: &Tensor,
|
||||||
|
terminated: bool,
|
||||||
|
truncated: bool,
|
||||||
|
) {
|
||||||
|
self.replay_buffer
|
||||||
|
.push(state, action, reward, next_state, terminated, truncated)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
|
||||||
|
let actions = self
|
||||||
|
.actor
|
||||||
|
.forward(&state.detach()?.unsqueeze(0)?)?
|
||||||
|
.squeeze(0)?;
|
||||||
|
let actions = if self.train {
|
||||||
|
(actions + self.ou_noise.sample()?)?
|
||||||
|
} else {
|
||||||
|
actions
|
||||||
|
};
|
||||||
|
actions.squeeze(0)?.to_scalar::<f32>()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train(&mut self, batch_size: usize) -> Result<()> {
|
||||||
|
let (states, actions, rewards, next_states, _, _) =
|
||||||
|
match self.replay_buffer.random_batch(batch_size)? {
|
||||||
|
Some(v) => v,
|
||||||
|
_ => return Ok(()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let q_target = self
|
||||||
|
.critic
|
||||||
|
.target_forward(&next_states, &self.actor.target_forward(&next_states)?)?;
|
||||||
|
let q_target = (rewards + (self.gamma * q_target)?.detach())?;
|
||||||
|
let q = self.critic.forward(&states, &actions)?;
|
||||||
|
let diff = (q_target - q)?;
|
||||||
|
|
||||||
|
let critic_loss = diff.sqr()?.mean_all()?;
|
||||||
|
self.critic_optim.backward_step(&critic_loss)?;
|
||||||
|
|
||||||
|
let actor_loss = self
|
||||||
|
.critic
|
||||||
|
.forward(&states, &self.actor.forward(&states)?)?
|
||||||
|
.mean_all()?
|
||||||
|
.neg()?;
|
||||||
|
self.actor_optim.backward_step(&actor_loss)?;
|
||||||
|
|
||||||
|
self.critic.track(self.tau)?;
|
||||||
|
self.actor.track(self.tau)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The impact of the q value of the next state on the current state's q value.
|
||||||
|
const GAMMA: f64 = 0.99;
|
||||||
|
// The weight for updating the target networks.
|
||||||
|
const TAU: f64 = 0.005;
|
||||||
|
// The capacity of the replay buffer used for sampling training data.
|
||||||
|
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
||||||
|
// The training batch size for each training iteration.
|
||||||
|
const TRAINING_BATCH_SIZE: usize = 100;
|
||||||
|
// The total number of episodes.
|
||||||
|
const MAX_EPISODES: usize = 100;
|
||||||
|
// The maximum length of an episode.
|
||||||
|
const EPISODE_LENGTH: usize = 200;
|
||||||
|
// The number of training iterations after one episode finishes.
|
||||||
|
const TRAINING_ITERATIONS: usize = 200;
|
||||||
|
|
||||||
|
// Ornstein-Uhlenbeck process parameters.
|
||||||
|
const MU: f64 = 0.0;
|
||||||
|
const THETA: f64 = 0.15;
|
||||||
|
const SIGMA: f64 = 0.1;
|
||||||
|
|
||||||
|
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
||||||
|
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
||||||
|
|
||||||
|
pub fn run() -> Result<()> {
|
||||||
|
let env = GymEnv::new("Pendulum-v1")?;
|
||||||
|
println!("action space: {}", env.action_space());
|
||||||
|
println!("observation space: {:?}", env.observation_space());
|
||||||
|
|
||||||
|
let size_state = env.observation_space().iter().product::<usize>();
|
||||||
|
let size_action = env.action_space();
|
||||||
|
|
||||||
|
let mut agent = DDPG::new(
|
||||||
|
&Device::Cpu,
|
||||||
|
size_state,
|
||||||
|
size_action,
|
||||||
|
true,
|
||||||
|
ACTOR_LEARNING_RATE,
|
||||||
|
CRITIC_LEARNING_RATE,
|
||||||
|
GAMMA,
|
||||||
|
TAU,
|
||||||
|
REPLAY_BUFFER_CAPACITY,
|
||||||
|
OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
|
for episode in 0..MAX_EPISODES {
|
||||||
|
// let mut state = env.reset(episode as u64)?;
|
||||||
|
let mut state = env.reset(rng.gen::<u64>())?;
|
||||||
|
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
for _ in 0..EPISODE_LENGTH {
|
||||||
|
let mut action = 2.0 * agent.actions(&state)?;
|
||||||
|
action = action.clamp(-2.0, 2.0);
|
||||||
|
|
||||||
|
let step = env.step(vec![action])?;
|
||||||
|
total_reward += step.reward;
|
||||||
|
|
||||||
|
agent.remember(
|
||||||
|
&state,
|
||||||
|
&Tensor::new(vec![action], &Device::Cpu)?,
|
||||||
|
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
|
||||||
|
&step.state,
|
||||||
|
step.terminated,
|
||||||
|
step.truncated,
|
||||||
|
);
|
||||||
|
|
||||||
|
if step.terminated || step.truncated {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
state = step.state;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("episode {episode} with total reward of {total_reward}");
|
||||||
|
|
||||||
|
for _ in 0..TRAINING_ITERATIONS {
|
||||||
|
agent.train(TRAINING_BATCH_SIZE)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("Testing...");
|
||||||
|
agent.train = false;
|
||||||
|
for episode in 0..10 {
|
||||||
|
// let mut state = env.reset(episode as u64)?;
|
||||||
|
let mut state = env.reset(rng.gen::<u64>())?;
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
for _ in 0..EPISODE_LENGTH {
|
||||||
|
let mut action = 2.0 * agent.actions(&state)?;
|
||||||
|
action = action.clamp(-2.0, 2.0);
|
||||||
|
|
||||||
|
let step = env.step(vec![action])?;
|
||||||
|
total_reward += step.reward;
|
||||||
|
|
||||||
|
if step.terminated || step.truncated {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
state = step.state;
|
||||||
|
}
|
||||||
|
println!("episode {episode} with total reward of {total_reward}");
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -7,20 +7,22 @@ use pyo3::types::PyDict;
|
|||||||
/// The return value for a step.
|
/// The return value for a step.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Step<A> {
|
pub struct Step<A> {
|
||||||
pub obs: Tensor,
|
pub state: Tensor,
|
||||||
pub action: A,
|
pub action: A,
|
||||||
pub reward: f64,
|
pub reward: f64,
|
||||||
pub is_done: bool,
|
pub terminated: bool,
|
||||||
|
pub truncated: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A: Copy> Step<A> {
|
impl<A: Copy> Step<A> {
|
||||||
/// Returns a copy of this step changing the observation tensor.
|
/// Returns a copy of this step changing the observation tensor.
|
||||||
pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> {
|
pub fn copy_with_obs(&self, state: &Tensor) -> Step<A> {
|
||||||
Step {
|
Step {
|
||||||
obs: obs.clone(),
|
state: state.clone(),
|
||||||
action: self.action,
|
action: self.action,
|
||||||
reward: self.reward,
|
reward: self.reward,
|
||||||
is_done: self.is_done,
|
terminated: self.terminated,
|
||||||
|
truncated: self.truncated,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -63,14 +65,14 @@ impl GymEnv {
|
|||||||
|
|
||||||
/// Resets the environment, returning the observation tensor.
|
/// Resets the environment, returning the observation tensor.
|
||||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||||
let obs: Vec<f32> = Python::with_gil(|py| {
|
let state: Vec<f32> = Python::with_gil(|py| {
|
||||||
let kwargs = PyDict::new(py);
|
let kwargs = PyDict::new(py);
|
||||||
kwargs.set_item("seed", seed)?;
|
kwargs.set_item("seed", seed)?;
|
||||||
let obs = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||||
obs.as_ref(py).get_item(0)?.extract()
|
state.as_ref(py).get_item(0)?.extract()
|
||||||
})
|
})
|
||||||
.map_err(w)?;
|
.map_err(w)?;
|
||||||
Tensor::new(obs, &Device::Cpu)
|
Tensor::new(state, &Device::Cpu)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies an environment step using the specified action.
|
/// Applies an environment step using the specified action.
|
||||||
@ -78,21 +80,23 @@ impl GymEnv {
|
|||||||
&self,
|
&self,
|
||||||
action: A,
|
action: A,
|
||||||
) -> Result<Step<A>> {
|
) -> Result<Step<A>> {
|
||||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
||||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
||||||
let step = step.as_ref(py);
|
let step = step.as_ref(py);
|
||||||
let obs: Vec<f32> = step.get_item(0)?.extract()?;
|
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
||||||
let reward: f64 = step.get_item(1)?.extract()?;
|
let reward: f64 = step.get_item(1)?.extract()?;
|
||||||
let is_done: bool = step.get_item(2)?.extract()?;
|
let terminated: bool = step.get_item(2)?.extract()?;
|
||||||
Ok((obs, reward, is_done))
|
let truncated: bool = step.get_item(3)?.extract()?;
|
||||||
|
Ok((state, reward, terminated, truncated))
|
||||||
})
|
})
|
||||||
.map_err(w)?;
|
.map_err(w)?;
|
||||||
let obs = Tensor::new(obs, &Device::Cpu)?;
|
let state = Tensor::new(state, &Device::Cpu)?;
|
||||||
Ok(Step {
|
Ok(Step {
|
||||||
obs,
|
state,
|
||||||
reward,
|
|
||||||
is_done,
|
|
||||||
action,
|
action,
|
||||||
|
reward,
|
||||||
|
terminated,
|
||||||
|
truncated,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,70 +6,32 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::Result;
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
mod gym_env;
|
mod gym_env;
|
||||||
mod vec_gym_env;
|
mod vec_gym_env;
|
||||||
|
|
||||||
use candle::Result;
|
mod ddpg;
|
||||||
use clap::Parser;
|
mod policy_gradient;
|
||||||
use rand::Rng;
|
|
||||||
|
|
||||||
// The total number of episodes.
|
#[derive(Parser)]
|
||||||
const MAX_EPISODES: usize = 100;
|
|
||||||
// The maximum length of an episode.
|
|
||||||
const EPISODE_LENGTH: usize = 200;
|
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
struct Args {
|
||||||
/// Run on CPU rather than on GPU.
|
#[command(subcommand)]
|
||||||
#[arg(long)]
|
command: Command,
|
||||||
cpu: bool,
|
}
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
#[derive(Subcommand)]
|
||||||
#[arg(long)]
|
enum Command {
|
||||||
tracing: bool,
|
Pg,
|
||||||
|
Ddpg,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
match args.command {
|
||||||
let _guard = if args.tracing {
|
Command::Pg => policy_gradient::run()?,
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
Command::Ddpg => ddpg::run()?,
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let env = gym_env::GymEnv::new("Pendulum-v1")?;
|
|
||||||
println!("action space: {}", env.action_space());
|
|
||||||
println!("observation space: {:?}", env.observation_space());
|
|
||||||
|
|
||||||
let _num_obs = env.observation_space().iter().product::<usize>();
|
|
||||||
let _num_actions = env.action_space();
|
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
|
||||||
|
|
||||||
for episode in 0..MAX_EPISODES {
|
|
||||||
let mut obs = env.reset(episode as u64)?;
|
|
||||||
|
|
||||||
let mut total_reward = 0.0;
|
|
||||||
for _ in 0..EPISODE_LENGTH {
|
|
||||||
let actions = rng.gen_range(-2.0..2.0);
|
|
||||||
|
|
||||||
let step = env.step(vec![actions])?;
|
|
||||||
total_reward += step.reward;
|
|
||||||
|
|
||||||
if step.is_done {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
obs = step.obs;
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("episode {episode} with total reward of {total_reward}");
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user