mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
410 Commits
wasm-llama
...
einsum-cus
Author | SHA1 | Date | |
---|---|---|---|
a910ec5993 | |||
acf8f10ae1 | |||
0906acab91 | |||
158ff3c609 | |||
e5703d2f56 | |||
98172d46fa | |||
28c87f6a34 | |||
c1453f00b1 | |||
989a4807b1 | |||
0e250aee4f | |||
cfcbec9fc7 | |||
3898e500de | |||
79c27fc489 | |||
7396b8ed1a | |||
7b50f3e106 | |||
8c991df394 | |||
000fa00e31 | |||
a17a7c42c1 | |||
6527ab81a3 | |||
7b1f2da828 | |||
bdc9d46fe3 | |||
dcf708559d | |||
7299a68353 | |||
16bf44f6e9 | |||
a4f40f3dc8 | |||
6a40decc76 | |||
a0d65585db | |||
94c6a8d3d3 | |||
6615daf242 | |||
1c9e5394a5 | |||
a8410bf35e | |||
cda45a7443 | |||
4698eb5cb6 | |||
000487c36f | |||
ab0d9fbdd1 | |||
f80fd44201 | |||
0d00c06a83 | |||
8395152d20 | |||
e2f9f60ac2 | |||
d0cdea95a5 | |||
20512ba408 | |||
9c61b0fc9b | |||
26cd266e65 | |||
bbec527bb9 | |||
f7980e07e0 | |||
74a82c358a | |||
84d003ff53 | |||
21109e1983 | |||
ad796eb4be | |||
e8e33752f4 | |||
dabaa479b9 | |||
2c1df6bba1 | |||
4d56cef583 | |||
19042962d5 | |||
731e3ffb03 | |||
2fef14cb14 | |||
1e5b2cc1d5 | |||
2ed78ab336 | |||
237323c2bc | |||
af552a5274 | |||
7529531056 | |||
f2d476ca65 | |||
f9f482d4e5 | |||
9736236175 | |||
30a4b593d7 | |||
949f1eae6f | |||
7cef35c84d | |||
7509c98970 | |||
94aa234dfd | |||
db59816087 | |||
d210c71d77 | |||
8e84d8a59b | |||
9bd486fb96 | |||
eaf760a751 | |||
1d0bb48fae | |||
21e1c73892 | |||
2047d34b7c | |||
9874d843f1 | |||
7d753d3acd | |||
3159982a89 | |||
ad8a62dbf5 | |||
f35b9f6baa | |||
618f4e4c78 | |||
5ac0a98f01 | |||
393690387f | |||
9b25113393 | |||
a1a5ab8b0a | |||
59b731de99 | |||
2d3fcad267 | |||
b31d41e26a | |||
71221559d3 | |||
a044907ffc | |||
ee8bb1bde1 | |||
3d2d3c7edb | |||
1aca6fa291 | |||
4ed202447e | |||
1d6bff53fc | |||
14b4d456e8 | |||
2d5b7a735d | |||
62ef494dc1 | |||
d0a330448d | |||
4b8d57ba15 | |||
d5a525f7a7 | |||
33c23c19b6 | |||
49326fb925 | |||
fd3131a4ce | |||
037b41c9dc | |||
72fae3140c | |||
ca26198b95 | |||
b292047882 | |||
09c5bd1881 | |||
fe6c88713d | |||
6f3f9285e6 | |||
baca3cf69d | |||
d726484a6d | |||
dd06d93d0b | |||
c109c93db7 | |||
d7a273be51 | |||
dd02f589c0 | |||
7602323667 | |||
9137c63175 | |||
3cca89cc70 | |||
26e1b40992 | |||
1da71a5da1 | |||
24dda44c27 | |||
72ebb12bca | |||
a3f97c143d | |||
4c338b0cd9 | |||
be471d50ab | |||
7151f2cf63 | |||
6e485f2deb | |||
5320aa6b7d | |||
a8b39dd7b7 | |||
fa0d75b18d | |||
28658054ff | |||
ab36a7f3e3 | |||
f704e39761 | |||
fdf15f0e05 | |||
06b37ea7ad | |||
c72eb3d75b | |||
864227edbf | |||
b23b347b35 | |||
71518caeee | |||
6559eae72c | |||
46eb225ba5 | |||
aa67e5107d | |||
c105550405 | |||
ca6c050b04 | |||
9c8d6dbc2a | |||
0afbc435df | |||
d4e75d5825 | |||
be371e827c | |||
97909e5068 | |||
1c1e34735e | |||
db8bab8b7a | |||
bc131b402b | |||
8bc5fffa45 | |||
4826a4212e | |||
afc10a3232 | |||
d728e646c2 | |||
c093b03d51 | |||
d8ba0452dc | |||
189442a0fa | |||
2cde0cb74b | |||
e21c686cdc | |||
c265ac50fa | |||
a87c6f7652 | |||
afd965f77c | |||
d2f42ab086 | |||
ca318a6ec7 | |||
dd64465899 | |||
79916c2edb | |||
431051cc32 | |||
eedd85ffa7 | |||
7478dda255 | |||
329f661d9b | |||
075b505480 | |||
aba1e90797 | |||
1f58bdbb1d | |||
c98d3cfd8b | |||
c5e43ad0ab | |||
2c280007e8 | |||
4ee1cf038a | |||
0f4ff8a739 | |||
89a00b56cc | |||
9a5c7db91a | |||
649202024c | |||
283f6c048d | |||
c8211fc474 | |||
7732bf6238 | |||
7c0ca80d3a | |||
b558d08b85 | |||
34cb9f924f | |||
d4968295a0 | |||
65e146c72d | |||
3743bed2d7 | |||
508d34daf2 | |||
0764741cc4 | |||
6a30ecefad | |||
7687a0f453 | |||
f9ecc84477 | |||
07067b01dc | |||
cc22d4db20 | |||
ec665acad7 | |||
cf27b9b636 | |||
352383cbc3 | |||
9bc811a247 | |||
bb69d89e28 | |||
20ce3e9f39 | |||
44420d8ae1 | |||
f16bb97401 | |||
3507e14c0c | |||
de50e66af1 | |||
cc2d6cf2e0 | |||
e3b71851e6 | |||
4300864ce9 | |||
d70cffdab6 | |||
912561614f | |||
8c232d706b | |||
11c7e7bd67 | |||
a1812f934f | |||
e3d2786ffb | |||
372f8912c5 | |||
d2622a8160 | |||
2fcb386f17 | |||
a8f61e66cc | |||
aa207f2dd9 | |||
82410995a2 | |||
d73ca3d28e | |||
551409092e | |||
6431140250 | |||
607ffb9f1e | |||
f861a9df6e | |||
ad33715c61 | |||
90ff04e77e | |||
42e1cc8062 | |||
b64e782c2d | |||
e5dd5fd1b3 | |||
cb069d6063 | |||
4f1541526c | |||
95462c6a2e | |||
b9661a1c25 | |||
109e95b189 | |||
c78ce76501 | |||
13401df4d1 | |||
a22b1bed7b | |||
26fd37b348 | |||
f056dcab21 | |||
557b2c28dd | |||
fc81af1712 | |||
3164cd24fa | |||
5f30c1e1e0 | |||
ad7c53953b | |||
5d99026fd2 | |||
c3176f0dfb | |||
03be33eea4 | |||
d32e8199cd | |||
d99cac3ec3 | |||
f708efb19c | |||
306c8eee7a | |||
098909de40 | |||
3bedba1fce | |||
c5f45887dc | |||
fa4590d7fd | |||
2e206e269d | |||
575e88a999 | |||
a9101700b6 | |||
102fa4c2e3 | |||
3071134788 | |||
fec87e86f5 | |||
33c882ea74 | |||
76804730c6 | |||
965597a873 | |||
ca449f9ee1 | |||
b8263aa15c | |||
e68b2accb4 | |||
08effe3762 | |||
8ad4a21ffc | |||
5e49922be2 | |||
ebcfd96d94 | |||
5b1690fffa | |||
3cc87058b7 | |||
531f23b4d0 | |||
495e0b7580 | |||
90374097dc | |||
c84883ecf2 | |||
a094dc503d | |||
34f4b3187e | |||
eab54e4490 | |||
9e7e6e0288 | |||
8bd2b22b33 | |||
d379a76a9e | |||
9af438ac1b | |||
b1ff78f762 | |||
5a63b51f14 | |||
6d694554b8 | |||
9aca398a4f | |||
60cd1551ca | |||
a0908d212c | |||
972078e1ae | |||
16b89f5b83 | |||
0741ebbd51 | |||
0c3f109faa | |||
2ba6b2826f | |||
1d0157bbc4 | |||
91dbf907d3 | |||
e12372021b | |||
55e428c8ae | |||
01ea57da8c | |||
662db45fc3 | |||
906c0f3eb5 | |||
e29c7809ec | |||
a325c1aa50 | |||
b6cf26e48e | |||
379eadc68e | |||
7e4fbc1e17 | |||
80f0482f26 | |||
94eff56aee | |||
a55133effd | |||
ff53f38467 | |||
4a95d34c83 | |||
7f710a573d | |||
c8039579a5 | |||
0b0fa56978 | |||
385f0d261c | |||
b765f2c37f | |||
66d1c093e0 | |||
de7c31bfe9 | |||
8e7ef96588 | |||
f3fe730a30 | |||
c7f92f985e | |||
3bbc08a8df | |||
6a2137af4f | |||
0dc1e5f387 | |||
bd2fb6216b | |||
3542b26143 | |||
a690f14a77 | |||
90d778c059 | |||
171fcbe539 | |||
07e83c55c0 | |||
25ec2d9f6b | |||
da26e2832c | |||
fcfdcbd337 | |||
653ec5abc1 | |||
c3a0761e62 | |||
0cef3998fd | |||
e5f510d209 | |||
0dd94eff4c | |||
a3b1699409 | |||
5b79b38bc7 | |||
a5c5a893aa | |||
e6ce47f9e0 | |||
1892bd139c | |||
749c8c7f51 | |||
d9b4fef189 | |||
8fa329aca2 | |||
cd225bd3b1 | |||
a4f6977087 | |||
dece0b8a76 | |||
b80348d22f | |||
3a62aee91f | |||
be21d7e75a | |||
9c4cf6804b | |||
dbc6f281c9 | |||
47a5bee249 | |||
cf965ecaa8 | |||
b9864e1357 | |||
608b2358c6 | |||
1e6dbeac01 | |||
13ce68ff9b | |||
89d3926c9b | |||
ab35684326 | |||
b5bb5e056d | |||
d0d7010682 | |||
fc265d9dcf | |||
2345b8ce3f | |||
f53a333ea9 | |||
e72ba0b9e7 | |||
5bb2fce998 | |||
2c9f605976 | |||
141df4ad2b | |||
166bfd5847 | |||
1c062bf06b | |||
d34039e352 | |||
93cfe5642f | |||
88bd3b604a | |||
b278834267 | |||
0b175fcbbd | |||
620f83cf66 | |||
f7b2a0391d | |||
8b6f5be1cc | |||
df6667ba88 | |||
a79286885c | |||
74845a4dcd | |||
aa76b783eb | |||
25564357f7 | |||
634700d84a | |||
e635f18eda | |||
dba31473d4 | |||
1b2b32e58d | |||
166f4d1101 | |||
ae68635af9 | |||
c11e78b334 | |||
1b705a426f | |||
a70b95f9e7 | |||
a44471a305 | |||
45642a8530 | |||
82464166e4 | |||
52414ba5c8 | |||
186c308d51 |
@ -1,8 +1,8 @@
|
||||
[target.x86_64-unknown-linux-gnu]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
||||
|
||||
[target.aarch64-apple-darwin]
|
||||
[build]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
||||
|
||||
[target.wasm32-unknown-unknown]
|
||||
rustflags = ["-C", "target-feature=+simd128"]
|
||||
|
||||
[target.x86_64-apple-darwin]
|
||||
rustflags = ["-C", "target-feature=-avx,-avx2"]
|
2
.github/workflows/book-cd.yml
vendored
2
.github/workflows/book-cd.yml
vendored
@ -1,7 +1,5 @@
|
||||
name: Deploy Rust book
|
||||
on:
|
||||
# TODO put this back only when merging after this PR lands.
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
87
.github/workflows/ci_cuda.yaml
vendored
Normal file
87
.github/workflows/ci_cuda.yaml
vendored
Normal file
@ -0,0 +1,87 @@
|
||||
name: CI / cuda
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
start-runner:
|
||||
name: Start self-hosted EC2 runner
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||
EC2_INSTANCE_TYPE: g5.xlarge
|
||||
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
||||
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
||||
outputs:
|
||||
label: ${{ steps.start-ec2-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Start EC2 runner
|
||||
id: start-ec2-runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: start
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
||||
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
||||
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
||||
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
||||
aws-resource-tags: > # optional, requires additional permissions
|
||||
[
|
||||
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
||||
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
||||
]
|
||||
|
||||
test-cuda:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
needs: start-runner # required to start the main job when the runner is ready
|
||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
# This is used to complete the identity challenge
|
||||
# with sigstore/fulcio when running outside of PRs.
|
||||
id-token: write
|
||||
security-events: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Install Rust Stable
|
||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: apt-get update -y && apt-get install libssl-dev -y
|
||||
- name: Test (cuda)
|
||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||
stop-runner:
|
||||
name: Stop self-hosted EC2 runner
|
||||
needs:
|
||||
- start-runner
|
||||
- test-cuda
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Stop EC2 runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: stop
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
label: ${{ needs.start-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
6
.gitignore
vendored
6
.gitignore
vendored
@ -20,11 +20,17 @@ Cargo.lock
|
||||
|
||||
perf.data
|
||||
flamegraph.svg
|
||||
*.dylib
|
||||
*.so
|
||||
*.swp
|
||||
trace-*.json
|
||||
|
||||
candle-wasm-examples/*/build
|
||||
candle-wasm-examples/*/*.bin
|
||||
candle-wasm-examples/*/*.jpeg
|
||||
candle-wasm-examples/*/*.wav
|
||||
candle-wasm-examples/*/*.safetensors
|
||||
candle-wasm-examples/*/package-lock.json
|
||||
|
||||
.DS_Store
|
||||
.idea/*
|
||||
|
42
CHANGELOG.md
Normal file
42
CHANGELOG.md
Normal file
@ -0,0 +1,42 @@
|
||||
# Changelog
|
||||
This documents the main changes to the `candle` crate.
|
||||
|
||||
## v0.2.1 - Unreleased
|
||||
|
||||
### Added
|
||||
|
||||
### Modified
|
||||
- Dilations are now supported in conv-transpose2d.
|
||||
[671](https://github.com/huggingface/candle/pull/671).
|
||||
|
||||
## v0.2.0 - 2023-08-30
|
||||
|
||||
### Added
|
||||
- Add the powf op
|
||||
[664](https://github.com/huggingface/candle/pull/664).
|
||||
- Stable Diffusion XL support
|
||||
[647](https://github.com/huggingface/candle/pull/647).
|
||||
- Add the conv-transpose2d op
|
||||
[635](https://github.com/huggingface/candle/pull/635).
|
||||
- Refactor the VarBuilder api
|
||||
[627](https://github.com/huggingface/candle/pull/627).
|
||||
- Add some quantization command
|
||||
[625](https://github.com/huggingface/candle/pull/625).
|
||||
- Support more quantized types, e.g. Q2K, Q4K, Q5K...
|
||||
[586](https://github.com/huggingface/candle/pull/586).
|
||||
- Add pose estimation to the yolo example
|
||||
[589](https://github.com/huggingface/candle/pull/589).
|
||||
- Api to write GGUF files
|
||||
[585](https://github.com/huggingface/candle/pull/585).
|
||||
- Support more quantization types
|
||||
[580](https://github.com/huggingface/candle/pull/580).
|
||||
- Add EfficientNet as an example Computer Vision model
|
||||
[572](https://github.com/huggingface/candle/pull/572).
|
||||
- Add a group parameter to convolutions
|
||||
[566](https://github.com/huggingface/candle/pull/566).
|
||||
- New dtype: int64
|
||||
[563](https://github.com/huggingface/candle/pull/563).
|
||||
- Handling of the GGUF file format.
|
||||
[559](https://github.com/huggingface/candle/pull/559).
|
||||
|
||||
## v0.1.2 - 2023-08-21
|
25
Cargo.toml
25
Cargo.toml
@ -1,36 +1,43 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"candle-core",
|
||||
"candle-datasets",
|
||||
"candle-examples",
|
||||
"candle-book",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/llama2-c",
|
||||
"candle-wasm-examples/whisper",
|
||||
"candle-wasm-examples/yolo",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.0"
|
||||
version = "0.2.1"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT/Apache-2.0"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
cudarc = { version = "0.9.13", features = ["f16"] }
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
# TODO: Switch back to the official gemm implementation once it has caught up.
|
||||
gemm = { version = "0.15.5", package = "candle-gemm" }
|
||||
hf-hub = "0.2.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
|
||||
gemm = { version = "0.15.6", package = "candle-gemm" }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
imageproc = { version = "0.23.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
@ -38,16 +45,20 @@ memmap2 = "0.7.1"
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
safetensors = "0.3.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.13.3", default-features = false }
|
||||
tokenizers = { version = "0.13.4", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
parquet = { version = "45.0.0" }
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
201
LICENSE-APACHE
Normal file
201
LICENSE-APACHE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
23
LICENSE-MIT
Normal file
23
LICENSE-MIT
Normal file
@ -0,0 +1,23 @@
|
||||
Permission is hereby granted, free of charge, to any
|
||||
person obtaining a copy of this software and associated
|
||||
documentation files (the "Software"), to deal in the
|
||||
Software without restriction, including without
|
||||
limitation the rights to use, copy, modify, merge,
|
||||
publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software
|
||||
is furnished to do so, subject to the following
|
||||
conditions:
|
||||
|
||||
The above copyright notice and this permission notice
|
||||
shall be included in all copies or substantial portions
|
||||
of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
|
||||
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
|
||||
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
|
||||
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
|
||||
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
|
||||
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
4
Makefile
4
Makefile
@ -1,7 +1,11 @@
|
||||
.PHONY: clean-ptx clean test
|
||||
|
||||
clean-ptx:
|
||||
find target -name "*.ptx" -type f -delete
|
||||
echo "" > candle-kernels/src/lib.rs
|
||||
touch candle-kernels/build.rs
|
||||
touch candle-examples/build.rs
|
||||
touch candle-flash-attn/build.rs
|
||||
|
||||
clean:
|
||||
cargo clean
|
||||
|
203
README.md
203
README.md
@ -1,32 +1,71 @@
|
||||
# candle
|
||||
[](https://discord.gg/hugging-face-879548962464493619)
|
||||
[](https://crates.io/crates/candle-core)
|
||||
[](https://docs.rs/candle-core)
|
||||

|
||||
|
||||
Candle is a minimalist ML framework for Rust with a focus on easiness of use and
|
||||
on performance (including GPU support). Try our online demos:
|
||||
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
|
||||
and ease of use. Try our online demos:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
|
||||
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[yolo](https://huggingface.co/spaces/lmz/candle-yolo).
|
||||
|
||||
## Get started
|
||||
|
||||
Make sure that you have [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) correctly installed as described in [**Installation**](https://huggingface.github.io/candle/guide/installation.html).
|
||||
|
||||
Let's see how to run a simple matrix multiplication.
|
||||
Write the following to your `myapp/src/main.rs` file:
|
||||
```rust
|
||||
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
||||
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let device = Device::Cpu;
|
||||
|
||||
let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
|
||||
let b = Tensor::randn(0f32, 1., (3, 4), &device)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
println!("{c}");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
`cargo run` should display a tensor of shape `Tensor[[2, 4], f32]`.
|
||||
|
||||
|
||||
Having installed `candle` with Cuda support, simply define the `device` to be on GPU:
|
||||
|
||||
```diff
|
||||
- let device = Device::Cpu;
|
||||
+ let device = Device::new_cuda(0)?;
|
||||
```
|
||||
|
||||
For more advanced examples, please have a look at the following section.
|
||||
|
||||
## Check out our examples
|
||||
|
||||
Check out our [examples](./candle-examples/examples/):
|
||||
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [Llama and Llama-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
||||
generation.
|
||||
|
||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||
the LLaMA model using the same quantization techniques as
|
||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
|
||||
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
|
||||
estimation models.
|
||||
[segment-anything](./candle-examples/examples/segment-anything/): image
|
||||
segmentation model with prompt.
|
||||
Run them using the following commands:
|
||||
```
|
||||
cargo run --example whisper --release
|
||||
@ -34,9 +73,16 @@ cargo run --example llama --release
|
||||
cargo run --example falcon --release
|
||||
cargo run --example bert --release
|
||||
cargo run --example bigcode --release
|
||||
cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch"
|
||||
cargo run --example dinov2 --release -- --image path/to/myinput.jpg
|
||||
cargo run --example quantized --release
|
||||
cargo run --example yolo-v3 --release -- myimage.jpg
|
||||
cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose
|
||||
cargo run --example segment-anything --release -- --image myimage.jpg
|
||||
```
|
||||
|
||||
In order to use **CUDA** add `--features cuda` to the example command line.
|
||||
In order to use **CUDA** add `--features cuda` to the example command line. If
|
||||
you have cuDNN installed, use `--features cudnn` for even more speedups.
|
||||
|
||||
There are also some wasm examples for whisper and
|
||||
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
|
||||
@ -44,47 +90,53 @@ There are also some wasm examples for whisper and
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
|
||||
|
||||
For llama2, run the following command to retrieve the weight files and start a
|
||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||
test server:
|
||||
```bash
|
||||
cd candle-wasm-examples/llama2-c
|
||||
wget https://karpathy.ai/llama2c/model.bin
|
||||
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
|
||||
trunk serve --release --public-url /candle-llama2/ --port 8081
|
||||
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
|
||||
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
|
||||
trunk serve --release --port 8081
|
||||
```
|
||||
And then browse to
|
||||
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
|
||||
And then head over to
|
||||
[http://localhost:8081/](http://localhost:8081/).
|
||||
|
||||
<!--- ANCHOR: features --->
|
||||
|
||||
## Features
|
||||
|
||||
- Simple syntax, looks and like PyTorch.
|
||||
- CPU and Cuda backends, m1, f16, bf16.
|
||||
- Enable serverless (CPU), small and fast deployments
|
||||
- WASM support, run your models in a browser.
|
||||
- Simple syntax, looks and feels like PyTorch.
|
||||
- Model training.
|
||||
- Distributed computing using NCCL.
|
||||
- Models out of the box: Llama, Whisper, Falcon, StarCoder...
|
||||
- Embed user-defined ops/kernels, such as [flash-attention
|
||||
v2](https://github.com/LaurentMazare/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
|
||||
- Embed user-defined ops/kernels, such as [flash-attention v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
|
||||
- Backends.
|
||||
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
|
||||
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
|
||||
- WASM support, run your models in a browser.
|
||||
- Included models.
|
||||
- LLMs: LLaMA v1 and v2, Falcon, StarCoder.
|
||||
- Whisper (multi-lingual support).
|
||||
- Stable Diffusion.
|
||||
- Computer Vision: DINOv2, EfficientNet, yolo-v3, yolo-v8.
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
- Quantization support using the llama.cpp quantized types.
|
||||
|
||||
<!--- ANCHOR_END: features --->
|
||||
|
||||
## How to use ?
|
||||
## How to use
|
||||
|
||||
<!--- ANCHOR: cheatsheet --->
|
||||
Cheatsheet:
|
||||
|
||||
| | Using PyTorch | Using Candle |
|
||||
|------------|------------------------------------------|------------------------------------------------------------------|
|
||||
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.]], [3., 4.]], &Device::Cpu)?` |
|
||||
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?` |
|
||||
| Creation | `torch.zeros((2, 2))` | `Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?` |
|
||||
| Indexing | `tensor[:, :4]` | `tensor.i((.., ..4))?` |
|
||||
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
|
||||
| Operations | `a.matmul(b)` | `a.matmul(&b)?` |
|
||||
| Arithmetic | `a + b` | `&a + &b` |
|
||||
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` |
|
||||
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::new_cuda(0)?)?` |
|
||||
| Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` |
|
||||
| Saving | `torch.save({"A": A}, "model.bin")` | `candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?` |
|
||||
| Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` |
|
||||
@ -95,62 +147,117 @@ Cheatsheet:
|
||||
## Structure
|
||||
|
||||
- [candle-core](./candle-core): Core ops, devices, and `Tensor` struct definition
|
||||
- [candle-nn](./candle-nn/): Facilities to build real models
|
||||
- [candle-examples](./candle-examples/): Real-world like examples on how to use the library in real settings
|
||||
- [candle-nn](./candle-nn/): Tools to build real models
|
||||
- [candle-examples](./candle-examples/): Examples of using the library in realistic settings
|
||||
- [candle-kernels](./candle-kernels/): CUDA custom kernels
|
||||
|
||||
|
||||
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
||||
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
||||
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
||||
|
||||
## FAQ
|
||||
|
||||
### Why Candle?
|
||||
### Why should I use Candle?
|
||||
|
||||
Candle stems from the need to reduce binary size in order to *enable serverless*
|
||||
possible by making the whole engine smaller than PyTorch very large library volume.
|
||||
This enables creating runtimes on a cluster much faster.
|
||||
Candle's core goal is to *make serverless inference possible*. Full machine learning frameworks like PyTorch
|
||||
are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight
|
||||
binaries.
|
||||
|
||||
And simply *removing Python* from production workloads.
|
||||
Python can really add overhead in more complex workflows and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
|
||||
Secondly, Candle lets you *remove Python* from production workloads. Python overhead can seriously hurt performance,
|
||||
and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
|
||||
|
||||
Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
|
||||
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
|
||||
|
||||
|
||||
### Other ML frameworks
|
||||
|
||||
- [dfdx](https://github.com/coreylowman/dfdx) is a formidable crate, with shapes being included
|
||||
in types preventing a lot of headaches by getting compiler to complain about shape mismatch right off the bat
|
||||
However we found that some features still require nightly and writing code can be a bit dauting for non rust experts.
|
||||
in types. This prevents a lot of headaches by getting the compiler to complain about shape mismatches right off the bat.
|
||||
However, we found that some features still require nightly, and writing code can be a bit daunting for non rust experts.
|
||||
|
||||
We're leveraging and contributing to other core crates for the runtime so hopefully both crates can benefit from each
|
||||
other
|
||||
other.
|
||||
|
||||
- [burn](https://github.com/burn-rs/burn) is a general crate that can leverage multiple backends so you can choose the best
|
||||
engine for your workload
|
||||
engine for your workload.
|
||||
|
||||
- [tch-rs](https://github.com/LaurentMazare/tch-rs.git) Bindings to the torch library in Rust. Extremely versatile, but they
|
||||
do bring in the entire torch library into the runtime. The main contributor of `tch-rs` is also involved in the development
|
||||
bring in the entire torch library into the runtime. The main contributor of `tch-rs` is also involved in the development
|
||||
of `candle`.
|
||||
|
||||
### Missing symbols when compiling with the mkl feature.
|
||||
### Common Errors
|
||||
|
||||
#### Missing symbols when compiling with the mkl feature.
|
||||
|
||||
If you get some missing symbols when compiling binaries/tests using the mkl
|
||||
features, e.g.:
|
||||
or accelerate features, e.g. for mkl you get:
|
||||
```
|
||||
= note: /usr/bin/ld: (....o): in function `blas::sgemm':
|
||||
.../blas-0.22.0/src/lib.rs:1944: undefined reference to `sgemm_' collect2: error: ld returned 1 exit status
|
||||
|
||||
= note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
|
||||
= note: use the `-l` flag to specify native libraries to link
|
||||
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo (see https://doc.rust-lang.org/cargo/reference/build-scripts.html#cargorustc-link-libkindname)
|
||||
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo
|
||||
```
|
||||
or for accelerate:
|
||||
```
|
||||
Undefined symbols for architecture arm64:
|
||||
"_dgemm_", referenced from:
|
||||
candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core...
|
||||
"_sgemm_", referenced from:
|
||||
candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core...
|
||||
ld: symbol(s) not found for architecture arm64
|
||||
```
|
||||
|
||||
This is likely due to some missing linker flag that enable the mkl library. You
|
||||
can try adding the following at the top of your binary:
|
||||
```
|
||||
This is likely due to a missing linker flag that was needed to enable the mkl library. You
|
||||
can try adding the following for mkl at the top of your binary:
|
||||
```rust
|
||||
extern crate intel_mkl_src;
|
||||
```
|
||||
or for accelerate:
|
||||
```rust
|
||||
extern crate accelerate_src;
|
||||
```
|
||||
|
||||
### How to know where an error comes from.
|
||||
#### Cannot run the LLaMA examples: access to source requires login credentials
|
||||
|
||||
```
|
||||
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401
|
||||
```
|
||||
|
||||
This is likely because you're not permissioned for the LLaMA-v2 model. To fix
|
||||
this, you have to register on the huggingface-hub, accept the [LLaMA-v2 model
|
||||
conditions](https://huggingface.co/meta-llama/Llama-2-7b-hf), and set up your
|
||||
authentication token. See issue
|
||||
[#350](https://github.com/huggingface/candle/issues/350) for more details.
|
||||
|
||||
#### Missing cute/cutlass headers when compiling flash-attn
|
||||
|
||||
```
|
||||
In file included from kernels/flash_fwd_launch_template.h:11:0,
|
||||
from kernels/flash_fwd_hdim224_fp16_sm80.cu:5:
|
||||
kernels/flash_fwd_kernel.h:8:10: fatal error: cute/algorithm/copy.hpp: No such file or directory
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
^~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
compilation terminated.
|
||||
Error: nvcc error while compiling:
|
||||
```
|
||||
[cutlass](https://github.com/NVIDIA/cutlass) is provided as a git submodule so you may want to run the following command to check it in properly.
|
||||
```bash
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
#### Compiling with flash-attention fails
|
||||
|
||||
```
|
||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||
```
|
||||
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
||||
```
|
||||
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
```
|
||||
|
||||
#### Tracking down errors
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
error is generated.
|
||||
|
49
candle-book/Cargo.toml
Normal file
49
candle-book/Cargo.toml
Normal file
@ -0,0 +1,49 @@
|
||||
[package]
|
||||
name = "candle-book"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
clap = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
parquet = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
@ -12,16 +12,16 @@
|
||||
|
||||
- [Running a model](inference/README.md)
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Serialization](inference/serialization.md)
|
||||
- [Advanced Cuda usage](inference/cuda/README.md)
|
||||
- [Writing a custom kernel](inference/cuda/writing.md)
|
||||
- [Porting a custom kernel](inference/cuda/porting.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Creating apps](apps/README.md)
|
||||
- [Creating a WASM app](apps/wasm.md)
|
||||
- [Creating a REST api webserver](apps/rest.md)
|
||||
- [Creating a desktop Tauri app](apps/dekstop.md)
|
||||
- [Training](training/README.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
- [Fine-tuning](training/finetuning.md)
|
||||
- [Using MKL](advanced/mkl.md)
|
||||
- [Fine-tuning]()
|
||||
- [Serialization]()
|
||||
- [Advanced Cuda usage]()
|
||||
- [Writing a custom kernel]()
|
||||
- [Porting a custom kernel]()
|
||||
- [Using MKL]()
|
||||
- [Creating apps]()
|
||||
- [Creating a WASM app]()
|
||||
- [Creating a REST api webserver]()
|
||||
- [Creating a desktop Tauri app]()
|
||||
|
1
candle-book/src/cuda/README.md
Normal file
1
candle-book/src/cuda/README.md
Normal file
@ -0,0 +1 @@
|
||||
# Advanced Cuda usage
|
1
candle-book/src/cuda/porting.md
Normal file
1
candle-book/src/cuda/porting.md
Normal file
@ -0,0 +1 @@
|
||||
# Porting a custom kernel
|
1
candle-book/src/cuda/writing.md
Normal file
1
candle-book/src/cuda/writing.md
Normal file
@ -0,0 +1 @@
|
||||
# Writing a custom kernel
|
@ -1 +1,51 @@
|
||||
# Error management
|
||||
|
||||
You might have seen in the code base a lot of `.unwrap()` or `?`.
|
||||
If you're unfamiliar with Rust check out the [Rust book](https://doc.rust-lang.org/book/ch09-02-recoverable-errors-with-result.html)
|
||||
for more information.
|
||||
|
||||
What's important to know though, is that if you want to know *where* a particular operation failed
|
||||
You can simply use `RUST_BACKTRACE=1` to get the location of where the model actually failed.
|
||||
|
||||
Let's see on failing code:
|
||||
|
||||
```rust,ignore
|
||||
let x = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
let y = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
let z = x.matmul(&y)?;
|
||||
```
|
||||
|
||||
Will print at runtime:
|
||||
|
||||
```bash
|
||||
Error: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }
|
||||
```
|
||||
|
||||
|
||||
After adding `RUST_BACKTRACE=1`:
|
||||
|
||||
|
||||
```bash
|
||||
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
|
||||
```
|
||||
|
||||
Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
|
||||
|
||||
|
||||
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
|
||||
especially in release builds. We're using [`anyhow`](https://docs.rs/anyhow/latest/anyhow/) for that.
|
||||
The library is still young, please [report](https://github.com/LaurentMazare/candle/issues) any issues detecting where an error is coming from.
|
||||
|
||||
## Cuda error management
|
||||
|
||||
When running a model on Cuda, you might get a stacktrace not really representing the error.
|
||||
The reason is that CUDA is async by nature, and therefore the error might be caught while you were sending totally different kernels.
|
||||
|
||||
One way to avoid this is to use `CUDA_LAUNCH_BLOCKING=1` as an environment variable. This will force every kernel to be launched sequentially.
|
||||
You might still however see the error happening on other kernels as the faulty kernel might exit without an error but spoiling some pointer for which the error will happen when dropping the `CudaSlice` only.
|
||||
|
||||
|
||||
If this occurs, you can use [`compute-sanitizer`](https://docs.nvidia.com/compute-sanitizer/ComputeSanitizer/index.html)
|
||||
This tool is like `valgrind` but for cuda. It will help locate the errors in the kernels.
|
||||
|
||||
|
||||
|
@ -128,17 +128,17 @@ fn main() -> Result<()> {
|
||||
```
|
||||
|
||||
Now it works, it is a great way to create your own layers.
|
||||
But most of the classical layers are already implemented in [candle-nn](https://github.com/LaurentMazare/candle/tree/main/candle-nn).
|
||||
But most of the classical layers are already implemented in [candle-nn](https://github.com/huggingface/candle/tree/main/candle-nn).
|
||||
|
||||
## Using `candle_nn`.
|
||||
|
||||
For instance [Linear](https://github.com/LaurentMazare/candle/blob/main/candle-nn/src/linear.rs) is already there.
|
||||
For instance [Linear](https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs) is already there.
|
||||
This Linear is coded with PyTorch layout in mind, to reuse better existing models out there, so it uses the transpose of the weights and not the weights directly.
|
||||
|
||||
So instead we can simplify our example:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/LaurentMazare/candle.git candle-nn
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-nn
|
||||
```
|
||||
|
||||
And rewrite our examples using it
|
||||
@ -147,7 +147,7 @@ And rewrite our examples using it
|
||||
# extern crate candle_core;
|
||||
# extern crate candle_nn;
|
||||
use candle_core::{DType, Device, Result, Tensor};
|
||||
use candle_nn::Linear;
|
||||
use candle_nn::{Linear, Module};
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
|
@ -1,24 +1,56 @@
|
||||
# Installation
|
||||
|
||||
Start by creating a new app:
|
||||
**With Cuda support**:
|
||||
|
||||
1. First, make sure that Cuda is correctly installed.
|
||||
- `nvcc --version` should print information about your Cuda compiler driver.
|
||||
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
|
||||
like:
|
||||
|
||||
```bash
|
||||
compute_cap
|
||||
8.9
|
||||
```
|
||||
|
||||
If any of the above commands errors out, please make sure to update your Cuda version.
|
||||
|
||||
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
||||
|
||||
Start by creating a new cargo:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
cargo add --git https://github.com/LaurentMazare/candle.git candle
|
||||
```
|
||||
|
||||
At this point, candle will be built **without** CUDA support.
|
||||
To get CUDA support use the `cuda` feature
|
||||
Make sure to add the `candle-core` crate with the cuda feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/LaurentMazare/candle.git candle --features cuda
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
|
||||
```
|
||||
|
||||
You can check everything works properly:
|
||||
Run `cargo build` to make sure everything can be correctly built.
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**Without Cuda support**:
|
||||
|
||||
Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core
|
||||
```
|
||||
|
||||
Finally, run `cargo build` to make sure everything can be correctly built.
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**With mkl support**
|
||||
|
||||
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)
|
||||
|
@ -1 +1,7 @@
|
||||
# Running a model
|
||||
|
||||
|
||||
In order to run an existing model, you will need to download and use existing weights.
|
||||
Most models are already available on https://huggingface.co/ in [`safetensors`](https://github.com/huggingface/safetensors) format.
|
||||
|
||||
Let's get started by running an old model : `bert-base-uncased`.
|
||||
|
@ -1 +1,104 @@
|
||||
# Using the hub
|
||||
|
||||
Install the [`hf-hub`](https://github.com/huggingface/hf-hub) crate:
|
||||
|
||||
```bash
|
||||
cargo add hf-hub
|
||||
```
|
||||
|
||||
Then let's start by downloading the [model file](https://huggingface.co/bert-base-uncased/tree/main).
|
||||
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# extern crate hf_hub;
|
||||
use hf_hub::api::sync::Api;
|
||||
use candle_core::Device;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
|
||||
let weights = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let weights = candle_core::safetensors::load(weights, &Device::Cpu);
|
||||
```
|
||||
|
||||
We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.
|
||||
|
||||
You can check all the names of the tensors [here](https://huggingface.co/bert-base-uncased?show_tensors=true)
|
||||
|
||||
|
||||
## Using async
|
||||
|
||||
`hf-hub` comes with an async API.
|
||||
|
||||
```bash
|
||||
cargo add hf-hub --features tokio
|
||||
```
|
||||
|
||||
```rust,ignore
|
||||
# This is tested directly in examples crate because it needs external dependencies unfortunately:
|
||||
# See [this](https://github.com/rust-lang/mdBook/issues/706)
|
||||
{{#include ../lib.rs:book_hub_1}}
|
||||
```
|
||||
|
||||
|
||||
## Using in a real model.
|
||||
|
||||
Now that we have our weights, we can use them in our bert architecture:
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# extern crate candle_nn;
|
||||
# extern crate hf_hub;
|
||||
# use hf_hub::api::sync::Api;
|
||||
#
|
||||
# let api = Api::new().unwrap();
|
||||
# let repo = api.model("bert-base-uncased".to_string());
|
||||
#
|
||||
# let weights = repo.get("model.safetensors").unwrap();
|
||||
use candle_core::{Device, Tensor, DType};
|
||||
use candle_nn::{Linear, Module};
|
||||
|
||||
let weights = candle_core::safetensors::load(weights, &Device::Cpu).unwrap();
|
||||
|
||||
let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
|
||||
let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();
|
||||
|
||||
let linear = Linear::new(weight.clone(), Some(bias.clone()));
|
||||
|
||||
let input_ids = Tensor::zeros((3, 768), DType::F32, &Device::Cpu).unwrap();
|
||||
let output = linear.forward(&input_ids).unwrap();
|
||||
```
|
||||
|
||||
For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example.
|
||||
|
||||
## Memory mapping
|
||||
|
||||
For more efficient loading, instead of reading the file, you could use [`memmap2`](https://docs.rs/memmap2/latest/memmap2/)
|
||||
|
||||
**Note**: Be careful about memory mapping it seems to cause issues on [Windows, WSL](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5893)
|
||||
and will definitely be slower on network mounted disk, because it will issue more read calls.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../lib.rs:book_hub_2}}
|
||||
```
|
||||
|
||||
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
|
||||
In practice model files should never be modified, and the mmaps should be mostly READONLY anyway, so the caveat most likely does not apply, but always keep it in mind.
|
||||
|
||||
|
||||
## Tensor Parallel Sharding
|
||||
|
||||
When using multiple GPUs to use in Tensor Parallel in order to get good latency, you can load only the part of the Tensor you need.
|
||||
|
||||
For that you need to use [`safetensors`](https://crates.io/crates/safetensors) directly.
|
||||
|
||||
```bash
|
||||
cargo add safetensors
|
||||
```
|
||||
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../lib.rs:book_hub_3}}
|
||||
```
|
||||
|
193
candle-book/src/lib.rs
Normal file
193
candle-book/src/lib.rs
Normal file
@ -0,0 +1,193 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use parquet::file::reader::SerializedFileReader;
|
||||
|
||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||
#[rustfmt::skip]
|
||||
#[tokio::test]
|
||||
async fn book_hub_1() {
|
||||
// ANCHOR: book_hub_1
|
||||
use candle::Device;
|
||||
use hf_hub::api::tokio::Api;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
|
||||
let weights_filename = repo.get("model.safetensors").await.unwrap();
|
||||
|
||||
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_1
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_2() {
|
||||
// ANCHOR: book_hub_2
|
||||
use candle::Device;
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_2
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_3() {
|
||||
// ANCHOR: book_hub_3
|
||||
use candle::{DType, Device, Tensor};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use safetensors::slice::IndexOp;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
|
||||
// Use safetensors directly
|
||||
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
|
||||
let view = tensors
|
||||
.tensor("bert.encoder.layer.0.attention.self.query.weight")
|
||||
.unwrap();
|
||||
|
||||
// We're going to load shard with rank 1, within a world_size of 4
|
||||
// We're going to split along dimension 0 doing VIEW[start..stop, :]
|
||||
let rank = 1;
|
||||
let world_size = 4;
|
||||
let dim = 0;
|
||||
let dtype = view.dtype();
|
||||
let mut tp_shape = view.shape().to_vec();
|
||||
let size = tp_shape[0];
|
||||
|
||||
if size % world_size != 0 {
|
||||
panic!("The dimension is not divisble by `world_size`");
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
let stop = (rank + 1) * block_size;
|
||||
|
||||
// Everything is expressed in tensor dimension
|
||||
// bytes offsets is handled automatically for safetensors.
|
||||
|
||||
let iterator = view.slice(start..stop).unwrap();
|
||||
|
||||
tp_shape[dim] = block_size;
|
||||
|
||||
// Convert safetensors Dtype to candle DType
|
||||
let dtype: DType = dtype.try_into().unwrap();
|
||||
|
||||
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
|
||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_3
|
||||
assert_eq!(view.shape(), &[768, 768]);
|
||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_training_1() -> Result<()>{
|
||||
// ANCHOR: book_training_1
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
let dataset_id = "mnist".to_string();
|
||||
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR_END: book_training_1
|
||||
// Ignore unused
|
||||
let _train = train_parquet;
|
||||
// ANCHOR: book_training_2
|
||||
for row in test_parquet {
|
||||
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
|
||||
println!("Column id {idx}, name {name}, value {field}");
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_2
|
||||
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR: book_training_3
|
||||
|
||||
let test_samples = 10_000;
|
||||
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
|
||||
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
|
||||
for row in test_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
test_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
test_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
|
||||
|
||||
let train_samples = 60_000;
|
||||
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
|
||||
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
|
||||
for row in train_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
train_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
train_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
|
||||
|
||||
let mnist = candle_datasets::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
};
|
||||
|
||||
// ANCHOR_END: book_training_3
|
||||
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
|
||||
assert_eq!(mnist.test_labels.dims(), &[10_000]);
|
||||
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
|
||||
assert_eq!(mnist.train_labels.dims(), &[60_000]);
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -1 +1,39 @@
|
||||
# Training
|
||||
|
||||
|
||||
Training starts with data. We're going to use the huggingface hub and
|
||||
start with the Hello world dataset of machine learning, MNIST.
|
||||
|
||||
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
|
||||
|
||||
This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
|
||||
```bash
|
||||
cargo add hf-hub
|
||||
```
|
||||
|
||||
This is going to be very hands-on for now.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
|
||||
```
|
||||
|
||||
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
|
||||
Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
|
||||
|
||||
We can inspect the content of the files with:
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_training_2}}
|
||||
```
|
||||
|
||||
You should see something like:
|
||||
|
||||
```bash
|
||||
Column id 1, name label, value 6
|
||||
Column id 0, name image, value {bytes: [137, ....]
|
||||
Column id 1, name label, value 8
|
||||
Column id 0, name image, value {bytes: [137, ....]
|
||||
```
|
||||
|
||||
So each row contains 2 columns (image, label) with image being saved as bytes.
|
||||
Let's put them into a useful struct.
|
||||
|
@ -1 +1,10 @@
|
||||
# MNIST
|
||||
|
||||
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../lib.rs:book_training_3}}
|
||||
```
|
||||
|
||||
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
|
||||
It is quite rudimentary, but simple enough for a small dataset like MNIST.
|
||||
|
@ -10,8 +10,9 @@ license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.1.0", optional = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.2.1", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
@ -21,14 +22,19 @@ memmap2 = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
num_cpus = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rand_distr = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["dep:cudarc", "dep:candle-kernels"]
|
||||
cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
|
@ -1,29 +1,18 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
||||
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
||||
let c = a.matmul(&b)?;
|
||||
println!("{a} {b} {c}");
|
||||
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
|
||||
let t1 = Tensor::new(data, &Device::Cpu)?;
|
||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];
|
||||
let t2 = Tensor::new(data2, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
Tensor::cat(&[&t1.t()?, &t2.t()?], 1)?
|
||||
.t()?
|
||||
.to_vec2::<f32>()?,
|
||||
[
|
||||
[3.0, 1.0, 4.0, 1.0, 5.0],
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0],
|
||||
[5.0, 5.0, 5.0, 5.0, 5.0],
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0]
|
||||
]
|
||||
);
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
let start = std::time::Instant::now();
|
||||
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
||||
println!("{:?}", start.elapsed());
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,3 +1,6 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
@ -6,10 +9,21 @@ use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
|
||||
let sum = t.sum_keepdim(0)?;
|
||||
println!("{sum}");
|
||||
let sum = t.sum_keepdim(1)?;
|
||||
println!("{sum}");
|
||||
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
|
||||
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
|
||||
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
println!("{out_t}");
|
||||
let in_t = in_t.to_device(&Device::Cpu)?;
|
||||
let k_t = k_t.to_device(&Device::Cpu)?;
|
||||
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
|
||||
.sqr()?
|
||||
.sum_all()?;
|
||||
println!("{diff}");
|
||||
|
||||
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
||||
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
||||
let res = t.conv2d(&w, 1, 1, 1, 1)?;
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,6 +1,9 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::Result;
|
||||
|
299
candle-core/examples/tensor-tools.rs
Normal file
299
candle-core/examples/tensor-tools.rs
Normal file
@ -0,0 +1,299 @@
|
||||
use candle_core::quantized::{gguf_file, k_quants, QTensor};
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum QuantizationMode {
|
||||
/// The default quantization includes all 2d tensors, except the output tensor which always
|
||||
/// uses Q6_K.
|
||||
Llama,
|
||||
}
|
||||
|
||||
impl QuantizationMode {
|
||||
fn quantize(
|
||||
&self,
|
||||
name: &str,
|
||||
tensor: QTensor,
|
||||
default: fn(&Tensor) -> Result<QTensor>,
|
||||
) -> Result<QTensor> {
|
||||
match self {
|
||||
Self::Llama => {
|
||||
// Same behavior as the llama.cpp quantization.
|
||||
let should_quantize = name.ends_with(".weight") && tensor.rank() == 2;
|
||||
if should_quantize {
|
||||
let tensor = tensor.dequantize(&Device::Cpu)?;
|
||||
if name == "output.weight" {
|
||||
QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
|
||||
} else {
|
||||
default(&tensor)
|
||||
}
|
||||
} else {
|
||||
Ok(tensor)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum Quantization {
|
||||
#[value(name = "q4_0")]
|
||||
Q4_0,
|
||||
#[value(name = "q4_1")]
|
||||
Q4_1,
|
||||
#[value(name = "q5_0")]
|
||||
Q5_0,
|
||||
#[value(name = "q5_1")]
|
||||
Q5_1,
|
||||
#[value(name = "q8_0")]
|
||||
Q8_0,
|
||||
#[value(name = "q8_1")]
|
||||
Q8_1,
|
||||
Q2k,
|
||||
Q3k,
|
||||
Q4k,
|
||||
Q5k,
|
||||
Q6k,
|
||||
Q8k,
|
||||
F16,
|
||||
F32,
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum Format {
|
||||
Safetensors,
|
||||
Npz,
|
||||
Ggml,
|
||||
Gguf,
|
||||
Pth,
|
||||
Pickle,
|
||||
}
|
||||
|
||||
impl Format {
|
||||
fn infer<P: AsRef<std::path::Path>>(p: P) -> Option<Self> {
|
||||
p.as_ref()
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.and_then(|e| match e {
|
||||
// We don't infer any format for .bin as it can be used for ggml/gguf or pytorch.
|
||||
"safetensors" | "safetensor" => Some(Self::Safetensors),
|
||||
"npz" => Some(Self::Npz),
|
||||
"pth" | "pt" => Some(Self::Pth),
|
||||
"ggml" => Some(Self::Ggml),
|
||||
"gguf" => Some(Self::Gguf),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Command {
|
||||
Ls {
|
||||
files: Vec<std::path::PathBuf>,
|
||||
|
||||
/// The file format to use, if unspecified infer from the file extension.
|
||||
#[arg(long, value_enum)]
|
||||
format: Option<Format>,
|
||||
|
||||
/// Enable verbose mode.
|
||||
#[arg(short, long)]
|
||||
verbose: bool,
|
||||
},
|
||||
|
||||
Quantize {
|
||||
/// The input file, in gguf format.
|
||||
in_file: std::path::PathBuf,
|
||||
/// The output file, in gguf format.
|
||||
out_file: std::path::PathBuf,
|
||||
|
||||
/// The quantization schema to apply.
|
||||
#[arg(long, value_enum)]
|
||||
quantization: Quantization,
|
||||
|
||||
/// Which tensor to quantize.
|
||||
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
||||
mode: QuantizationMode,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
|
||||
let format = match format {
|
||||
Some(format) => format,
|
||||
None => match Format::infer(file) {
|
||||
Some(format) => format,
|
||||
None => {
|
||||
println!(
|
||||
"{file:?}: cannot infer format from file extension, use the --format flag"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
},
|
||||
};
|
||||
match format {
|
||||
Format::Npz => {
|
||||
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
||||
let mut names = tensors.names();
|
||||
names.sort();
|
||||
for name in names {
|
||||
let shape_dtype = match tensors.get_shape_and_dtype(name) {
|
||||
Ok((shape, dtype)) => format!("[{shape:?}; {dtype:?}]"),
|
||||
Err(err) => err.to_string(),
|
||||
};
|
||||
println!("{name}: {shape_dtype}")
|
||||
}
|
||||
}
|
||||
Format::Safetensors => {
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
|
||||
let tensors = tensors.deserialize()?;
|
||||
let mut tensors = tensors.tensors();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, view) in tensors.iter() {
|
||||
let dtype = view.dtype();
|
||||
let dtype = match candle_core::DType::try_from(dtype) {
|
||||
Ok(dtype) => format!("{dtype:?}"),
|
||||
Err(_) => format!("{dtype:?}"),
|
||||
};
|
||||
let shape = view.shape();
|
||||
println!("{name}: [{shape:?}; {dtype}]")
|
||||
}
|
||||
}
|
||||
Format::Pth => {
|
||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
|
||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
for tensor_info in tensors.iter() {
|
||||
println!(
|
||||
"{}: [{:?}; {:?}]",
|
||||
tensor_info.name,
|
||||
tensor_info.layout.shape(),
|
||||
tensor_info.dtype,
|
||||
);
|
||||
if verbose {
|
||||
println!(" {:?}", tensor_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
Format::Pickle => {
|
||||
let file = std::fs::File::open(file)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
let mut stack = candle_core::pickle::Stack::empty();
|
||||
stack.read_loop(&mut reader)?;
|
||||
for (i, obj) in stack.stack().iter().enumerate() {
|
||||
println!("{i} {obj:?}");
|
||||
}
|
||||
}
|
||||
Format::Ggml => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
|
||||
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, qtensor) in tensors.iter() {
|
||||
println!("{name}: [{:?}; {:?}]", qtensor.shape(), qtensor.dtype());
|
||||
}
|
||||
}
|
||||
Format::Gguf => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = gguf_file::Content::read(&mut file)?;
|
||||
if verbose {
|
||||
let mut metadata = content.metadata.into_iter().collect::<Vec<_>>();
|
||||
metadata.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
println!("metadata entries ({})", metadata.len());
|
||||
for (key, value) in metadata.iter() {
|
||||
println!(" {key}: {value:?}");
|
||||
}
|
||||
}
|
||||
let mut tensors = content.tensor_infos.into_iter().collect::<Vec<_>>();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, info) in tensors.iter() {
|
||||
println!("{name}: [{:?}; {:?}]", info.shape, info.ggml_dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
qmode: QuantizationMode,
|
||||
) -> Result<()> {
|
||||
// Open the out file early so as to fail directly on missing directories etc.
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut in_ = std::fs::File::open(&in_file)?;
|
||||
let content = gguf_file::Content::read(&mut in_)?;
|
||||
println!("tensors: {}", content.tensor_infos.len());
|
||||
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
|
||||
let qtensors = content
|
||||
.tensor_infos
|
||||
.par_iter()
|
||||
.map(|(name, _)| {
|
||||
println!(" quantizing {name}");
|
||||
let mut in_file = std::fs::File::open(&in_file)?;
|
||||
let tensor = content.tensor(&mut in_file, name)?;
|
||||
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let qtensors = qtensors
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let metadata = content
|
||||
.metadata
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
gguf_file::write(&mut out_file, metadata.as_slice(), &qtensors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
match args.command {
|
||||
Command::Ls {
|
||||
files,
|
||||
format,
|
||||
verbose,
|
||||
} => {
|
||||
let multiple_files = files.len() > 1;
|
||||
for file in files.iter() {
|
||||
if multiple_files {
|
||||
println!("--- {file:?} ---");
|
||||
}
|
||||
run_ls(file, format.clone(), verbose)?
|
||||
}
|
||||
}
|
||||
Command::Quantize {
|
||||
in_file,
|
||||
out_file,
|
||||
quantization,
|
||||
mode,
|
||||
} => run_quantize(in_file, out_file, quantization, mode)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
412
candle-core/src/accelerate.rs
Normal file
412
candle-core/src/accelerate.rs
Normal file
@ -0,0 +1,412 @@
|
||||
#![allow(dead_code)]
|
||||
use libc::{c_char, c_double, c_float, c_int, c_long, c_ulong};
|
||||
|
||||
mod ffi {
|
||||
use super::*;
|
||||
extern "C" {
|
||||
// It would be nice to be able to switch to the NEWLAPACK version of the function but this
|
||||
// seems to trigger some link error. Available function names can be seen here:
|
||||
// /Library/Developer/CommandLineTools/SDKs/MacOSX13.3.sdk/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate.tbd
|
||||
#[link_name = "sgemm_"]
|
||||
pub fn sgemm_ffi(
|
||||
transa: *const c_char,
|
||||
transb: *const c_char,
|
||||
m: *const c_int,
|
||||
n: *const c_int,
|
||||
k: *const c_int,
|
||||
alpha: *const c_float,
|
||||
a: *const c_float,
|
||||
lda: *const c_int,
|
||||
b: *const c_float,
|
||||
ldb: *const c_int,
|
||||
beta: *const c_float,
|
||||
c: *mut c_float,
|
||||
ldc: *const c_int,
|
||||
);
|
||||
#[link_name = "dgemm_"]
|
||||
pub fn dgemm_ffi(
|
||||
transa: *const c_char,
|
||||
transb: *const c_char,
|
||||
m: *const c_int,
|
||||
n: *const c_int,
|
||||
k: *const c_int,
|
||||
alpha: *const c_double,
|
||||
a: *const c_double,
|
||||
lda: *const c_int,
|
||||
b: *const c_double,
|
||||
ldb: *const c_int,
|
||||
beta: *const c_double,
|
||||
c: *mut c_double,
|
||||
ldc: *const c_int,
|
||||
);
|
||||
|
||||
pub fn vvexpf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvexp(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvsqrtf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvsqrt(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvsinf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvsin(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvcosf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvtanhf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvtanh(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
|
||||
pub fn vDSP_vaddD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vadd(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vsubD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vsub(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmulD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmul(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vdivD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vdiv(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vminD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmin(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmaxD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmax(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[inline]
|
||||
pub unsafe fn sgemm(
|
||||
transa: u8,
|
||||
transb: u8,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: f32,
|
||||
a: &[f32],
|
||||
lda: i32,
|
||||
b: &[f32],
|
||||
ldb: i32,
|
||||
beta: f32,
|
||||
c: &mut [f32],
|
||||
ldc: i32,
|
||||
) {
|
||||
ffi::sgemm_ffi(
|
||||
&(transa as c_char),
|
||||
&(transb as c_char),
|
||||
&m,
|
||||
&n,
|
||||
&k,
|
||||
&alpha,
|
||||
a.as_ptr(),
|
||||
&lda,
|
||||
b.as_ptr(),
|
||||
&ldb,
|
||||
&beta,
|
||||
c.as_mut_ptr(),
|
||||
&ldc,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[inline]
|
||||
pub unsafe fn dgemm(
|
||||
transa: u8,
|
||||
transb: u8,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: f64,
|
||||
a: &[f64],
|
||||
lda: i32,
|
||||
b: &[f64],
|
||||
ldb: i32,
|
||||
beta: f64,
|
||||
c: &mut [f64],
|
||||
ldc: i32,
|
||||
) {
|
||||
ffi::dgemm_ffi(
|
||||
&(transa as c_char),
|
||||
&(transb as c_char),
|
||||
&m,
|
||||
&n,
|
||||
&k,
|
||||
&alpha,
|
||||
a.as_ptr(),
|
||||
&lda,
|
||||
b.as_ptr(),
|
||||
&ldb,
|
||||
&beta,
|
||||
c.as_mut_ptr(),
|
||||
&ldc,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvexpf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvexp(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_sqrt(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvsqrtf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_sqrt(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvsqrt(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_sin(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvsinf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_sin(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvsin(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
#[inline]
|
||||
pub fn vs_cos(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvcosf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_cos(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
#[inline]
|
||||
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvtanhf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvlogf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_ln(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvlog(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_sqr(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||
#[inline]
|
||||
pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) {
|
||||
let a_len = a.len();
|
||||
let b_len = b.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len || b_len != y_len {
|
||||
panic!(
|
||||
"{} a,b,y len mismatch {a_len} {b_len} {y_len}",
|
||||
stringify!($fn_name)
|
||||
);
|
||||
}
|
||||
unsafe {
|
||||
// Weird quirk of accelerate, the rhs comes before the lhs.
|
||||
ffi::$accelerate_name(
|
||||
b.as_ptr(),
|
||||
1,
|
||||
a.as_ptr(),
|
||||
1,
|
||||
y.as_mut_ptr(),
|
||||
1,
|
||||
a_len as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
binary_op!(vs_add, f32, vDSP_vadd);
|
||||
binary_op!(vd_add, f64, vDSP_vaddD);
|
||||
binary_op!(vs_sub, f32, vDSP_vsub);
|
||||
binary_op!(vd_sub, f64, vDSP_vsubD);
|
||||
binary_op!(vs_mul, f32, vDSP_vmul);
|
||||
binary_op!(vd_mul, f64, vDSP_vmulD);
|
||||
binary_op!(vs_div, f32, vDSP_vdiv);
|
||||
binary_op!(vd_div, f64, vDSP_vdivD);
|
||||
binary_op!(vs_max, f32, vDSP_vmax);
|
||||
binary_op!(vd_max, f64, vDSP_vmaxD);
|
||||
binary_op!(vs_min, f32, vDSP_vmin);
|
||||
binary_op!(vd_min, f64, vDSP_vminD);
|
@ -15,6 +15,8 @@ pub trait BackendStorage: Sized {
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self>;
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
|
||||
|
||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
|
||||
@ -37,6 +39,26 @@ pub trait BackendStorage: Sized {
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn scatter_add(
|
||||
&self,
|
||||
|
@ -55,6 +55,16 @@ impl Tensor {
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::Conv2D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::ConvTranspose2D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::CustomOp2(lhs, rhs, _)
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Gather(lhs, rhs, _)
|
||||
@ -81,6 +91,9 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest2D(node)
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
| Op::Copy(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Cmp(node, _)
|
||||
@ -88,9 +101,11 @@ impl Tensor {
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Permute(node, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Unary(node, _)
|
||||
| Op::Elu(node, _)
|
||||
| Op::Powf(node, _)
|
||||
| Op::CustomOp1(node, _) => {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
@ -153,6 +168,21 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
|
||||
}
|
||||
Op::Binary(lhs, rhs, BinaryOp::Minimum)
|
||||
| Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
|
||||
let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
|
||||
let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
|
||||
|
||||
// If both masks are 1 one the same point, we want to scale the
|
||||
// gradient by 0.5 rather than 1.
|
||||
let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
|
||||
let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::WhereCond(pred, t, f) => {
|
||||
let zeros = grad.zeros_like()?;
|
||||
let t_sum_grad = grads.or_insert(t)?;
|
||||
@ -163,6 +193,78 @@ impl Tensor {
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
} => {
|
||||
// The output height for conv_transpose2d is:
|
||||
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
|
||||
let grad_h = grad.dim(2)?;
|
||||
let k_h = kernel.dim(2)?;
|
||||
let out_size =
|
||||
(grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
|
||||
let out_padding = arg.dim(2)? - out_size;
|
||||
let grad_arg = grad.conv_transpose2d(
|
||||
kernel,
|
||||
*padding,
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
||||
let grad_kernel = arg
|
||||
.transpose(0, 1)?
|
||||
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose2d",
|
||||
})?,
|
||||
Op::AvgPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
stride,
|
||||
} => {
|
||||
if kernel_size != stride {
|
||||
crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
|
||||
}
|
||||
let (_n, _c, h, w) = arg.dims4()?;
|
||||
let grad_arg = grad.upsample_nearest2d(h, w)?;
|
||||
let grad_arg =
|
||||
(grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
}
|
||||
Op::MaxPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
stride,
|
||||
} => {
|
||||
if kernel_size != stride {
|
||||
crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
|
||||
}
|
||||
let (_n, _c, h, w) = arg.dims4()?;
|
||||
// For computing the max-pool gradient, we compute a mask where a 1 means
|
||||
// that the element is the maximum, then we apply this mask to the
|
||||
// upsampled gradient (taking into account that multiple max may exist so
|
||||
// we scale the gradient for this case).
|
||||
let node_upsampled = node.upsample_nearest2d(h, w)?;
|
||||
let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
|
||||
let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
|
||||
let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
}
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
Op::Gather(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
@ -277,6 +379,11 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Tanh) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let minus_dtanh = (node.sqr()? - 1.)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Abs) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let ones = arg.ones_like()?;
|
||||
@ -291,6 +398,11 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&grad)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Recip) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let grad = (grad / arg.sqr()?)?;
|
||||
*sum_grad = sum_grad.sub(&grad)?
|
||||
}
|
||||
&Op::Narrow(ref arg, dim, start_idx, len) => {
|
||||
let arg_dims = arg.dims();
|
||||
let left_pad = if start_idx == 0 {
|
||||
@ -331,6 +443,11 @@ impl Tensor {
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::Powf(arg, e) => {
|
||||
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::CustomOp1(arg, c) => {
|
||||
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
@ -384,6 +501,15 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Permute(arg, dims) => {
|
||||
let mut inv_dims = vec![0; dims.len()];
|
||||
for (i, &dim_idx) in dims.iter().enumerate() {
|
||||
inv_dims[dim_idx] = i
|
||||
}
|
||||
let arg_grad = grad.permute(inv_dims)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConv1D {
|
||||
pub(crate) b_size: Option<usize>,
|
||||
pub(crate) b_size: usize,
|
||||
// Maybe we should have a version without l_in as this bit depends on the input and not only on
|
||||
// the weights.
|
||||
pub(crate) l_in: usize,
|
||||
@ -9,19 +11,240 @@ pub struct ParamsConv1D {
|
||||
pub(crate) k_size: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConv1D {
|
||||
pub(crate) fn l_out(&self) -> usize {
|
||||
let dilation = 1;
|
||||
(self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
|
||||
(self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
let l_out = self.l_out();
|
||||
match self.b_size {
|
||||
None => vec![self.c_out, l_out],
|
||||
Some(n) => vec![n, self.c_out, l_out],
|
||||
vec![self.b_size, self.c_out, l_out]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConv2D {
|
||||
pub(crate) b_size: usize,
|
||||
pub(crate) i_h: usize,
|
||||
pub(crate) i_w: usize,
|
||||
pub(crate) k_h: usize,
|
||||
pub(crate) k_w: usize,
|
||||
pub(crate) c_out: usize,
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConv2D {
|
||||
pub(crate) fn out_h(&self) -> usize {
|
||||
(self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_w(&self) -> usize {
|
||||
(self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConvTranspose2D {
|
||||
pub(crate) b_size: usize,
|
||||
pub(crate) i_h: usize,
|
||||
pub(crate) i_w: usize,
|
||||
pub(crate) k_h: usize,
|
||||
pub(crate) k_w: usize,
|
||||
pub(crate) c_out: usize,
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) output_padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConvTranspose2D {
|
||||
pub(crate) fn out_h(&self) -> usize {
|
||||
(self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
|
||||
- 2 * self.padding
|
||||
}
|
||||
|
||||
pub(crate) fn out_w(&self) -> usize {
|
||||
(self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
|
||||
- 2 * self.padding
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k * groups {
|
||||
Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
k_shape: kernel.shape().clone(),
|
||||
padding,
|
||||
stride,
|
||||
msg: "the number of in-channels on the input doesn't match the kernel size",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
|
||||
let params = ParamsConv1D {
|
||||
b_size,
|
||||
l_in,
|
||||
c_out: c_out / groups,
|
||||
c_in: c_in / groups,
|
||||
k_size,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv1d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let kernel = kernel.chunk(groups, 0)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.zip(&kernel)
|
||||
.map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
}
|
||||
|
||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 2D convolution over the input tensor.
|
||||
pub fn conv2d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k * groups {
|
||||
crate::bail!(
|
||||
"in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
|
||||
)
|
||||
}
|
||||
let params = ParamsConv2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out: c_out / groups,
|
||||
c_in: c_in / groups,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let kernel = kernel.chunk(groups, 0)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.zip(&kernel)
|
||||
.map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies a 2D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose2d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = ParamsConvTranspose2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
let storage = self.storage().conv_transpose2d(
|
||||
self.layout(),
|
||||
&kernel.storage(),
|
||||
kernel.layout(),
|
||||
¶ms,
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
output_padding: params.output_padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! Implement conversion traits for tensors
|
||||
use crate::{Device, Error, Tensor, WithDType};
|
||||
use half::{bf16, f16};
|
||||
use crate::{DType, Device, Error, Tensor, WithDType};
|
||||
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
||||
use std::convert::TryFrom;
|
||||
|
||||
impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
|
||||
@ -92,5 +92,54 @@ from_tensor!(f64);
|
||||
from_tensor!(f32);
|
||||
from_tensor!(f16);
|
||||
from_tensor!(bf16);
|
||||
from_tensor!(i64);
|
||||
from_tensor!(u32);
|
||||
from_tensor!(u8);
|
||||
|
||||
impl Tensor {
|
||||
pub fn write_bytes<W: std::io::Write>(&self, f: &mut W) -> crate::Result<()> {
|
||||
use byteorder::{LittleEndian, WriteBytesExt};
|
||||
|
||||
let vs = self.flatten_all()?;
|
||||
match self.dtype() {
|
||||
DType::BF16 => {
|
||||
let vs = vs.to_vec1::<bf16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
let vs = vs.to_vec1::<f16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F32 => {
|
||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||
for v in vs.to_vec1::<f32>()? {
|
||||
f.write_f32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F64 => {
|
||||
for v in vs.to_vec1::<f64>()? {
|
||||
f.write_f64::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U32 => {
|
||||
for v in vs.to_vec1::<u32>()? {
|
||||
f.write_u32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::I64 => {
|
||||
for v in vs.to_vec1::<i64>()? {
|
||||
f.write_i64::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U8 => {
|
||||
let vs = vs.to_vec1::<u8>()?;
|
||||
f.write_all(&vs)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
148
candle-core/src/cpu/avx.rs
Normal file
148
candle-core/src/cpu/avx.rs
Normal file
@ -0,0 +1,148 @@
|
||||
use super::{Cpu, CpuF16};
|
||||
#[cfg(target_arch = "x86")]
|
||||
use core::arch::x86::*;
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
use core::arch::x86_64::*;
|
||||
|
||||
use half::f16;
|
||||
|
||||
pub struct CurrentCpu {}
|
||||
|
||||
const STEP: usize = 32;
|
||||
const EPR: usize = 8;
|
||||
const ARR: usize = STEP / EPR;
|
||||
|
||||
impl Cpu<ARR> for CurrentCpu {
|
||||
type Unit = __m256;
|
||||
type Array = [__m256; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
_mm256_setzero_ps()
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit {
|
||||
_mm256_set1_ps(v)
|
||||
}
|
||||
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
|
||||
_mm256_loadu_ps(mem_addr)
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(_mm256_mul_ps(b, c), a)
|
||||
}
|
||||
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
|
||||
_mm256_storeu_ps(mem_addr, a);
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
for i in 0..ARR / 2 {
|
||||
x[2 * i] = _mm256_add_ps(x[2 * i], x[2 * i + 1]);
|
||||
}
|
||||
for i in 0..ARR / 4 {
|
||||
x[4 * i] = _mm256_add_ps(x[4 * i], x[4 * i + 2]);
|
||||
}
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
for i in 0..ARR / 8 {
|
||||
x[8 * i] = _mm256_add_ps(x[8 * i], x[8 * i + 4]);
|
||||
}
|
||||
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
|
||||
let t1 = _mm_hadd_ps(t0, t0);
|
||||
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CurrentCpuF16 {}
|
||||
impl CpuF16<ARR> for CurrentCpuF16 {
|
||||
type Unit = __m256;
|
||||
type Array = [__m256; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
_mm256_setzero_ps()
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit {
|
||||
_mm256_set1_ps(v)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "f16c")]
|
||||
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
|
||||
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "f16c"))]
|
||||
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
|
||||
let mut tmp = [0.0f32; 8];
|
||||
for i in 0..8 {
|
||||
tmp[i] = (*mem_addr.add(i)).to_f32();
|
||||
}
|
||||
_mm256_loadu_ps(tmp.as_ptr())
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(_mm256_mul_ps(b, c), a)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "f16c")]
|
||||
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
|
||||
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "f16c"))]
|
||||
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
|
||||
let mut tmp = [0.0f32; 8];
|
||||
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
|
||||
for i in 0..8 {
|
||||
*mem_addr.add(i) = f16::from_f32(tmp[i]);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
let mut offset = ARR >> 1;
|
||||
for i in 0..offset {
|
||||
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
||||
}
|
||||
offset >>= 1;
|
||||
for i in 0..offset {
|
||||
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
||||
}
|
||||
offset >>= 1;
|
||||
for i in 0..offset {
|
||||
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
||||
}
|
||||
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
|
||||
let t1 = _mm_hadd_ps(t0, t0);
|
||||
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
|
||||
}
|
||||
}
|
191
candle-core/src/cpu/kernels.rs
Normal file
191
candle-core/src/cpu/kernels.rs
Normal file
@ -0,0 +1,191 @@
|
||||
pub trait VecOps: num_traits::NumAssign + Copy {
|
||||
fn min(self, rhs: Self) -> Self;
|
||||
fn max(self, rhs: Self) -> Self;
|
||||
|
||||
/// Dot-product of two vectors.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = Self::zero();
|
||||
for i in 0..len {
|
||||
*res += *lhs.add(i) * *rhs.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sum of all elements in a vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len`. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = Self::zero();
|
||||
for i in 0..len {
|
||||
*res += *xs.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
/// Maximum element in a non-empty vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = *xs;
|
||||
for i in 1..len {
|
||||
*res = (*res).max(*xs.add(i))
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimum element in a non-empty vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = *xs;
|
||||
for i in 1..len {
|
||||
*res = (*res).min(*xs.add(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for f32 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
super::vec_dot_f32(lhs, rhs, res, len)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
|
||||
super::vec_sum(xs, res, len)
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for half::f16 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
let mut res_f32 = 0f32;
|
||||
super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
|
||||
*res = half::f16::from_f32(res_f32);
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for f64 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for half::bf16 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for u8 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
<Self as Ord>::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
<Self as Ord>::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for u32 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
<Self as Ord>::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
<Self as Ord>::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for i64 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
<Self as Ord>::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
<Self as Ord>::max(self, other)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
||||
if n_threads == 1 {
|
||||
func(0)
|
||||
} else {
|
||||
rayon::scope(|s| {
|
||||
for thread_idx in 0..n_threads {
|
||||
let func = &func;
|
||||
s.spawn(move |_| func(thread_idx));
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
||||
if n_threads == 1 {
|
||||
for i in lo..up {
|
||||
func(i)
|
||||
}
|
||||
} else {
|
||||
rayon::scope(|s| {
|
||||
for thread_idx in 0..n_threads {
|
||||
let func = &func;
|
||||
s.spawn(move |_| {
|
||||
for i in (thread_idx..up).step_by(n_threads) {
|
||||
func(i)
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
179
candle-core/src/cpu/mod.rs
Normal file
179
candle-core/src/cpu/mod.rs
Normal file
@ -0,0 +1,179 @@
|
||||
pub mod kernels;
|
||||
|
||||
trait Cpu<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
const STEP: usize;
|
||||
const EPR: usize;
|
||||
|
||||
fn n() -> usize;
|
||||
unsafe fn zero() -> Self::Unit;
|
||||
unsafe fn zero_array() -> Self::Array;
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit;
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit;
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
||||
}
|
||||
|
||||
trait CpuF16<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
const STEP: usize;
|
||||
const EPR: usize;
|
||||
|
||||
fn n() -> usize;
|
||||
unsafe fn zero() -> Self::Unit;
|
||||
unsafe fn zero_array() -> Self::Array;
|
||||
unsafe fn load(mem_addr: *const f16) -> Self::Unit;
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit;
|
||||
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
|
||||
}
|
||||
use half::f16;
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub use avx::{CurrentCpu, CurrentCpuF16};
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub use simd128::CurrentCpu;
|
||||
|
||||
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub use neon::CurrentCpu;
|
||||
|
||||
#[cfg(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
|
||||
let np = k & !(CurrentCpu::STEP - 1);
|
||||
|
||||
let mut sum = CurrentCpu::zero_array();
|
||||
let mut ax = CurrentCpu::zero_array();
|
||||
let mut ay = CurrentCpu::zero_array();
|
||||
|
||||
for i in (0..np).step_by(CurrentCpu::STEP) {
|
||||
for j in 0..CurrentCpu::n() {
|
||||
ax[j] = CurrentCpu::load(a_row.add(i + j * CurrentCpu::EPR));
|
||||
ay[j] = CurrentCpu::load(b_row.add(i + j * CurrentCpu::EPR));
|
||||
|
||||
sum[j] = CurrentCpu::vec_fma(sum[j], ax[j], ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
CurrentCpu::vec_reduce(sum, c);
|
||||
|
||||
// leftovers
|
||||
for i in np..k {
|
||||
*c += *a_row.add(i) * (*b_row.add(i));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
)))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
|
||||
// leftovers
|
||||
for i in 0..k {
|
||||
*c += *a_row.add(i) * (*b_row.add(i));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
|
||||
let np = k & !(CurrentCpu::STEP - 1);
|
||||
|
||||
let mut sum = CurrentCpu::zero_array();
|
||||
let mut x = CurrentCpu::zero_array();
|
||||
|
||||
for i in (0..np).step_by(CurrentCpu::STEP) {
|
||||
for j in 0..CurrentCpu::n() {
|
||||
x[j] = CurrentCpu::load(row.add(i + j * CurrentCpu::EPR));
|
||||
sum[j] = CurrentCpu::vec_add(sum[j], x[j]);
|
||||
}
|
||||
}
|
||||
|
||||
CurrentCpu::vec_reduce(sum, b);
|
||||
|
||||
// leftovers
|
||||
for i in np..k {
|
||||
*b += *row.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
)))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
|
||||
*b = 0f32;
|
||||
for i in 0..k {
|
||||
*b += *row.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
|
||||
let mut sumf = 0.0f32;
|
||||
let np = k & !(CurrentCpuF16::STEP - 1);
|
||||
|
||||
let mut sum = CurrentCpuF16::zero_array();
|
||||
let mut ax = CurrentCpuF16::zero_array();
|
||||
let mut ay = CurrentCpuF16::zero_array();
|
||||
|
||||
for i in (0..np).step_by(CurrentCpuF16::STEP) {
|
||||
for j in 0..CurrentCpuF16::n() {
|
||||
ax[j] = CurrentCpuF16::load(a_row.add(i + j * CurrentCpuF16::EPR));
|
||||
ay[j] = CurrentCpuF16::load(b_row.add(i + j * CurrentCpuF16::EPR));
|
||||
|
||||
sum[j] = CurrentCpuF16::vec_fma(sum[j], ax[j], ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
CurrentCpuF16::vec_reduce(sum, &mut sumf);
|
||||
|
||||
// leftovers
|
||||
for i in np..k {
|
||||
sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
|
||||
}
|
||||
*c = sumf;
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "avx"))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
|
||||
// leftovers
|
||||
let mut sum = 0.0;
|
||||
for i in 0..k {
|
||||
sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
|
||||
}
|
||||
*c = sum;
|
||||
}
|
74
candle-core/src/cpu/neon.rs
Normal file
74
candle-core/src/cpu/neon.rs
Normal file
@ -0,0 +1,74 @@
|
||||
use super::Cpu;
|
||||
#[cfg(target_arch = "arm")]
|
||||
use core::arch::arm::*;
|
||||
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
use core::arch::aarch64::*;
|
||||
|
||||
pub struct CurrentCpu {}
|
||||
|
||||
const STEP: usize = 16;
|
||||
const EPR: usize = 4;
|
||||
const ARR: usize = STEP / EPR;
|
||||
|
||||
impl CurrentCpu {
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
unsafe fn reduce_one(x: float32x4_t) -> f32 {
|
||||
vaddvq_f32(x)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "arm")]
|
||||
unsafe fn reduce_one(x: float32x4_t) -> f32 {
|
||||
vgetq_lane_f32(x, 0) + vgetq_lane_f32(x, 1) + vgetq_lane_f32(x, 2) + vgetq_lane_f32(x, 3)
|
||||
}
|
||||
}
|
||||
|
||||
impl Cpu<ARR> for CurrentCpu {
|
||||
type Unit = float32x4_t;
|
||||
type Array = [float32x4_t; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
vdupq_n_f32(0.0)
|
||||
}
|
||||
|
||||
unsafe fn from_f32(x: f32) -> Self::Unit {
|
||||
vdupq_n_f32(x)
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
|
||||
vld1q_f32(mem_addr)
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
vaddq_f32(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
vfmaq_f32(a, b, c)
|
||||
}
|
||||
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
|
||||
vst1q_f32(mem_addr, a);
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
for i in 0..ARR / 2 {
|
||||
x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]);
|
||||
}
|
||||
for i in 0..ARR / 4 {
|
||||
x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]);
|
||||
}
|
||||
*y = Self::reduce_one(x[0]);
|
||||
}
|
||||
}
|
64
candle-core/src/cpu/simd128.rs
Normal file
64
candle-core/src/cpu/simd128.rs
Normal file
@ -0,0 +1,64 @@
|
||||
use super::Cpu;
|
||||
use core::arch::wasm32::*;
|
||||
|
||||
pub struct CurrentCpu {}
|
||||
|
||||
const STEP: usize = 16;
|
||||
const EPR: usize = 4;
|
||||
const ARR: usize = STEP / EPR;
|
||||
|
||||
impl Cpu<ARR> for CurrentCpu {
|
||||
type Unit = v128;
|
||||
type Array = [v128; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
f32x4_splat(0.0)
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit {
|
||||
f32x4_splat(v)
|
||||
}
|
||||
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
|
||||
v128_load(mem_addr as *mut v128)
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
f32x4_add(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
f32x4_add(f32x4_mul(b, c), a)
|
||||
}
|
||||
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
|
||||
v128_store(mem_addr as *mut v128, a);
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
for i in 0..ARR / 2 {
|
||||
x[2 * i] = f32x4_add(x[2 * i], x[2 * i + 1]);
|
||||
}
|
||||
for i in 0..ARR / 4 {
|
||||
x[4 * i] = f32x4_add(x[4 * i], x[4 * i + 2]);
|
||||
}
|
||||
for i in 0..ARR / 8 {
|
||||
x[8 * i] = f32x4_add(x[8 * i], x[8 * i + 4]);
|
||||
}
|
||||
*y = f32x4_extract_lane::<0>(x[0])
|
||||
+ f32x4_extract_lane::<1>(x[0])
|
||||
+ f32x4_extract_lane::<2>(x[0])
|
||||
+ f32x4_extract_lane::<3>(x[0]);
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use candle_kernels as kernels;
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{
|
||||
@ -64,7 +64,7 @@ impl From<CudaError> for crate::Error {
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct DeviceId(usize);
|
||||
pub struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
fn new() -> Self {
|
||||
@ -111,6 +111,14 @@ impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||
self.device.clone()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DeviceId {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
@ -131,6 +139,14 @@ impl CudaDevice {
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||
let params = (&data, v as i64, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||
@ -228,6 +244,10 @@ impl BackendDevice for CudaDevice {
|
||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
@ -257,11 +277,13 @@ impl BackendDevice for CudaDevice {
|
||||
let slice = match dtype {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})
|
||||
.w()?,
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
@ -273,10 +295,12 @@ impl BackendDevice for CudaDevice {
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
if lo != 0.0 || up != 1.0 {
|
||||
let slice = if lo == 0. && up == 1.0 {
|
||||
slice
|
||||
} else {
|
||||
let layout = Layout::contiguous(shape);
|
||||
Affine(up - lo, lo).map(&slice, self, &layout)?;
|
||||
}
|
||||
Affine(up - lo, lo).map(&slice, self, &layout)?
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
@ -289,11 +313,13 @@ impl BackendDevice for CudaDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})
|
||||
.w()?,
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
curand
|
||||
@ -328,6 +354,10 @@ impl BackendDevice for CudaDevice {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
@ -353,9 +383,10 @@ impl BackendDevice for CudaDevice {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum CudaStorageSlice {
|
||||
pub enum CudaStorageSlice {
|
||||
U8(CudaSlice<u8>),
|
||||
U32(CudaSlice<u32>),
|
||||
I64(CudaSlice<i64>),
|
||||
BF16(CudaSlice<bf16>),
|
||||
F16(CudaSlice<f16>),
|
||||
F32(CudaSlice<f32>),
|
||||
@ -363,7 +394,7 @@ enum CudaStorageSlice {
|
||||
}
|
||||
type S = CudaStorageSlice;
|
||||
|
||||
trait Map1 {
|
||||
pub trait Map1 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -375,6 +406,7 @@ trait Map1 {
|
||||
let out = match s {
|
||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||
@ -384,7 +416,7 @@ trait Map1 {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2 {
|
||||
pub trait Map2 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
@ -398,6 +430,7 @@ trait Map2 {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||
@ -408,7 +441,7 @@ trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2InPlace {
|
||||
pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
@ -429,6 +462,7 @@ trait Map2InPlace {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
@ -438,7 +472,7 @@ trait Map2InPlace {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map1Any {
|
||||
pub trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -451,6 +485,7 @@ trait Map1Any {
|
||||
let out = match s {
|
||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||
@ -460,7 +495,7 @@ trait Map1Any {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2Any {
|
||||
pub trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
@ -474,6 +509,7 @@ trait Map2Any {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
@ -496,7 +532,7 @@ impl Map1 for Clone {
|
||||
}
|
||||
}
|
||||
|
||||
fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
pub fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
let dtype = T::DTYPE.as_str();
|
||||
format!("{root}_{dtype}")
|
||||
}
|
||||
@ -557,6 +593,30 @@ impl Map1 for Elu {
|
||||
}
|
||||
}
|
||||
|
||||
struct Powf(f64);
|
||||
impl Map1 for Powf {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct Sum<'a>(&'a [usize]);
|
||||
impl<'a> Map1 for Sum<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -706,6 +766,9 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
CudaStorageSlice::U8(slice) => {
|
||||
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index_select ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
@ -765,8 +828,11 @@ impl<'a> Map1 for Gather<'a> {
|
||||
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr())
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "gather ids should be u8 or u32",
|
||||
msg: "gather ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
@ -812,9 +878,10 @@ impl<'a> Map2InPlace for IndexAdd<'a> {
|
||||
};
|
||||
let (name, ids) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index-add ids should be u8 or u32",
|
||||
msg: "index-add ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
@ -859,9 +926,10 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
|
||||
};
|
||||
let (name, ids) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "scatter-add ids should be u8 or u32",
|
||||
msg: "scatter-add ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
@ -897,14 +965,13 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
// Kernel shape: (c_out, c_in_k, k_size)
|
||||
// Input shape: (b_size, c_in, l_in) or (c_in, l_in)
|
||||
let p = &self.0;
|
||||
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let l_out = p.l_out();
|
||||
let dst_el = p.c_out * l_out * p.b_size.unwrap_or(1);
|
||||
let dst_el = p.c_out * l_out * p.b_size;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -914,10 +981,193 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
} else if dims.len() == 2 {
|
||||
[&[1], dims, &[1], inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
crate::bail!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (el, l_out, p.stride, &ds, inp, k, &out);
|
||||
let params = (
|
||||
el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||
impl<'a> Map2 for Conv2D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
inp_l: &Layout,
|
||||
k: &CudaSlice<T>,
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_out, c_in_k, h_k, w_k)
|
||||
// Input shape: (b_size, c_in, h_in, w_in)
|
||||
let p = &self.0;
|
||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv2d"), kernels::CONV)?;
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv2d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
||||
impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
inp_l: &Layout,
|
||||
k: &CudaSlice<T>,
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_in_k, c_out, h_k, w_k)
|
||||
// Input shape: (b_size, c_in, h_in, w_in)
|
||||
let p = &self.0;
|
||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), kernels::CONV)?;
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el,
|
||||
out_w,
|
||||
out_h,
|
||||
p.stride,
|
||||
p.padding,
|
||||
p.output_padding,
|
||||
p.dilation,
|
||||
&ds,
|
||||
inp,
|
||||
k,
|
||||
&out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
enum PoolOp {
|
||||
Max,
|
||||
Avg,
|
||||
}
|
||||
|
||||
struct Pool2D {
|
||||
w_k: usize,
|
||||
h_k: usize,
|
||||
w_stride: usize,
|
||||
h_stride: usize,
|
||||
op: PoolOp,
|
||||
}
|
||||
|
||||
impl Map1 for Pool2D {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
inp_l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Input shape: (b_size, c, h, w)
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for pool {dims:?}")
|
||||
};
|
||||
let el = shape.elem_count();
|
||||
let out_w = (dims[2] - self.w_k) / self.w_stride + 1;
|
||||
let out_h = (dims[3] - self.h_k) / self.h_stride + 1;
|
||||
let dst_el = out_w * out_h * dims[0] * dims[1];
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let kname = match self.op {
|
||||
PoolOp::Max => "max_pool2d",
|
||||
PoolOp::Avg => "avg_pool2d",
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(kname), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el,
|
||||
self.w_k,
|
||||
self.h_k,
|
||||
self.w_stride,
|
||||
self.h_stride,
|
||||
&ds,
|
||||
inp,
|
||||
&out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct UpsampleNearest2D(usize, usize);
|
||||
impl Map1 for UpsampleNearest2D {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
inp_l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Input shape: (b_size, c, h, w)
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for upsample {dims:?}")
|
||||
};
|
||||
let (out_w, out_h) = (self.0, self.1);
|
||||
let dst_el = out_w * out_h * dims[0] * dims[1];
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let scale_w = dims[2] as f64 / out_w as f64;
|
||||
let scale_h = dims[3] as f64 / out_h as f64;
|
||||
let params = (out_w, out_h, scale_w, scale_h, &ds, inp, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
@ -944,8 +1194,12 @@ impl<'a> Map2 for WhereCond<'a> {
|
||||
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
||||
(ptr, "where_u32")
|
||||
}
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
||||
(ptr, "where_i64")
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "where conditions should be u8 or u32",
|
||||
msg: "where conditions should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
@ -1056,8 +1310,8 @@ fn slice_src_and_dst<'a, T>(
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStorage {
|
||||
slice: CudaStorageSlice,
|
||||
device: CudaDevice,
|
||||
pub slice: CudaStorageSlice,
|
||||
pub device: CudaDevice,
|
||||
}
|
||||
|
||||
pub trait CudaDType: Sized {
|
||||
@ -1089,6 +1343,7 @@ macro_rules! cuda_dtype {
|
||||
}
|
||||
cuda_dtype!(u8, U8);
|
||||
cuda_dtype!(u32, U32);
|
||||
cuda_dtype!(i64, I64);
|
||||
cuda_dtype!(f16, F16);
|
||||
cuda_dtype!(bf16, BF16);
|
||||
cuda_dtype!(f32, F32);
|
||||
@ -1202,6 +1457,7 @@ impl BackendStorage for CudaStorage {
|
||||
match self.slice {
|
||||
CudaStorageSlice::U8(_) => DType::U8,
|
||||
CudaStorageSlice::U32(_) => DType::U32,
|
||||
CudaStorageSlice::I64(_) => DType::I64,
|
||||
CudaStorageSlice::BF16(_) => DType::BF16,
|
||||
CudaStorageSlice::F16(_) => DType::F16,
|
||||
CudaStorageSlice::F32(_) => DType::F32,
|
||||
@ -1227,6 +1483,7 @@ impl BackendStorage for CudaStorage {
|
||||
let inp = match &self.slice {
|
||||
CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
@ -1249,6 +1506,12 @@ impl BackendStorage for CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
DType::I64 => {
|
||||
let out = unsafe { dev.alloc::<i64>(el) }.w()?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::I64(out)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
@ -1286,6 +1549,12 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Powf(e).map(&self.slice, &device, layout)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Elu(alpha).map(&self.slice, &device, layout)?;
|
||||
@ -1333,6 +1602,11 @@ impl BackendStorage for CudaStorage {
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
||||
Ok(CpuStorage::U32(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
||||
Ok(CpuStorage::I64(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::BF16(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
||||
@ -1381,6 +1655,127 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cudnn"))]
|
||||
fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
#[cfg(feature = "cudnn")]
|
||||
fn conv2d(
|
||||
&self,
|
||||
inp_l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
if !kernel_l.is_contiguous() {
|
||||
let slice = Conv2D(params).map(&self.slice, inp_l, &kernel.slice, kernel_l, &device)?;
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
let (out_w, out_h) = (params.out_w(), params.out_h());
|
||||
let dst_el = params.c_out * out_w * out_h * params.b_size;
|
||||
let slice = match (&self.slice, &kernel.slice) {
|
||||
(S::U8(inp), S::U8(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
|
||||
crate::cudnn::launch_conv2d::<u8>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::U8(out)
|
||||
}
|
||||
(S::BF16(inp), S::BF16(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
|
||||
crate::cudnn::launch_conv2d::<bf16>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::BF16(out)
|
||||
}
|
||||
(S::F16(inp), S::F16(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
|
||||
crate::cudnn::launch_conv2d::<f16>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F16(out)
|
||||
}
|
||||
(S::F32(inp), S::F32(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
|
||||
crate::cudnn::launch_conv2d::<f32>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F32(out)
|
||||
}
|
||||
(S::F64(inp), S::F64(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
|
||||
crate::cudnn::launch_conv2d::<f64>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F64(out)
|
||||
}
|
||||
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
|
||||
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
|
||||
};
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice =
|
||||
ConvTranspose2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Pool2D {
|
||||
w_k: k.0,
|
||||
h_k: k.1,
|
||||
w_stride: stride.0,
|
||||
h_stride: stride.1,
|
||||
op: PoolOp::Avg,
|
||||
}
|
||||
.map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Pool2D {
|
||||
w_k: k.0,
|
||||
h_k: k.1,
|
||||
w_stride: stride.0,
|
||||
h_stride: stride.1,
|
||||
op: PoolOp::Max,
|
||||
}
|
||||
.map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
@ -1494,6 +1889,9 @@ impl BackendStorage for CudaStorage {
|
||||
let src_shape = src_l.shape();
|
||||
let dims = src_shape.dims();
|
||||
let el_count = src_shape.elem_count();
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
|
||||
@ -1558,6 +1956,18 @@ impl BackendStorage for CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }.w()?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst).w()?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_i64", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
|
107
candle-core/src/cudnn.rs
Normal file
107
candle-core/src/cudnn.rs
Normal file
@ -0,0 +1,107 @@
|
||||
use crate::WithDType;
|
||||
use cudarc;
|
||||
use cudarc::cudnn::safe::{Conv2dForward, Cudnn};
|
||||
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
// The cudnn handles are stored per thread here rather than on the CudaDevice as they are neither
|
||||
// send nor sync.
|
||||
thread_local! {
|
||||
static CUDNN: RefCell<HashMap<crate::cuda_backend::DeviceId, Arc<Cudnn>>> = HashMap::new().into();
|
||||
}
|
||||
|
||||
impl From<cudarc::cudnn::CudnnError> for crate::Error {
|
||||
fn from(err: cudarc::cudnn::CudnnError) -> Self {
|
||||
crate::Error::wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<cudarc::driver::DriverError> for crate::Error {
|
||||
fn from(err: cudarc::driver::DriverError) -> Self {
|
||||
crate::Error::wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn launch_conv2d<
|
||||
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
||||
>(
|
||||
src: &CudaView<T>,
|
||||
src_l: &crate::Layout,
|
||||
filter: &CudaView<T>,
|
||||
dst: &mut CudaSlice<T>,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
dev: &crate::cuda_backend::CudaDevice,
|
||||
) -> crate::Result<()> {
|
||||
let device_id = dev.id();
|
||||
let cudnn = CUDNN.with(|cudnn| {
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
return Ok(cudnn.clone());
|
||||
}
|
||||
let c = Cudnn::new(dev.cuda_device());
|
||||
if let Ok(c) = &c {
|
||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||
}
|
||||
c
|
||||
})?;
|
||||
let conv = cudnn.create_conv2d::<T>(
|
||||
/* pad */ [params.padding as i32, params.padding as i32],
|
||||
/* stride */ [params.stride as i32, params.stride as i32],
|
||||
/* dilation */ [params.dilation as i32, params.dilation as i32],
|
||||
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
||||
)?;
|
||||
let x_shape = [
|
||||
params.b_size as i32,
|
||||
params.c_in as i32,
|
||||
params.i_h as i32,
|
||||
params.i_w as i32,
|
||||
];
|
||||
// Note that `src` already starts at the proper offset.
|
||||
let x = if src_l.is_contiguous() {
|
||||
cudnn.create_4d_tensor(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
x_shape,
|
||||
)?
|
||||
} else {
|
||||
let s = src_l.stride();
|
||||
cudnn.create_4d_tensor_ex(
|
||||
x_shape,
|
||||
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
|
||||
)?
|
||||
};
|
||||
let w = cudnn.create_4d_filter(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[
|
||||
params.c_out as i32,
|
||||
params.c_in as i32,
|
||||
params.k_h as i32,
|
||||
params.k_w as i32,
|
||||
],
|
||||
)?;
|
||||
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
||||
let y = cudnn.create_4d_tensor(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
||||
)?;
|
||||
let conv2d = Conv2dForward {
|
||||
conv: &conv,
|
||||
x: &x,
|
||||
w: &w,
|
||||
y: &y,
|
||||
};
|
||||
let alg = conv2d.pick_algorithm()?;
|
||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
||||
alg,
|
||||
Some(&mut workspace),
|
||||
(T::one(), T::zero()),
|
||||
src,
|
||||
filter,
|
||||
dst,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -16,7 +16,6 @@ pub enum Device {
|
||||
Cuda(crate::CudaDevice),
|
||||
}
|
||||
|
||||
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
||||
pub trait NdArray {
|
||||
fn shape(&self) -> Result<Shape>;
|
||||
|
||||
@ -81,6 +80,49 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
|
||||
for &[[[[S; N4]; N3]; N2]; N1]
|
||||
{
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from((N1, N2, N3, N4)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
|
||||
for i1 in 0..N1 {
|
||||
for i2 in 0..N2 {
|
||||
for i3 in 0..N3 {
|
||||
vec.extend(self[i1][i2][i3])
|
||||
}
|
||||
}
|
||||
}
|
||||
S::to_cpu_storage_owned(vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: NdArray> NdArray for Vec<S> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let shape0 = self[0].shape()?;
|
||||
let n = self.len();
|
||||
for v in self.iter() {
|
||||
let shape = v.shape()?;
|
||||
if shape != shape0 {
|
||||
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||
}
|
||||
}
|
||||
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
// This allocates intermediary memory and shouldn't be necessary.
|
||||
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
|
||||
CpuStorage::concat(storages.as_slice()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub fn new_cuda(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
@ -101,6 +143,13 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cpu(&self) -> bool {
|
||||
match self {
|
||||
Self::Cpu => true,
|
||||
Self::Cuda(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cuda(&self) -> bool {
|
||||
match self {
|
||||
Self::Cpu => false,
|
||||
|
@ -9,11 +9,14 @@ impl Tensor {
|
||||
&self,
|
||||
f: &mut std::fmt::Formatter,
|
||||
) -> std::fmt::Result {
|
||||
let prefix = match self.device() {
|
||||
crate::Device::Cpu => "Cpu",
|
||||
crate::Device::Cuda(_) => "Cuda",
|
||||
let device_str = match self.device().location() {
|
||||
crate::DeviceLocation::Cpu => "".to_owned(),
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
write!(f, "{prefix}Tensor[")?;
|
||||
|
||||
write!(f, "Tensor[")?;
|
||||
match self.dims() {
|
||||
[] => {
|
||||
if let Ok(v) = self.to_scalar::<T>() {
|
||||
@ -40,7 +43,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
}
|
||||
write!(f, "; {}]", self.dtype().as_str())
|
||||
write!(f, "; {}{}]", self.dtype().as_str(), device_str)
|
||||
}
|
||||
}
|
||||
|
||||
@ -49,6 +52,7 @@ impl std::fmt::Debug for Tensor {
|
||||
match self.dtype() {
|
||||
DType::U8 => self.fmt_dt::<u8>(f),
|
||||
DType::U32 => self.fmt_dt::<u32>(f),
|
||||
DType::I64 => self.fmt_dt::<i64>(f),
|
||||
DType::BF16 => self.fmt_dt::<bf16>(f),
|
||||
DType::F16 => self.fmt_dt::<f16>(f),
|
||||
DType::F32 => self.fmt_dt::<f32>(f),
|
||||
@ -431,6 +435,12 @@ impl std::fmt::Display for Tensor {
|
||||
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
DType::I64 => {
|
||||
let tf: IntFormatter<i64> = IntFormatter::new();
|
||||
let max_w = tf.max_width(&to_display);
|
||||
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
DType::BF16 => {
|
||||
if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
|
||||
let max_w = tf.max_width(&to_display);
|
||||
@ -460,6 +470,20 @@ impl std::fmt::Display for Tensor {
|
||||
}
|
||||
}
|
||||
};
|
||||
write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str())
|
||||
|
||||
let device_str = match self.device().location() {
|
||||
crate::DeviceLocation::Cpu => "".to_owned(),
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(
|
||||
f,
|
||||
"Tensor[{:?}, {}{}]",
|
||||
self.dims(),
|
||||
self.dtype().as_str(),
|
||||
device_str
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -1,13 +1,24 @@
|
||||
//! Types for elements that can be stored and manipulated using tensors.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{CpuStorage, Error, Result};
|
||||
|
||||
/// The different types of elements allowed in tensors.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
// Unsigned 8 bits integer.
|
||||
U8,
|
||||
// Unsigned 32 bits integer.
|
||||
U32,
|
||||
// Signed 64 bits integer.
|
||||
I64,
|
||||
// Brain floating-point using half precision (16 bits).
|
||||
BF16,
|
||||
// Floating-point using half precision (16 bits).
|
||||
F16,
|
||||
// Floating-point using single precision (32 bits).
|
||||
F32,
|
||||
// Floating-point using double precision (64 bits).
|
||||
F64,
|
||||
}
|
||||
|
||||
@ -20,6 +31,7 @@ impl std::str::FromStr for DType {
|
||||
match s {
|
||||
"u8" => Ok(Self::U8),
|
||||
"u32" => Ok(Self::U32),
|
||||
"i64" => Ok(Self::I64),
|
||||
"bf16" => Ok(Self::BF16),
|
||||
"f16" => Ok(Self::F16),
|
||||
"f32" => Ok(Self::F32),
|
||||
@ -30,10 +42,12 @@ impl std::str::FromStr for DType {
|
||||
}
|
||||
|
||||
impl DType {
|
||||
/// String representation for dtypes.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::U8 => "u8",
|
||||
Self::U32 => "u32",
|
||||
Self::I64 => "i64",
|
||||
Self::BF16 => "bf16",
|
||||
Self::F16 => "f16",
|
||||
Self::F32 => "f32",
|
||||
@ -41,10 +55,12 @@ impl DType {
|
||||
}
|
||||
}
|
||||
|
||||
/// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
|
||||
pub fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
Self::U8 => 4,
|
||||
Self::U8 => 1,
|
||||
Self::U32 => 4,
|
||||
Self::I64 => 8,
|
||||
Self::BF16 => 2,
|
||||
Self::F16 => 2,
|
||||
Self::F32 => 4,
|
||||
@ -53,7 +69,17 @@ impl DType {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithDType: Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + 'static {
|
||||
pub trait WithDType:
|
||||
Sized
|
||||
+ Copy
|
||||
+ num_traits::NumAssign
|
||||
+ std::cmp::PartialOrd
|
||||
+ std::fmt::Display
|
||||
+ 'static
|
||||
+ Send
|
||||
+ Sync
|
||||
+ crate::cpu::kernels::VecOps
|
||||
{
|
||||
const DTYPE: DType;
|
||||
|
||||
fn from_f64(v: f64) -> Self;
|
||||
@ -115,6 +141,7 @@ use half::{bf16, f16};
|
||||
|
||||
with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
|
||||
with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
|
||||
with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
|
||||
with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
|
||||
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
|
||||
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
||||
@ -125,6 +152,15 @@ pub trait IntDType: WithDType {
|
||||
fn as_usize(&self) -> usize;
|
||||
}
|
||||
|
||||
impl IntDType for i64 {
|
||||
fn is_true(&self) -> bool {
|
||||
*self != 0
|
||||
}
|
||||
fn as_usize(&self) -> usize {
|
||||
*self as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl IntDType for u32 {
|
||||
fn is_true(&self) -> bool {
|
||||
*self != 0
|
||||
|
@ -37,6 +37,10 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -75,6 +79,26 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -119,6 +143,18 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::backend::BackendDevice for CudaDevice {
|
||||
|
@ -30,7 +30,7 @@ pub enum Error {
|
||||
UnsupportedDTypeForOp(DType, &'static str),
|
||||
|
||||
// === Dimension Index Errors ===
|
||||
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
|
||||
#[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
|
||||
DimOutOfRange {
|
||||
shape: Shape,
|
||||
dim: i32,
|
||||
@ -185,6 +185,13 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
Wrapped(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
/// Adding path information to an error.
|
||||
#[error("path: {path:?} {inner}")]
|
||||
WithPath {
|
||||
inner: Box<Self>,
|
||||
path: std::path::PathBuf,
|
||||
},
|
||||
|
||||
#[error("{inner}\n{backtrace}")]
|
||||
WithBacktrace {
|
||||
inner: Box<Self>,
|
||||
@ -200,7 +207,11 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
impl Error {
|
||||
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
Self::Wrapped(Box::new(err))
|
||||
Self::Wrapped(Box::new(err)).bt()
|
||||
}
|
||||
|
||||
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
Self::Msg(err.to_string()).bt()
|
||||
}
|
||||
|
||||
pub fn bt(self) -> Self {
|
||||
@ -214,6 +225,13 @@ impl Error {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
|
||||
Self::WithPath {
|
||||
inner: Box::new(self),
|
||||
path: p.as_ref().to_path_buf(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
|
@ -9,6 +9,14 @@ pub struct Layout {
|
||||
}
|
||||
|
||||
impl Layout {
|
||||
pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
|
||||
Self {
|
||||
shape,
|
||||
stride,
|
||||
start_offset,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
|
||||
let shape = shape.into();
|
||||
let stride = shape.stride_contiguous();
|
||||
@ -112,6 +120,31 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
let is_permutation =
|
||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||
if !is_permutation {
|
||||
crate::bail!(
|
||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||
self.dims(),
|
||||
idxs
|
||||
)
|
||||
}
|
||||
let stride = self.stride();
|
||||
let dims = self.shape().dims();
|
||||
let mut perm_stride = stride.to_vec();
|
||||
let mut perm_dims = dims.to_vec();
|
||||
for (i, &idx) in idxs.iter().enumerate() {
|
||||
perm_stride[i] = stride[idx];
|
||||
perm_dims[i] = dims[idx];
|
||||
}
|
||||
Ok(Self {
|
||||
shape: Shape::from(perm_dims),
|
||||
stride: perm_stride,
|
||||
start_offset: self.start_offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
if shape.rank() < self.shape().rank() {
|
||||
|
@ -33,13 +33,18 @@
|
||||
//!
|
||||
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
mod accelerate;
|
||||
pub mod backend;
|
||||
pub mod backprop;
|
||||
mod conv;
|
||||
mod convert;
|
||||
pub mod cpu;
|
||||
pub mod cpu_backend;
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod cuda_backend;
|
||||
#[cfg(feature = "cudnn")]
|
||||
pub mod cudnn;
|
||||
mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
@ -51,11 +56,15 @@ pub mod layout;
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
mod op;
|
||||
pub mod pickle;
|
||||
pub mod quantized;
|
||||
pub mod safetensors;
|
||||
pub mod scalar;
|
||||
pub mod shape;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
mod variable;
|
||||
|
||||
@ -80,3 +89,39 @@ pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
pub trait ToUsize2 {
|
||||
fn to_usize2(self) -> (usize, usize);
|
||||
}
|
||||
|
||||
impl ToUsize2 for usize {
|
||||
fn to_usize2(self) -> (usize, usize) {
|
||||
(self, self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToUsize2 for (usize, usize) {
|
||||
fn to_usize2(self) -> (usize, usize) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// A simple trait defining a module with forward method using a single argument.
|
||||
pub trait Module: std::fmt::Debug {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
|
||||
/// Change the module to use training mode vs eval mode.
|
||||
///
|
||||
/// The default implementation does nothing as this is only used for a couple modules such as
|
||||
/// dropout or batch-normalization.
|
||||
fn set_training(&mut self, _training: bool) {}
|
||||
}
|
||||
|
||||
impl Module for quantized::QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
@ -25,6 +25,10 @@ mod ffi {
|
||||
pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
pub fn vsFmax(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdFmax(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
pub fn vsFmin(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdFmin(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
|
||||
pub fn sgemm_(
|
||||
transa: *const c_char,
|
||||
@ -297,7 +301,7 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
@ -307,7 +311,7 @@ fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
@ -376,3 +380,7 @@ binary_op!(vs_mul, f32, vsMul);
|
||||
binary_op!(vd_mul, f64, vdMul);
|
||||
binary_op!(vs_div, f32, vsDiv);
|
||||
binary_op!(vd_div, f64, vdDiv);
|
||||
binary_op!(vs_max, f32, vsFmax);
|
||||
binary_op!(vd_max, f64, vdFmax);
|
||||
binary_op!(vs_min, f32, vsFmin);
|
||||
binary_op!(vd_min, f64, vdFmin);
|
||||
|
@ -26,7 +26,7 @@
|
||||
//! values = np.loadz("test.npz")
|
||||
//! ```
|
||||
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
@ -85,6 +85,7 @@ impl Header {
|
||||
DType::F16 => "f2",
|
||||
DType::F32 => "f4",
|
||||
DType::F64 => "f8",
|
||||
DType::I64 => "i8",
|
||||
DType::U32 => "u4",
|
||||
DType::U8 => "u1",
|
||||
};
|
||||
@ -160,7 +161,7 @@ impl Header {
|
||||
"f" | "f4" => DType::F32,
|
||||
"d" | "f8" => DType::F64,
|
||||
// "i" | "i4" => DType::S32,
|
||||
// "q" | "i8" => DType::S64,
|
||||
"q" | "i8" => DType::I64,
|
||||
// "h" | "i2" => DType::S16,
|
||||
// "b" | "i1" => DType::S8,
|
||||
"B" | "u1" => DType::U8,
|
||||
@ -196,7 +197,11 @@ impl Header {
|
||||
|
||||
impl Tensor {
|
||||
// TODO: Add the possibility to read directly to a device?
|
||||
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
|
||||
pub(crate) fn from_reader<R: std::io::Read>(
|
||||
shape: Shape,
|
||||
dtype: DType,
|
||||
reader: &mut R,
|
||||
) -> Result<Self> {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::BF16 => {
|
||||
@ -229,6 +234,11 @@ impl Tensor {
|
||||
reader.read_u32_into::<LittleEndian>(&mut data_t)?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
DType::I64 => {
|
||||
let mut data_t = vec![0i64; elem_count];
|
||||
reader.read_i64_into::<LittleEndian>(&mut data_t)?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -307,42 +317,7 @@ impl Tensor {
|
||||
header.push('\n');
|
||||
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
|
||||
f.write_all(header.as_bytes())?;
|
||||
let elem_count = self.elem_count();
|
||||
match self.dtype() {
|
||||
DType::BF16 => {
|
||||
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F32 => {
|
||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||
for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
|
||||
f.write_f32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F64 => {
|
||||
for v in self.reshape(elem_count)?.to_vec1::<f64>()? {
|
||||
f.write_f64::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U32 => {
|
||||
for v in self.reshape(elem_count)?.to_vec1::<u32>()? {
|
||||
f.write_u32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U8 => {
|
||||
let data = self.reshape(elem_count)?.to_vec1::<u8>()?;
|
||||
f.write_all(&data)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
self.write_bytes(f)
|
||||
}
|
||||
|
||||
/// Writes a multi-dimensional array in the npy format.
|
||||
@ -373,7 +348,7 @@ pub struct NpzTensors {
|
||||
index_per_name: HashMap<String, usize>,
|
||||
path: std::path::PathBuf,
|
||||
// We do not store a zip reader as it needs mutable access to extract data. Instead we
|
||||
// re-create a zip reader each time.
|
||||
// re-create a zip reader for each tensor.
|
||||
}
|
||||
|
||||
impl NpzTensors {
|
||||
@ -396,6 +371,25 @@ impl NpzTensors {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn names(&self) -> Vec<&String> {
|
||||
self.index_per_name.keys().collect()
|
||||
}
|
||||
|
||||
/// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids
|
||||
/// reading the whole tensor data.
|
||||
pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
|
||||
let index = match self.index_per_name.get(name) {
|
||||
None => crate::bail!("cannot find tensor {name}"),
|
||||
Some(index) => *index,
|
||||
};
|
||||
let zip_reader = BufReader::new(File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_index(index)?;
|
||||
let header = read_header(&mut reader)?;
|
||||
let header = Header::parse(&header)?;
|
||||
Ok((header.shape(), header.descr))
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
let index = match self.index_per_name.get(name) {
|
||||
None => return Ok(None),
|
||||
|
@ -1,3 +1,4 @@
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
@ -40,6 +41,8 @@ pub enum BinaryOp {
|
||||
Mul,
|
||||
Sub,
|
||||
Div,
|
||||
Maximum,
|
||||
Minimum,
|
||||
}
|
||||
|
||||
// Unary ops with no argument
|
||||
@ -51,10 +54,12 @@ pub enum UnaryOp {
|
||||
Cos,
|
||||
Abs,
|
||||
Neg,
|
||||
Recip,
|
||||
Sqr,
|
||||
Sqrt,
|
||||
Gelu,
|
||||
Relu,
|
||||
Tanh,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -77,8 +82,42 @@ pub enum Op {
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
Conv2D {
|
||||
arg: Tensor,
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
ConvTranspose2D {
|
||||
arg: Tensor,
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
AvgPool2D {
|
||||
arg: Tensor,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
MaxPool2D {
|
||||
arg: Tensor,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
UpsampleNearest2D(Tensor),
|
||||
|
||||
Cat(Vec<Tensor>, usize),
|
||||
|
||||
#[allow(dead_code)] // add is currently unused.
|
||||
@ -94,14 +133,25 @@ pub enum Op {
|
||||
Reshape(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
Permute(Tensor, Vec<usize>),
|
||||
Elu(Tensor, f64),
|
||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
|
||||
CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
|
||||
CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
|
||||
Powf(Tensor, f64),
|
||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
|
||||
CustomOp2(
|
||||
Tensor,
|
||||
Tensor,
|
||||
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||
),
|
||||
CustomOp3(
|
||||
Tensor,
|
||||
Tensor,
|
||||
Tensor,
|
||||
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||
),
|
||||
}
|
||||
|
||||
/// Unary ops that can be defined in user-land.
|
||||
pub trait CustomOp1: Send + Sync {
|
||||
pub trait CustomOp1 {
|
||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
@ -125,7 +175,7 @@ pub trait CustomOp1: Send + Sync {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp2: Send + Sync {
|
||||
pub trait CustomOp2 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
@ -163,7 +213,7 @@ pub trait CustomOp2: Send + Sync {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp3: Send + Sync {
|
||||
pub trait CustomOp3 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
@ -216,6 +266,7 @@ pub trait UnaryOpT {
|
||||
fn f64(v1: f64) -> f64;
|
||||
fn u8(v1: u8) -> u8;
|
||||
fn u32(v1: u32) -> u32;
|
||||
fn i64(v1: i64) -> i64;
|
||||
|
||||
// There is no very good way to represent optional function in traits so we go for an explicit
|
||||
// boolean flag to mark the function as existing.
|
||||
@ -239,6 +290,7 @@ pub trait BinaryOpT {
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
fn u8(v1: u8, v2: u8) -> u8;
|
||||
fn u32(v1: u32, v2: u32) -> u32;
|
||||
fn i64(v1: i64, v2: i64) -> i64;
|
||||
|
||||
const BF16_VEC: bool = false;
|
||||
fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
|
||||
@ -252,22 +304,28 @@ pub trait BinaryOpT {
|
||||
fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
|
||||
const U32_VEC: bool = false;
|
||||
fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
|
||||
const I64_VEC: bool = false;
|
||||
fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
|
||||
}
|
||||
|
||||
pub(crate) struct Add;
|
||||
pub(crate) struct Div;
|
||||
pub(crate) struct Mul;
|
||||
pub(crate) struct Sub;
|
||||
pub(crate) struct Maximum;
|
||||
pub(crate) struct Minimum;
|
||||
pub(crate) struct Exp;
|
||||
pub(crate) struct Log;
|
||||
pub(crate) struct Sin;
|
||||
pub(crate) struct Cos;
|
||||
pub(crate) struct Abs;
|
||||
pub(crate) struct Neg;
|
||||
pub(crate) struct Recip;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Tanh;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
@ -299,6 +357,10 @@ macro_rules! bin_op {
|
||||
fn u32(v1: u32, v2: u32) -> u32 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v1: i64, v2: i64) -> i64 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
@ -314,6 +376,21 @@ macro_rules! bin_op {
|
||||
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::$f64_vec(xs1, xs2, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F32_VEC: bool = true;
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F64_VEC: bool = true;
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
|
||||
crate::accelerate::$f32_vec(xs1, xs2, ys)
|
||||
}
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
|
||||
crate::accelerate::$f64_vec(xs1, xs2, ys)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -322,7 +399,22 @@ bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
|
||||
bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
|
||||
bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
|
||||
bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
|
||||
bin_op!(
|
||||
Minimum,
|
||||
"minimum",
|
||||
|v1, v2| if v1 > v2 { v2 } else { v1 },
|
||||
vs_min,
|
||||
vd_min
|
||||
);
|
||||
bin_op!(
|
||||
Maximum,
|
||||
"maximum",
|
||||
|v1, v2| if v1 < v2 { v2 } else { v1 },
|
||||
vs_max,
|
||||
vd_max
|
||||
);
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
macro_rules! unary_op {
|
||||
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
||||
impl UnaryOpT for $op {
|
||||
@ -353,6 +445,10 @@ macro_rules! unary_op {
|
||||
fn u32(_: u32) -> u32 {
|
||||
todo!("no unary function for u32")
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
todo!("no unary function for i64")
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -385,6 +481,10 @@ macro_rules! unary_op {
|
||||
fn u32(_: u32) -> u32 {
|
||||
todo!("no unary function for u32")
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
todo!("no unary function for i64")
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
@ -400,6 +500,21 @@ macro_rules! unary_op {
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::$f64_vec(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F32_VEC: bool = true;
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F64_VEC: bool = true;
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::accelerate::$f32_vec(xs, ys)
|
||||
}
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::accelerate::$f64_vec(xs, ys)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -408,8 +523,10 @@ unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
|
||||
unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Recip, "recip", v, v.recip());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
@ -460,6 +577,10 @@ impl UnaryOpT for Gelu {
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
const KERNEL: &'static str = "ugelu";
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
@ -509,6 +630,10 @@ impl UnaryOpT for Relu {
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
|
||||
|
725
candle-core/src/pickle.rs
Normal file
725
candle-core/src/pickle.rs
Normal file
@ -0,0 +1,725 @@
|
||||
// Just enough pickle support to be able to read PyTorch checkpoints.
|
||||
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
|
||||
// composable/tensor agnostic at some point.
|
||||
use crate::{DType, Error as E, Layout, Result, Tensor};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
use std::io::BufRead;
|
||||
|
||||
const VERBOSE: bool = false;
|
||||
|
||||
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub enum OpCode {
|
||||
// https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123
|
||||
Proto = 0x80,
|
||||
Global = b'c',
|
||||
BinPut = b'q',
|
||||
LongBinPut = b'r',
|
||||
EmptyTuple = b')',
|
||||
Reduce = b'R',
|
||||
Mark = b'(',
|
||||
BinUnicode = b'X',
|
||||
BinInt = b'J',
|
||||
Tuple = b't',
|
||||
BinPersId = b'Q',
|
||||
BinInt1 = b'K',
|
||||
BinInt2 = b'M',
|
||||
Tuple1 = 0x85,
|
||||
Tuple2 = 0x86,
|
||||
Tuple3 = 0x87,
|
||||
NewTrue = 0x88,
|
||||
NewFalse = 0x89,
|
||||
None = b'N',
|
||||
BinGet = b'h',
|
||||
LongBinGet = b'j',
|
||||
SetItem = b's',
|
||||
SetItems = b'u',
|
||||
EmptyDict = b'}',
|
||||
Dict = b'd',
|
||||
Build = b'b',
|
||||
Stop = b'.',
|
||||
NewObj = 0x81,
|
||||
EmptyList = b']',
|
||||
BinFloat = b'g',
|
||||
Append = b'a',
|
||||
Appends = b'e',
|
||||
}
|
||||
|
||||
// Avoid using FromPrimitive so as not to drag another dependency.
|
||||
impl TryFrom<u8> for OpCode {
|
||||
type Error = u8;
|
||||
fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
0x80 => Ok(Self::Proto),
|
||||
b'c' => Ok(Self::Global),
|
||||
b'q' => Ok(Self::BinPut),
|
||||
b'r' => Ok(Self::LongBinPut),
|
||||
b')' => Ok(Self::EmptyTuple),
|
||||
b'R' => Ok(Self::Reduce),
|
||||
b'(' => Ok(Self::Mark),
|
||||
b'X' => Ok(Self::BinUnicode),
|
||||
b'J' => Ok(Self::BinInt),
|
||||
b't' => Ok(Self::Tuple),
|
||||
b'Q' => Ok(Self::BinPersId),
|
||||
b'K' => Ok(Self::BinInt1),
|
||||
b'M' => Ok(Self::BinInt2),
|
||||
b'N' => Ok(Self::None),
|
||||
0x85 => Ok(Self::Tuple1),
|
||||
0x86 => Ok(Self::Tuple2),
|
||||
0x87 => Ok(Self::Tuple3),
|
||||
0x88 => Ok(Self::NewTrue),
|
||||
0x89 => Ok(Self::NewFalse),
|
||||
b'h' => Ok(Self::BinGet),
|
||||
b'j' => Ok(Self::LongBinGet),
|
||||
b's' => Ok(Self::SetItem),
|
||||
b'u' => Ok(Self::SetItems),
|
||||
b'}' => Ok(Self::EmptyDict),
|
||||
b'd' => Ok(Self::EmptyDict),
|
||||
b'b' => Ok(Self::Build),
|
||||
b'.' => Ok(Self::Stop),
|
||||
0x81 => Ok(Self::NewObj),
|
||||
b']' => Ok(Self::EmptyList),
|
||||
b'G' => Ok(Self::BinFloat),
|
||||
b'a' => Ok(Self::Append),
|
||||
b'e' => Ok(Self::Appends),
|
||||
value => Err(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
|
||||
let mut data: Vec<u8> = Vec::with_capacity(32);
|
||||
r.read_until(b'\n', &mut data)?;
|
||||
data.pop();
|
||||
if data.last() == Some(&b'\r') {
|
||||
data.pop();
|
||||
}
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Object {
|
||||
Class {
|
||||
module_name: String,
|
||||
class_name: String,
|
||||
},
|
||||
Int(i32),
|
||||
Float(f64),
|
||||
Unicode(String),
|
||||
Bool(bool),
|
||||
None,
|
||||
Tuple(Vec<Object>),
|
||||
List(Vec<Object>),
|
||||
Mark,
|
||||
Dict(Vec<(Object, Object)>),
|
||||
Reduce {
|
||||
callable: Box<Object>,
|
||||
args: Box<Object>,
|
||||
},
|
||||
Build {
|
||||
callable: Box<Object>,
|
||||
args: Box<Object>,
|
||||
},
|
||||
PersistentLoad(Box<Object>),
|
||||
}
|
||||
|
||||
type OResult<T> = std::result::Result<T, Object>;
|
||||
|
||||
impl Object {
|
||||
pub fn unicode(self) -> OResult<String> {
|
||||
match self {
|
||||
Self::Unicode(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reduce(self) -> OResult<(Self, Self)> {
|
||||
match self {
|
||||
Self::Reduce { callable, args } => Ok((*callable, *args)),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn none(self) -> OResult<()> {
|
||||
match self {
|
||||
Self::None => Ok(()),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn persistent_load(self) -> OResult<Self> {
|
||||
match self {
|
||||
Self::PersistentLoad(t) => Ok(*t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bool(self) -> OResult<bool> {
|
||||
match self {
|
||||
Self::Bool(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn int(self) -> OResult<i32> {
|
||||
match self {
|
||||
Self::Int(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tuple(self) -> OResult<Vec<Self>> {
|
||||
match self {
|
||||
Self::Tuple(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
|
||||
match self {
|
||||
Self::Dict(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn class(self) -> OResult<(String, String)> {
|
||||
match self {
|
||||
Self::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} => Ok((module_name, class_name)),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Object> for String {
|
||||
type Error = Object;
|
||||
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
Object::Unicode(s) => Ok(s),
|
||||
other => Err(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Object> for usize {
|
||||
type Error = Object;
|
||||
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
Object::Int(s) if s >= 0 => Ok(s as usize),
|
||||
other => Err(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
|
||||
type Error = Object;
|
||||
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
Object::Tuple(values) => {
|
||||
// This does not return the appropriate value in the error case but instead return
|
||||
// the object related to the first error.
|
||||
values
|
||||
.into_iter()
|
||||
.map(|v| T::try_from(v))
|
||||
.collect::<std::result::Result<Vec<T>, Self::Error>>()
|
||||
}
|
||||
other => Err(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Stack {
|
||||
stack: Vec<Object>,
|
||||
memo: HashMap<u32, Object>,
|
||||
}
|
||||
|
||||
impl Stack {
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
stack: Vec::with_capacity(512),
|
||||
memo: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stack(&self) -> &[Object] {
|
||||
self.stack.as_slice()
|
||||
}
|
||||
|
||||
pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
|
||||
loop {
|
||||
if self.read(r)? {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn finalize(mut self) -> Result<Object> {
|
||||
self.pop()
|
||||
}
|
||||
|
||||
fn push(&mut self, obj: Object) {
|
||||
self.stack.push(obj)
|
||||
}
|
||||
|
||||
fn pop(&mut self) -> Result<Object> {
|
||||
match self.stack.pop() {
|
||||
None => crate::bail!("unexpected empty stack"),
|
||||
Some(obj) => Ok(obj),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
|
||||
fn build(&mut self) -> Result<()> {
|
||||
let args = self.pop()?;
|
||||
let obj = self.pop()?;
|
||||
let obj = match (obj, args) {
|
||||
(Object::Dict(mut obj), Object::Dict(mut args)) => {
|
||||
obj.append(&mut args);
|
||||
Object::Dict(obj)
|
||||
}
|
||||
(obj, args) => Object::Build {
|
||||
callable: Box::new(obj),
|
||||
args: Box::new(args),
|
||||
},
|
||||
};
|
||||
self.push(obj);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn reduce(&mut self) -> Result<()> {
|
||||
let args = self.pop()?;
|
||||
let callable = self.pop()?;
|
||||
#[allow(clippy::single_match)]
|
||||
let reduced = match &callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} => {
|
||||
if module_name == "collections" && class_name == "OrderedDict" {
|
||||
// TODO: have a separate ordered dict.
|
||||
Some(Object::Dict(vec![]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
let reduced = reduced.unwrap_or_else(|| Object::Reduce {
|
||||
callable: Box::new(callable),
|
||||
args: Box::new(args),
|
||||
});
|
||||
self.push(reduced);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn last(&mut self) -> Result<&mut Object> {
|
||||
match self.stack.last_mut() {
|
||||
None => crate::bail!("unexpected empty stack"),
|
||||
Some(obj) => Ok(obj),
|
||||
}
|
||||
}
|
||||
|
||||
fn memo_get(&self, id: u32) -> Result<Object> {
|
||||
match self.memo.get(&id) {
|
||||
None => crate::bail!("missing object in memo {id}"),
|
||||
Some(obj) => {
|
||||
// Maybe we should use refcounting rather than doing potential large clones here.
|
||||
Ok(obj.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn memo_put(&mut self, id: u32) -> Result<()> {
|
||||
let obj = self.last()?.clone();
|
||||
self.memo.insert(id, obj);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn persistent_load(&self, id: Object) -> Result<Object> {
|
||||
Ok(Object::PersistentLoad(Box::new(id)))
|
||||
}
|
||||
|
||||
fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
|
||||
Ok(Object::Reduce {
|
||||
callable: Box::new(class),
|
||||
args: Box::new(args),
|
||||
})
|
||||
}
|
||||
|
||||
fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
|
||||
let mut mark_idx = None;
|
||||
for (idx, obj) in self.stack.iter().enumerate().rev() {
|
||||
if obj == &Object::Mark {
|
||||
mark_idx = Some(idx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
match mark_idx {
|
||||
Some(mark_idx) => {
|
||||
let objs = self.stack.split_off(mark_idx + 1);
|
||||
self.stack.pop();
|
||||
Ok(objs)
|
||||
}
|
||||
None => {
|
||||
crate::bail!("marker object not found")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
|
||||
let op_code = match OpCode::try_from(r.read_u8()?) {
|
||||
Ok(op_code) => op_code,
|
||||
Err(op_code) => {
|
||||
crate::bail!("unknown op-code {op_code}")
|
||||
}
|
||||
};
|
||||
// println!("op: {op_code:?}");
|
||||
// println!("{:?}", self.stack);
|
||||
match op_code {
|
||||
OpCode::Proto => {
|
||||
let version = r.read_u8()?;
|
||||
if VERBOSE {
|
||||
println!("proto {version}");
|
||||
}
|
||||
}
|
||||
OpCode::Global => {
|
||||
let module_name = read_to_newline(r)?;
|
||||
let class_name = read_to_newline(r)?;
|
||||
let module_name = String::from_utf8_lossy(&module_name).to_string();
|
||||
let class_name = String::from_utf8_lossy(&class_name).to_string();
|
||||
self.push(Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
})
|
||||
}
|
||||
OpCode::BinInt1 => {
|
||||
let arg = r.read_u8()?;
|
||||
self.push(Object::Int(arg as i32))
|
||||
}
|
||||
OpCode::BinInt2 => {
|
||||
let arg = r.read_u16::<LittleEndian>()?;
|
||||
self.push(Object::Int(arg as i32))
|
||||
}
|
||||
OpCode::BinInt => {
|
||||
let arg = r.read_i32::<LittleEndian>()?;
|
||||
self.push(Object::Int(arg))
|
||||
}
|
||||
OpCode::BinFloat => {
|
||||
let arg = r.read_f64::<LittleEndian>()?;
|
||||
self.push(Object::Float(arg))
|
||||
}
|
||||
OpCode::BinUnicode => {
|
||||
let len = r.read_u32::<LittleEndian>()?;
|
||||
let mut data = vec![0u8; len as usize];
|
||||
r.read_exact(&mut data)?;
|
||||
let data = String::from_utf8(data).map_err(E::wrap)?;
|
||||
self.push(Object::Unicode(data))
|
||||
}
|
||||
OpCode::BinPersId => {
|
||||
let id = self.pop()?;
|
||||
let obj = self.persistent_load(id)?;
|
||||
self.push(obj)
|
||||
}
|
||||
OpCode::Tuple => {
|
||||
let objs = self.pop_to_marker()?;
|
||||
self.push(Object::Tuple(objs))
|
||||
}
|
||||
OpCode::Tuple1 => {
|
||||
let obj = self.pop()?;
|
||||
self.push(Object::Tuple(vec![obj]))
|
||||
}
|
||||
OpCode::Tuple2 => {
|
||||
let obj2 = self.pop()?;
|
||||
let obj1 = self.pop()?;
|
||||
self.push(Object::Tuple(vec![obj1, obj2]))
|
||||
}
|
||||
OpCode::Tuple3 => {
|
||||
let obj3 = self.pop()?;
|
||||
let obj2 = self.pop()?;
|
||||
let obj1 = self.pop()?;
|
||||
self.push(Object::Tuple(vec![obj1, obj2, obj3]))
|
||||
}
|
||||
OpCode::NewTrue => self.push(Object::Bool(true)),
|
||||
OpCode::NewFalse => self.push(Object::Bool(false)),
|
||||
OpCode::Append => {
|
||||
let value = self.pop()?;
|
||||
let pylist = self.last()?;
|
||||
if let Object::List(d) = pylist {
|
||||
d.push(value)
|
||||
} else {
|
||||
crate::bail!("expected a list, got {pylist:?}")
|
||||
}
|
||||
}
|
||||
OpCode::Appends => {
|
||||
let objs = self.pop_to_marker()?;
|
||||
let pylist = self.last()?;
|
||||
if let Object::List(d) = pylist {
|
||||
d.extend(objs)
|
||||
} else {
|
||||
crate::bail!("expected a list, got {pylist:?}")
|
||||
}
|
||||
}
|
||||
OpCode::SetItem => {
|
||||
let value = self.pop()?;
|
||||
let key = self.pop()?;
|
||||
let pydict = self.last()?;
|
||||
if let Object::Dict(d) = pydict {
|
||||
d.push((key, value))
|
||||
} else {
|
||||
crate::bail!("expected a dict, got {pydict:?}")
|
||||
}
|
||||
}
|
||||
OpCode::SetItems => {
|
||||
let mut objs = self.pop_to_marker()?;
|
||||
let pydict = self.last()?;
|
||||
if let Object::Dict(d) = pydict {
|
||||
if objs.len() % 2 != 0 {
|
||||
crate::bail!("setitems: not an even number of objects")
|
||||
}
|
||||
while let Some(value) = objs.pop() {
|
||||
let key = objs.pop().unwrap();
|
||||
d.push((key, value))
|
||||
}
|
||||
} else {
|
||||
crate::bail!("expected a dict, got {pydict:?}")
|
||||
}
|
||||
}
|
||||
OpCode::None => self.push(Object::None),
|
||||
OpCode::Stop => {
|
||||
return Ok(true);
|
||||
}
|
||||
OpCode::Build => self.build()?,
|
||||
OpCode::EmptyDict => self.push(Object::Dict(vec![])),
|
||||
OpCode::Dict => {
|
||||
let mut objs = self.pop_to_marker()?;
|
||||
let mut pydict = vec![];
|
||||
if objs.len() % 2 != 0 {
|
||||
crate::bail!("setitems: not an even number of objects")
|
||||
}
|
||||
while let Some(value) = objs.pop() {
|
||||
let key = objs.pop().unwrap();
|
||||
pydict.push((key, value))
|
||||
}
|
||||
self.push(Object::Dict(pydict))
|
||||
}
|
||||
OpCode::Mark => self.push(Object::Mark),
|
||||
OpCode::Reduce => self.reduce()?,
|
||||
OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
|
||||
OpCode::EmptyList => self.push(Object::List(vec![])),
|
||||
OpCode::BinGet => {
|
||||
let arg = r.read_u8()?;
|
||||
let obj = self.memo_get(arg as u32)?;
|
||||
self.push(obj)
|
||||
}
|
||||
OpCode::LongBinGet => {
|
||||
let arg = r.read_u32::<LittleEndian>()?;
|
||||
let obj = self.memo_get(arg)?;
|
||||
self.push(obj)
|
||||
}
|
||||
OpCode::BinPut => {
|
||||
let arg = r.read_u8()?;
|
||||
self.memo_put(arg as u32)?
|
||||
}
|
||||
OpCode::LongBinPut => {
|
||||
let arg = r.read_u32::<LittleEndian>()?;
|
||||
self.memo_put(arg)?
|
||||
}
|
||||
OpCode::NewObj => {
|
||||
let args = self.pop()?;
|
||||
let class = self.pop()?;
|
||||
let obj = self.new_obj(class, args)?;
|
||||
self.push(obj)
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Object> for E {
|
||||
fn from(value: Object) -> Self {
|
||||
E::Msg(format!("conversion error on {value:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
|
||||
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
|
||||
fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
||||
let mut args = args.tuple()?;
|
||||
let stride = Vec::<usize>::try_from(args.remove(3))?;
|
||||
let size = Vec::<usize>::try_from(args.remove(2))?;
|
||||
let offset = args.remove(1).int()? as usize;
|
||||
let storage = args.remove(0).persistent_load()?;
|
||||
let mut storage = storage.tuple()?;
|
||||
let storage_size = storage.remove(4).int()? as usize;
|
||||
let path = storage.remove(2).unicode()?;
|
||||
let (_module_name, class_name) = storage.remove(1).class()?;
|
||||
let dtype = match class_name.as_str() {
|
||||
"FloatStorage" => DType::F32,
|
||||
"DoubleStorage" => DType::F64,
|
||||
"HalfStorage" => DType::F16,
|
||||
"BFloat16Storage" => DType::BF16,
|
||||
"ByteStorage" => DType::U8,
|
||||
other => {
|
||||
crate::bail!("unsupported storage type {other}")
|
||||
}
|
||||
};
|
||||
let layout = Layout::new(crate::Shape::from(size), stride, offset);
|
||||
Ok((layout, dtype, path, storage_size))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TensorInfo {
|
||||
pub name: String,
|
||||
pub dtype: DType,
|
||||
pub layout: Layout,
|
||||
pub path: String,
|
||||
pub storage_size: usize,
|
||||
}
|
||||
|
||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
file: P,
|
||||
verbose: bool,
|
||||
) -> Result<Vec<TensorInfo>> {
|
||||
let file = std::fs::File::open(file)?;
|
||||
let zip_reader = std::io::BufReader::new(file);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let zip_file_names = zip
|
||||
.file_names()
|
||||
.map(|f| f.to_string())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let mut tensor_infos = vec![];
|
||||
for file_name in zip_file_names.iter() {
|
||||
if !file_name.ends_with("data.pkl") {
|
||||
continue;
|
||||
}
|
||||
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
|
||||
let reader = zip.by_name(file_name)?;
|
||||
let mut reader = std::io::BufReader::new(reader);
|
||||
let mut stack = Stack::empty();
|
||||
stack.read_loop(&mut reader)?;
|
||||
let obj = stack.finalize()?;
|
||||
if VERBOSE || verbose {
|
||||
println!("{obj:?}");
|
||||
}
|
||||
let obj = match obj {
|
||||
Object::Build { callable, args } => match *callable {
|
||||
Object::Reduce { callable, args: _ } => match *callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "__torch__" && class_name == "Module" => *args,
|
||||
_ => continue,
|
||||
},
|
||||
_ => continue,
|
||||
},
|
||||
obj => obj,
|
||||
};
|
||||
if let Object::Dict(key_values) = obj {
|
||||
for (name, value) in key_values.into_iter() {
|
||||
let name = match name.unicode() {
|
||||
Ok(name) => name,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let (callable, args) = match value.reduce() {
|
||||
Ok(callable_args) => callable_args,
|
||||
_ => continue,
|
||||
};
|
||||
let (callable, args) = match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._tensor"
|
||||
&& class_name == "_rebuild_from_type_v2" =>
|
||||
{
|
||||
let mut args = args.tuple()?;
|
||||
let callable = args.remove(0);
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
||||
_ => continue,
|
||||
};
|
||||
match rebuild_args(args) {
|
||||
Ok((layout, dtype, file_path, storage_size)) => {
|
||||
let mut path = dir_name.clone();
|
||||
path.push(file_path);
|
||||
tensor_infos.push(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
storage_size,
|
||||
})
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("skipping {name}: {err:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(tensor_infos)
|
||||
}
|
||||
|
||||
/// Lazy tensor loader.
|
||||
pub struct PthTensors {
|
||||
tensor_infos: HashMap<String, TensorInfo>,
|
||||
path: std::path::PathBuf,
|
||||
// We do not store a zip reader as it needs mutable access to extract data. Instead we
|
||||
// re-create a zip reader for each tensor.
|
||||
}
|
||||
|
||||
impl PthTensors {
|
||||
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
|
||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
|
||||
let tensor_infos = tensor_infos
|
||||
.into_iter()
|
||||
.map(|ti| (ti.name.to_string(), ti))
|
||||
.collect();
|
||||
let path = path.as_ref().to_owned();
|
||||
Ok(Self { tensor_infos, path })
|
||||
}
|
||||
|
||||
pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
|
||||
&self.tensor_infos
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
None => return Ok(None),
|
||||
Some(tensor_info) => tensor_info,
|
||||
};
|
||||
// We hope that the file has not changed since first reading it.
|
||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
||||
// For now only support the basic case.
|
||||
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
)
|
||||
}
|
||||
let tensor = Tensor::from_reader(
|
||||
tensor_info.layout.shape().clone(),
|
||||
tensor_info.dtype,
|
||||
&mut reader,
|
||||
)?;
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
640
candle-core/src/quantized/avx.rs
Normal file
640
candle-core/src/quantized/avx.rs
Normal file
@ -0,0 +1,640 @@
|
||||
use super::k_quants::{
|
||||
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
|
||||
};
|
||||
use crate::Result;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use half::f16;
|
||||
|
||||
#[cfg(target_arch = "x86")]
|
||||
use core::arch::x86::*;
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
use core::arch::x86_64::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn sum_i16_pairs_float(x: __m256i) -> __m256 {
|
||||
let ones = _mm256_set1_epi16(1);
|
||||
let summed_pairs = _mm256_madd_epi16(ones, x);
|
||||
_mm256_cvtepi32_ps(summed_pairs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 {
|
||||
let dot = _mm256_maddubs_epi16(ax, sy);
|
||||
sum_i16_pairs_float(dot)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 {
|
||||
let res = _mm256_extractf128_ps(x, 1);
|
||||
let res = _mm_add_ps(res, _mm256_castps256_ps128(x));
|
||||
let res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
||||
let res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
||||
_mm_cvtss_f32(res)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn bytes_from_nibbles_32(rsi: *const u8) -> __m256i {
|
||||
let tmp = _mm_loadu_si128(rsi as *const __m128i);
|
||||
let bytes = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(tmp), _mm_srli_epi16(tmp, 4));
|
||||
let low_mask = _mm256_set1_epi8(0xF);
|
||||
_mm256_and_si256(low_mask, bytes)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
|
||||
let ax = _mm256_sign_epi8(x, x);
|
||||
let sy = _mm256_sign_epi8(y, x);
|
||||
mul_sum_us8_pairs_float(ax, sy)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let bx = bytes_from_nibbles_32(x.qs.as_ptr());
|
||||
let off = _mm256_set1_epi8(8);
|
||||
let bx = _mm256_sub_epi8(bx, off);
|
||||
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
|
||||
let q = mul_sum_i8_pairs_float(bx, by);
|
||||
acc = _mm256_fmadd_ps(d, q, acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i);
|
||||
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
|
||||
let q = mul_sum_i8_pairs_float(bx, by);
|
||||
acc = _mm256_fmadd_ps(d, q, acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn get_scale_shuffle(i: usize) -> __m128i {
|
||||
const K_SHUFFLE: [u8; 128] = [
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
|
||||
3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
|
||||
7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10,
|
||||
11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13,
|
||||
13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
|
||||
];
|
||||
_mm_loadu_si128((K_SHUFFLE.as_ptr() as *const __m128i).add(i))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i {
|
||||
const K_SHUFFLE: [u8; 256] = [
|
||||
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
|
||||
0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
|
||||
2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
|
||||
4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
|
||||
6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
|
||||
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10,
|
||||
11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13,
|
||||
12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
|
||||
13, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
|
||||
14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
|
||||
];
|
||||
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn get_scale_shuffle_q3k(i: usize) -> __m256i {
|
||||
const K_SHUFFLE: [u8; 128] = [
|
||||
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
|
||||
2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
|
||||
6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11,
|
||||
10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
|
||||
13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
|
||||
];
|
||||
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % qk != 0 {
|
||||
crate::bail!("vec_dot_q6k_8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let m4 = _mm256_set1_epi8(0xF);
|
||||
let m2 = _mm256_set1_epi8(3);
|
||||
let m32s = _mm256_set1_epi8(32);
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let mut q4 = x.ql.as_ptr();
|
||||
let mut qh = x.qh.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for j in 0..QK_K / 128 {
|
||||
let is = j * 4;
|
||||
let scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is));
|
||||
let scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
|
||||
let scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
|
||||
let scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
|
||||
|
||||
let q4bits1 = _mm256_loadu_si256(q4 as *const __m256i);
|
||||
q4 = q4.add(32);
|
||||
let q4bits2 = _mm256_loadu_si256(q4 as *const __m256i);
|
||||
q4 = q4.add(32);
|
||||
let q4bits_h = _mm256_loadu_si256(qh as *const __m256i);
|
||||
qh = qh.add(32);
|
||||
|
||||
let q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bits_h, m2), 4);
|
||||
let q4h_1 =
|
||||
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 2), m2), 4);
|
||||
let q4h_2 =
|
||||
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 4), m2), 4);
|
||||
let q4h_3 =
|
||||
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 6), m2), 4);
|
||||
|
||||
let q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
|
||||
let q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
|
||||
let q4_2 =
|
||||
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
|
||||
let q4_3 =
|
||||
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
|
||||
|
||||
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
|
||||
let q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
|
||||
let q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
|
||||
let q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
|
||||
let q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
|
||||
|
||||
let p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
|
||||
let p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
|
||||
let p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
|
||||
let p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
|
||||
|
||||
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
||||
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
||||
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
|
||||
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
|
||||
|
||||
let p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
|
||||
let p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
|
||||
let p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
|
||||
let p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
|
||||
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
|
||||
}
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i {
|
||||
_mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let m3 = _mm256_set1_epi8(3);
|
||||
let m4 = _mm_set1_epi8(0xF);
|
||||
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = -y.d * x.dmin.to_f32();
|
||||
|
||||
let mut q2 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mins_and_scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
|
||||
let scales8 = _mm_and_si128(mins_and_scales, m4);
|
||||
let mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
|
||||
let mins = _mm256_cvtepi8_epi16(mins8);
|
||||
let prod =
|
||||
_mm256_madd_epi16(mins, _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i));
|
||||
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
|
||||
|
||||
let all_scales = _mm256_cvtepi8_epi16(scales8);
|
||||
let l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
let h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
let scales = [
|
||||
mm256_set_m128i(l_scales, l_scales),
|
||||
mm256_set_m128i(h_scales, h_scales),
|
||||
];
|
||||
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for scale in scales {
|
||||
let q2bits = _mm256_loadu_si256(q2 as *const __m256i);
|
||||
q2 = q2.add(32);
|
||||
|
||||
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
|
||||
let q2_0 = _mm256_and_si256(q2bits, m3);
|
||||
let q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
|
||||
let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
|
||||
let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
|
||||
|
||||
let p0 = _mm256_maddubs_epi16(q2_0, q8_0);
|
||||
let p1 = _mm256_maddubs_epi16(q2_1, q8_1);
|
||||
let p2 = _mm256_maddubs_epi16(q2_2, q8_2);
|
||||
let p3 = _mm256_maddubs_epi16(q2_3, q8_3);
|
||||
|
||||
let p0 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0);
|
||||
let p1 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1);
|
||||
let p2 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2);
|
||||
let p3 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3);
|
||||
|
||||
let p0 = _mm256_add_epi32(p0, p1);
|
||||
let p2 = _mm256_add_epi32(p2, p3);
|
||||
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
|
||||
}
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
const KMASK1: u32 = 0x03030303;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
|
||||
let mut aux = [0u32; 3];
|
||||
|
||||
unsafe {
|
||||
let m3 = _mm256_set1_epi8(3);
|
||||
let mone = _mm256_set1_epi8(1);
|
||||
let m32 = _mm_set1_epi8(32);
|
||||
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
|
||||
let mut q3 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut aux);
|
||||
let scales128 = _mm_set_epi32(
|
||||
(((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32,
|
||||
(((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32,
|
||||
((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32,
|
||||
((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32,
|
||||
);
|
||||
let scales128 = _mm_sub_epi8(scales128, m32);
|
||||
let all_scales = _mm256_cvtepi8_epi16(scales128);
|
||||
let l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
let h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
let scales = [
|
||||
mm256_set_m128i(l_scales, l_scales),
|
||||
mm256_set_m128i(h_scales, h_scales),
|
||||
];
|
||||
|
||||
// high bit
|
||||
let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i);
|
||||
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for (j, scale) in scales.iter().enumerate() {
|
||||
// load low 2 bits
|
||||
let q3bits = _mm256_loadu_si256(q3 as *const __m256i);
|
||||
q3 = q3.add(32);
|
||||
|
||||
// Prepare low and high bits
|
||||
// We hardcode the shifts here to avoid loading them into a seperate register
|
||||
let q3l_0 = _mm256_and_si256(q3bits, m3);
|
||||
let q3h_0 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
|
||||
} else {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)), 4)
|
||||
};
|
||||
let q3h_0 = _mm256_slli_epi16(q3h_0, 2);
|
||||
|
||||
let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
|
||||
let q3h_1 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)), 1)
|
||||
} else {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)), 5)
|
||||
};
|
||||
let q3h_1 = _mm256_slli_epi16(q3h_1, 2);
|
||||
|
||||
let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
|
||||
let q3h_2 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)), 2)
|
||||
} else {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)), 6)
|
||||
};
|
||||
let q3h_2 = _mm256_slli_epi16(q3h_2, 2);
|
||||
|
||||
let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
|
||||
let q3h_3 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)), 3)
|
||||
} else {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)), 7)
|
||||
};
|
||||
let q3h_3 = _mm256_slli_epi16(q3h_3, 2);
|
||||
|
||||
// load Q8 quants
|
||||
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
|
||||
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we
|
||||
// can use _mm256_maddubs_epi16, and then subtract. The high bit part has the 2
|
||||
// already subtracted (and so, it is zero if the high bit was not set, and 2 if the
|
||||
// high bit was set)
|
||||
let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
|
||||
let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
|
||||
let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
|
||||
let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
|
||||
|
||||
let p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
|
||||
let p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
|
||||
let p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
|
||||
let p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
|
||||
|
||||
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
||||
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
||||
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
|
||||
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
|
||||
|
||||
// multiply with scales
|
||||
let p16_0 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0);
|
||||
let p16_1 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1);
|
||||
let p16_2 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2);
|
||||
let p16_3 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3);
|
||||
|
||||
// accumulate
|
||||
let p16_0 = _mm256_add_epi32(p16_0, p16_1);
|
||||
let p16_2 = _mm256_add_epi32(p16_2, p16_3);
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
|
||||
}
|
||||
|
||||
// multiply with block scale and accumulate
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut utmp = [0u32; 4];
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
unsafe {
|
||||
let m4 = _mm256_set1_epi8(0xF);
|
||||
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
let mut acc_m = _mm_setzero_ps();
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = -y.d * x.dmin.to_f32();
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
let mut q4 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
|
||||
utmp[3] as i32,
|
||||
utmp[2] as i32,
|
||||
utmp[1] as i32,
|
||||
utmp[0] as i32,
|
||||
));
|
||||
|
||||
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
|
||||
let q8s = _mm_hadd_epi16(
|
||||
_mm256_extracti128_si256(q8sums, 0),
|
||||
_mm256_extracti128_si256(q8sums, 1),
|
||||
);
|
||||
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
|
||||
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
|
||||
|
||||
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
||||
let scales = mm256_set_m128i(sc128, sc128);
|
||||
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
|
||||
let scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
|
||||
|
||||
let q4bits = _mm256_loadu_si256(q4 as *const __m256i);
|
||||
q4 = q4.add(32);
|
||||
let q4l = _mm256_and_si256(q4bits, m4);
|
||||
let q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
|
||||
|
||||
let q8l = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let p16l = _mm256_maddubs_epi16(q4l, q8l);
|
||||
let p16l = _mm256_madd_epi16(scale_l, p16l);
|
||||
sumi = _mm256_add_epi32(sumi, p16l);
|
||||
|
||||
let q8h = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let p16h = _mm256_maddubs_epi16(q4h, q8h);
|
||||
let p16h = _mm256_madd_epi16(scale_h, p16h);
|
||||
sumi = _mm256_add_epi32(sumi, p16h);
|
||||
}
|
||||
|
||||
let vd = _mm256_set1_ps(d);
|
||||
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
|
||||
let acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
|
||||
let acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
|
||||
|
||||
Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut utmp = [0u32; 4];
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
unsafe {
|
||||
let m4 = _mm256_set1_epi8(0xF);
|
||||
let mzero = _mm_setzero_si128();
|
||||
let mone = _mm256_set1_epi8(1);
|
||||
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
let mut summs = 0.0;
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = -y.d * x.dmin.to_f32();
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
let mut q5 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
|
||||
utmp[3] as i32,
|
||||
utmp[2] as i32,
|
||||
utmp[1] as i32,
|
||||
utmp[0] as i32,
|
||||
));
|
||||
|
||||
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
|
||||
let q8s = _mm_hadd_epi16(
|
||||
_mm256_extracti128_si256(q8sums, 0),
|
||||
_mm256_extracti128_si256(q8sums, 1),
|
||||
);
|
||||
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
|
||||
let hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
|
||||
summs += dmin * _mm_extract_epi32(hsum, 0) as f32;
|
||||
|
||||
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
||||
let scales = mm256_set_m128i(sc128, sc128);
|
||||
|
||||
let hbits = _mm256_loadu_si256(x.qh.as_ptr() as *const __m256i);
|
||||
let mut hmask = mone;
|
||||
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
|
||||
let scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
|
||||
|
||||
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
|
||||
q5 = q5.add(32);
|
||||
|
||||
//Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register
|
||||
let q5l_0 = _mm256_and_si256(q5bits, m4);
|
||||
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
|
||||
let q5l_0_right_shift = match j {
|
||||
0 => _mm256_srli_epi16(q5l_0_shift_input, 0),
|
||||
1 => _mm256_srli_epi16(q5l_0_shift_input, 2),
|
||||
2 => _mm256_srli_epi16(q5l_0_shift_input, 4),
|
||||
3 => _mm256_srli_epi16(q5l_0_shift_input, 6),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let q5h_0 = _mm256_slli_epi16(q5l_0_right_shift, 4);
|
||||
let q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
|
||||
hmask = _mm256_slli_epi16(hmask, 1);
|
||||
|
||||
let q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
|
||||
let q5l_1_shift_input = _mm256_and_si256(hbits, hmask);
|
||||
let q5l_1_right_shift = match j {
|
||||
0 => _mm256_srli_epi16(q5l_1_shift_input, 1),
|
||||
1 => _mm256_srli_epi16(q5l_1_shift_input, 3),
|
||||
2 => _mm256_srli_epi16(q5l_1_shift_input, 5),
|
||||
3 => _mm256_srli_epi16(q5l_1_shift_input, 7),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let q5h_1 = _mm256_slli_epi16(q5l_1_right_shift, 4);
|
||||
let q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
|
||||
hmask = _mm256_slli_epi16(hmask, 1);
|
||||
|
||||
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
|
||||
let p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
|
||||
let p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
|
||||
|
||||
let p16_0 = _mm256_madd_epi16(scale_0, p16_0);
|
||||
let p16_1 = _mm256_madd_epi16(scale_1, p16_1);
|
||||
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
||||
}
|
||||
let vd = _mm256_set1_ps(d);
|
||||
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc) + summs)
|
||||
}
|
||||
}
|
225
candle-core/src/quantized/ggml_file.rs
Normal file
225
candle-core/src/quantized/ggml_file.rs
Normal file
@ -0,0 +1,225 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
use super::{k_quants, GgmlDType};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum Magic {
|
||||
Ggjt,
|
||||
Ggla,
|
||||
Ggmf,
|
||||
Ggml,
|
||||
Ggsn,
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for Magic {
|
||||
type Error = crate::Error;
|
||||
fn try_from(value: u32) -> Result<Self> {
|
||||
let magic = match value {
|
||||
0x67676a74 => Self::Ggjt,
|
||||
0x67676c61 => Self::Ggla,
|
||||
0x67676d66 => Self::Ggmf,
|
||||
0x67676d6c => Self::Ggml,
|
||||
0x6767736e => Self::Ggsn,
|
||||
_ => crate::bail!("unknown magic {value:08x}"),
|
||||
};
|
||||
Ok(magic)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum VersionedMagic {
|
||||
GgmlUnversioned,
|
||||
GgmfV1,
|
||||
GgjtV1,
|
||||
GgjtV2,
|
||||
GgjtV3,
|
||||
}
|
||||
|
||||
impl VersionedMagic {
|
||||
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||
let magic = reader.read_u32::<LittleEndian>()?;
|
||||
let magic = Magic::try_from(magic)?;
|
||||
if magic == Magic::Ggml {
|
||||
return Ok(Self::GgmlUnversioned);
|
||||
}
|
||||
let version = reader.read_u32::<LittleEndian>()?;
|
||||
let versioned_magic = match (magic, version) {
|
||||
(Magic::Ggmf, 1) => Self::GgmfV1,
|
||||
(Magic::Ggjt, 1) => Self::GgjtV1,
|
||||
(Magic::Ggjt, 2) => Self::GgjtV2,
|
||||
(Magic::Ggjt, 3) => Self::GgjtV3,
|
||||
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||
};
|
||||
Ok(versioned_magic)
|
||||
}
|
||||
|
||||
fn align32(&self) -> bool {
|
||||
match self {
|
||||
Self::GgmlUnversioned | Self::GgmfV1 => false,
|
||||
Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct HParams {
|
||||
pub n_vocab: u32,
|
||||
pub n_embd: u32,
|
||||
pub n_mult: u32,
|
||||
pub n_head: u32,
|
||||
pub n_layer: u32,
|
||||
pub n_rot: u32,
|
||||
pub ftype: u32,
|
||||
}
|
||||
|
||||
impl HParams {
|
||||
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||
let n_vocab = reader.read_u32::<LittleEndian>()?;
|
||||
let n_embd = reader.read_u32::<LittleEndian>()?;
|
||||
let n_mult = reader.read_u32::<LittleEndian>()?;
|
||||
let n_head = reader.read_u32::<LittleEndian>()?;
|
||||
let n_layer = reader.read_u32::<LittleEndian>()?;
|
||||
let n_rot = reader.read_u32::<LittleEndian>()?;
|
||||
let ftype = reader.read_u32::<LittleEndian>()?;
|
||||
Ok(Self {
|
||||
n_vocab,
|
||||
n_embd,
|
||||
n_mult,
|
||||
n_head,
|
||||
n_layer,
|
||||
n_rot,
|
||||
ftype,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Vocab {
|
||||
pub token_score_pairs: Vec<(Vec<u8>, f32)>,
|
||||
}
|
||||
|
||||
impl Vocab {
|
||||
fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556
|
||||
let mut token_score_pairs = Vec::with_capacity(n_vocab);
|
||||
for _index in 0..n_vocab {
|
||||
let len = reader.read_u32::<LittleEndian>()? as usize;
|
||||
let mut word = vec![0u8; len];
|
||||
reader.read_exact(&mut word)?;
|
||||
let score = reader.read_f32::<LittleEndian>()?;
|
||||
token_score_pairs.push((word, score))
|
||||
}
|
||||
Ok(Self { token_score_pairs })
|
||||
}
|
||||
}
|
||||
|
||||
fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
raw_data: &[u8],
|
||||
size_in_bytes: usize,
|
||||
dims: Vec<usize>,
|
||||
) -> Result<super::QTensor> {
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
super::QTensor::new(data.to_vec(), dims)
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
pub fn qtensor_from_ggml(
|
||||
ggml_dtype: GgmlDType,
|
||||
raw_data: &[u8],
|
||||
dims: Vec<usize>,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
magic: VersionedMagic,
|
||||
) -> Result<(String, super::QTensor)> {
|
||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
|
||||
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
|
||||
let mut dims = vec![0u32; n_dims as usize];
|
||||
reader.read_u32_into::<LittleEndian>(&mut dims)?;
|
||||
// The dimensions are stored in reverse order, see for example:
|
||||
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/convert.py#L969
|
||||
dims.reverse();
|
||||
let mut name = vec![0u8; name_len as usize];
|
||||
reader.read_exact(&mut name)?;
|
||||
let name = String::from_utf8_lossy(&name).into_owned();
|
||||
|
||||
if magic.align32() {
|
||||
let pos = reader.stream_position()?;
|
||||
reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
|
||||
}
|
||||
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||
Ok(tensor) => Ok((name, tensor)),
|
||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Content {
|
||||
pub magic: VersionedMagic,
|
||||
pub hparams: HParams,
|
||||
pub vocab: Vocab,
|
||||
pub tensors: HashMap<String, super::QTensor>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||
let magic = VersionedMagic::read(reader)?;
|
||||
let hparams = HParams::read(reader)?;
|
||||
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
|
||||
let mut tensors = HashMap::new();
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||
tensors.insert(name, tensor);
|
||||
}
|
||||
Ok(Self {
|
||||
magic,
|
||||
hparams,
|
||||
vocab,
|
||||
tensors,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, name: &str) -> Result<super::QTensor> {
|
||||
match self.tensors.remove(name) {
|
||||
None => crate::bail!("cannot find tensor with name '{name}'"),
|
||||
Some(tensor) => Ok(tensor),
|
||||
}
|
||||
}
|
||||
}
|
513
candle-core/src/quantized/gguf_file.rs
Normal file
513
candle-core/src/quantized/gguf_file.rs
Normal file
@ -0,0 +1,513 @@
|
||||
//! Support for the GGUF file format.
|
||||
//!
|
||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||
|
||||
use super::{GgmlDType, QTensor};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub const DEFAULT_ALIGNMENT: u64 = 32;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum Magic {
|
||||
Gguf,
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for Magic {
|
||||
type Error = crate::Error;
|
||||
fn try_from(value: u32) -> Result<Self> {
|
||||
let magic = match value {
|
||||
0x46554747 | 0x47475546 => Self::Gguf,
|
||||
_ => crate::bail!("unknown magic 0x{value:08x}"),
|
||||
};
|
||||
Ok(magic)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum VersionedMagic {
|
||||
GgufV1,
|
||||
GgufV2,
|
||||
}
|
||||
|
||||
impl VersionedMagic {
|
||||
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||
let magic = reader.read_u32::<LittleEndian>()?;
|
||||
let magic = Magic::try_from(magic)?;
|
||||
let version = reader.read_u32::<LittleEndian>()?;
|
||||
let versioned_magic = match (magic, version) {
|
||||
(Magic::Gguf, 1) => Self::GgufV1,
|
||||
(Magic::Gguf, 2) => Self::GgufV2,
|
||||
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||
};
|
||||
Ok(versioned_magic)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TensorInfo {
|
||||
pub ggml_dtype: GgmlDType,
|
||||
pub shape: crate::Shape,
|
||||
pub offset: u64,
|
||||
}
|
||||
|
||||
impl TensorInfo {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
&self,
|
||||
reader: &mut R,
|
||||
tensor_data_offset: u64,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let size_in_bytes =
|
||||
tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Content {
|
||||
pub magic: VersionedMagic,
|
||||
pub metadata: HashMap<String, Value>,
|
||||
pub tensor_infos: HashMap<String, TensorInfo>,
|
||||
pub tensor_data_offset: u64,
|
||||
}
|
||||
|
||||
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
let mut v = vec![0u8; len];
|
||||
reader.read_exact(&mut v)?;
|
||||
// GGUF strings are supposed to be non-null terminated but in practice this happens.
|
||||
while let Some(0) = v.last() {
|
||||
v.pop();
|
||||
}
|
||||
// GGUF strings are utf8 encoded but there are cases that don't seem to be valid.
|
||||
Ok(String::from_utf8_lossy(&v).into_owned())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum ValueType {
|
||||
// The value is a 8-bit unsigned integer.
|
||||
U8,
|
||||
// The value is a 8-bit signed integer.
|
||||
I8,
|
||||
// The value is a 16-bit unsigned little-endian integer.
|
||||
U16,
|
||||
// The value is a 16-bit signed little-endian integer.
|
||||
I16,
|
||||
// The value is a 32-bit unsigned little-endian integer.
|
||||
U32,
|
||||
// The value is a 32-bit signed little-endian integer.
|
||||
I32,
|
||||
// The value is a 64-bit unsigned little-endian integer.
|
||||
U64,
|
||||
// The value is a 64-bit signed little-endian integer.
|
||||
I64,
|
||||
// The value is a 32-bit IEEE754 floating point number.
|
||||
F32,
|
||||
// The value is a 64-bit IEEE754 floating point number.
|
||||
F64,
|
||||
// The value is a boolean.
|
||||
// 1-byte value where 0 is false and 1 is true.
|
||||
// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy.
|
||||
Bool,
|
||||
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
||||
String,
|
||||
// The value is an array of other values, with the length and type prepended.
|
||||
///
|
||||
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
||||
Array,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Value {
|
||||
U8(u8),
|
||||
I8(i8),
|
||||
U16(u16),
|
||||
I16(i16),
|
||||
U32(u32),
|
||||
I32(i32),
|
||||
U64(u64),
|
||||
I64(i64),
|
||||
F32(f32),
|
||||
F64(f64),
|
||||
Bool(bool),
|
||||
String(String),
|
||||
Array(Vec<Value>),
|
||||
}
|
||||
|
||||
impl Value {
|
||||
pub fn value_type(&self) -> ValueType {
|
||||
match self {
|
||||
Self::U8(_) => ValueType::U8,
|
||||
Self::I8(_) => ValueType::I8,
|
||||
Self::U16(_) => ValueType::U16,
|
||||
Self::I16(_) => ValueType::I16,
|
||||
Self::U32(_) => ValueType::U32,
|
||||
Self::I32(_) => ValueType::I32,
|
||||
Self::U64(_) => ValueType::U64,
|
||||
Self::I64(_) => ValueType::I64,
|
||||
Self::F32(_) => ValueType::F32,
|
||||
Self::F64(_) => ValueType::F64,
|
||||
Self::Bool(_) => ValueType::Bool,
|
||||
Self::String(_) => ValueType::String,
|
||||
Self::Array(_) => ValueType::Array,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u8(&self) -> Result<u8> {
|
||||
match self {
|
||||
Self::U8(v) => Ok(*v),
|
||||
v => crate::bail!("not a u8 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i8(&self) -> Result<i8> {
|
||||
match self {
|
||||
Self::I8(v) => Ok(*v),
|
||||
v => crate::bail!("not a i8 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u16(&self) -> Result<u16> {
|
||||
match self {
|
||||
Self::U16(v) => Ok(*v),
|
||||
v => crate::bail!("not a u16 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i16(&self) -> Result<i16> {
|
||||
match self {
|
||||
Self::I16(v) => Ok(*v),
|
||||
v => crate::bail!("not a i16 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u32(&self) -> Result<u32> {
|
||||
match self {
|
||||
Self::U32(v) => Ok(*v),
|
||||
v => crate::bail!("not a u32 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i32(&self) -> Result<i32> {
|
||||
match self {
|
||||
Self::I32(v) => Ok(*v),
|
||||
v => crate::bail!("not a i32 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u64(&self) -> Result<u64> {
|
||||
match self {
|
||||
Self::U64(v) => Ok(*v),
|
||||
v => crate::bail!("not a u64 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i64(&self) -> Result<i64> {
|
||||
match self {
|
||||
Self::I64(v) => Ok(*v),
|
||||
v => crate::bail!("not a i64 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f32(&self) -> Result<f32> {
|
||||
match self {
|
||||
Self::F32(v) => Ok(*v),
|
||||
v => crate::bail!("not a f32 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f64(&self) -> Result<f64> {
|
||||
match self {
|
||||
Self::F64(v) => Ok(*v),
|
||||
v => crate::bail!("not a f64 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_bool(&self) -> Result<bool> {
|
||||
match self {
|
||||
Self::Bool(v) => Ok(*v),
|
||||
v => crate::bail!("not a bool {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_vec(&self) -> Result<&Vec<Value>> {
|
||||
match self {
|
||||
Self::Array(v) => Ok(v),
|
||||
v => crate::bail!("not a vec {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> Result<&String> {
|
||||
match self {
|
||||
Self::String(v) => Ok(v),
|
||||
v => crate::bail!("not a string {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn read<R: std::io::Read>(
|
||||
reader: &mut R,
|
||||
value_type: ValueType,
|
||||
magic: &VersionedMagic,
|
||||
) -> Result<Self> {
|
||||
let v = match value_type {
|
||||
ValueType::U8 => Self::U8(reader.read_u8()?),
|
||||
ValueType::I8 => Self::I8(reader.read_i8()?),
|
||||
ValueType::U16 => Self::U16(reader.read_u16::<LittleEndian>()?),
|
||||
ValueType::I16 => Self::I16(reader.read_i16::<LittleEndian>()?),
|
||||
ValueType::U32 => Self::U32(reader.read_u32::<LittleEndian>()?),
|
||||
ValueType::I32 => Self::I32(reader.read_i32::<LittleEndian>()?),
|
||||
ValueType::U64 => Self::U64(reader.read_u64::<LittleEndian>()?),
|
||||
ValueType::I64 => Self::I64(reader.read_i64::<LittleEndian>()?),
|
||||
ValueType::F32 => Self::F32(reader.read_f32::<LittleEndian>()?),
|
||||
ValueType::F64 => Self::F64(reader.read_f64::<LittleEndian>()?),
|
||||
ValueType::Bool => match reader.read_u8()? {
|
||||
0 => Self::Bool(false),
|
||||
1 => Self::Bool(true),
|
||||
b => crate::bail!("unexpected bool value {b}"),
|
||||
},
|
||||
ValueType::String => Self::String(read_string(reader, magic)?),
|
||||
ValueType::Array => {
|
||||
let value_type = reader.read_u32::<LittleEndian>()?;
|
||||
let value_type = ValueType::from_u32(value_type)?;
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
let mut vs = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
vs.push(Value::read(reader, value_type, magic)?)
|
||||
}
|
||||
Self::Array(vs)
|
||||
}
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
fn write<W: std::io::Write>(&self, w: &mut W) -> Result<()> {
|
||||
match self {
|
||||
&Self::U8(v) => w.write_u8(v)?,
|
||||
&Self::I8(v) => w.write_i8(v)?,
|
||||
&Self::U16(v) => w.write_u16::<LittleEndian>(v)?,
|
||||
&Self::I16(v) => w.write_i16::<LittleEndian>(v)?,
|
||||
&Self::U32(v) => w.write_u32::<LittleEndian>(v)?,
|
||||
&Self::I32(v) => w.write_i32::<LittleEndian>(v)?,
|
||||
&Self::U64(v) => w.write_u64::<LittleEndian>(v)?,
|
||||
&Self::I64(v) => w.write_i64::<LittleEndian>(v)?,
|
||||
&Self::F32(v) => w.write_f32::<LittleEndian>(v)?,
|
||||
&Self::F64(v) => w.write_f64::<LittleEndian>(v)?,
|
||||
&Self::Bool(v) => w.write_u8(u8::from(v))?,
|
||||
Self::String(v) => write_string(w, v.as_str())?,
|
||||
Self::Array(v) => {
|
||||
// The `Value` type does not enforce that all the values in an Array have the same
|
||||
// type.
|
||||
let value_type = if v.is_empty() {
|
||||
// Doesn't matter, the array is empty.
|
||||
ValueType::U32
|
||||
} else {
|
||||
let value_type: std::collections::HashSet<_> =
|
||||
v.iter().map(|elem| elem.value_type()).collect();
|
||||
if value_type.len() != 1 {
|
||||
crate::bail!("multiple value-types in the same array {value_type:?}")
|
||||
}
|
||||
value_type.into_iter().next().unwrap()
|
||||
};
|
||||
w.write_u32::<LittleEndian>(value_type.to_u32())?;
|
||||
w.write_u64::<LittleEndian>(v.len() as u64)?;
|
||||
for elem in v.iter() {
|
||||
elem.write(w)?
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueType {
|
||||
fn from_u32(v: u32) -> Result<Self> {
|
||||
let v = match v {
|
||||
0 => Self::U8,
|
||||
1 => Self::I8,
|
||||
2 => Self::U16,
|
||||
3 => Self::I16,
|
||||
4 => Self::U32,
|
||||
5 => Self::I32,
|
||||
6 => Self::F32,
|
||||
7 => Self::Bool,
|
||||
8 => Self::String,
|
||||
9 => Self::Array,
|
||||
10 => Self::U64,
|
||||
11 => Self::I64,
|
||||
12 => Self::F64,
|
||||
v => crate::bail!("unrecognized value-type {v:#08x}"),
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
fn to_u32(self) -> u32 {
|
||||
match self {
|
||||
Self::U8 => 0,
|
||||
Self::I8 => 1,
|
||||
Self::U16 => 2,
|
||||
Self::I16 => 3,
|
||||
Self::U32 => 4,
|
||||
Self::I32 => 5,
|
||||
Self::F32 => 6,
|
||||
Self::Bool => 7,
|
||||
Self::String => 8,
|
||||
Self::Array => 9,
|
||||
Self::U64 => 10,
|
||||
Self::I64 => 11,
|
||||
Self::F64 => 12,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||
let magic = VersionedMagic::read(reader)?;
|
||||
|
||||
let tensor_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
let metadata_kv_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
for _idx in 0..metadata_kv_count {
|
||||
let key = read_string(reader, &magic)?;
|
||||
let value_type = reader.read_u32::<LittleEndian>()?;
|
||||
let value_type = ValueType::from_u32(value_type)?;
|
||||
let value = Value::read(reader, value_type, &magic)?;
|
||||
metadata.insert(key, value);
|
||||
}
|
||||
let mut tensor_infos = HashMap::new();
|
||||
for _idx in 0..tensor_count {
|
||||
let tensor_name = read_string(reader, &magic)?;
|
||||
let n_dimensions = reader.read_u32::<LittleEndian>()?;
|
||||
|
||||
let mut dimensions: Vec<usize> = match magic {
|
||||
VersionedMagic::GgufV1 => {
|
||||
let mut dimensions = vec![0; n_dimensions as usize];
|
||||
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
||||
dimensions.into_iter().map(|c| c as usize).collect()
|
||||
}
|
||||
VersionedMagic::GgufV2 => {
|
||||
let mut dimensions = vec![0; n_dimensions as usize];
|
||||
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
||||
dimensions.into_iter().map(|c| c as usize).collect()
|
||||
}
|
||||
};
|
||||
|
||||
dimensions.reverse();
|
||||
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
|
||||
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
|
||||
let offset = reader.read_u64::<LittleEndian>()?;
|
||||
tensor_infos.insert(
|
||||
tensor_name,
|
||||
TensorInfo {
|
||||
shape: crate::Shape::from(dimensions),
|
||||
offset,
|
||||
ggml_dtype,
|
||||
},
|
||||
);
|
||||
}
|
||||
let position = reader.stream_position()?;
|
||||
let alignment = match metadata.get("general.alignment") {
|
||||
Some(Value::U8(v)) => *v as u64,
|
||||
Some(Value::U16(v)) => *v as u64,
|
||||
Some(Value::U32(v)) => *v as u64,
|
||||
Some(Value::I8(v)) if *v >= 0 => *v as u64,
|
||||
Some(Value::I16(v)) if *v >= 0 => *v as u64,
|
||||
Some(Value::I32(v)) if *v >= 0 => *v as u64,
|
||||
_ => DEFAULT_ALIGNMENT,
|
||||
};
|
||||
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
|
||||
Ok(Self {
|
||||
magic,
|
||||
metadata,
|
||||
tensor_infos,
|
||||
tensor_data_offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn tensor<R: std::io::Seek + std::io::Read>(
|
||||
&self,
|
||||
reader: &mut R,
|
||||
name: &str,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
Some(tensor_info) => tensor_info,
|
||||
None => crate::bail!("cannot find tensor-infor for {name}"),
|
||||
};
|
||||
tensor_info.read(reader, self.tensor_data_offset)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_string<W: std::io::Write>(w: &mut W, str: &str) -> Result<()> {
|
||||
let bytes = str.as_bytes();
|
||||
w.write_u64::<LittleEndian>(bytes.len() as u64)?;
|
||||
w.write_all(bytes)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write<W: std::io::Seek + std::io::Write>(
|
||||
w: &mut W,
|
||||
metadata: &[(&str, &Value)],
|
||||
tensors: &[(&str, &QTensor)],
|
||||
) -> Result<()> {
|
||||
w.write_u32::<LittleEndian>(0x46554747)?;
|
||||
w.write_u32::<LittleEndian>(2)?; // version 2.
|
||||
w.write_u64::<LittleEndian>(tensors.len() as u64)?;
|
||||
w.write_u64::<LittleEndian>(metadata.len() as u64)?;
|
||||
for (name, value) in metadata.iter() {
|
||||
write_string(w, name)?;
|
||||
w.write_u32::<LittleEndian>(value.value_type().to_u32())?;
|
||||
value.write(w)?;
|
||||
}
|
||||
let mut offset = 0usize;
|
||||
let mut offsets = Vec::with_capacity(tensors.len());
|
||||
for (name, tensor) in tensors.iter() {
|
||||
write_string(w, name)?;
|
||||
let dims = tensor.shape().dims();
|
||||
w.write_u32::<LittleEndian>(dims.len() as u32)?;
|
||||
for &dim in dims.iter().rev() {
|
||||
w.write_u64::<LittleEndian>(dim as u64)?;
|
||||
}
|
||||
w.write_u32::<LittleEndian>(tensor.dtype().to_u32())?;
|
||||
w.write_u64::<LittleEndian>(offset as u64)?;
|
||||
offsets.push(offset);
|
||||
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||
offset += size_in_bytes + padding;
|
||||
}
|
||||
let pos = w.stream_position()? as usize;
|
||||
let padding = 31 - (31 + pos) % 32;
|
||||
w.write_all(&vec![0u8; padding])?;
|
||||
let tensor_start_pos = w.stream_position()? as usize;
|
||||
for (offset, (_name, tensor)) in offsets.iter().zip(tensors.iter()) {
|
||||
let pos = w.stream_position()? as usize;
|
||||
if tensor_start_pos + offset != pos {
|
||||
crate::bail!(
|
||||
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
||||
)
|
||||
}
|
||||
let data_ptr = tensor.as_ptr();
|
||||
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
w.write_all(data)?;
|
||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||
w.write_all(&vec![0u8; padding])?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
1873
candle-core/src/quantized/k_quants.rs
Normal file
1873
candle-core/src/quantized/k_quants.rs
Normal file
File diff suppressed because it is too large
Load Diff
292
candle-core/src/quantized/mod.rs
Normal file
292
candle-core/src/quantized/mod.rs
Normal file
@ -0,0 +1,292 @@
|
||||
use crate::{Device, Result, Shape, Tensor};
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
pub mod ggml_file;
|
||||
pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
pub mod utils;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
|
||||
pub struct QTensor {
|
||||
data: Box<dyn QuantizedType>,
|
||||
shape: Shape,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum GgmlDType {
|
||||
F32,
|
||||
F16,
|
||||
Q4_0,
|
||||
Q4_1,
|
||||
Q5_0,
|
||||
Q5_1,
|
||||
Q8_0,
|
||||
Q8_1,
|
||||
Q2K,
|
||||
Q3K,
|
||||
Q4K,
|
||||
Q5K,
|
||||
Q6K,
|
||||
Q8K,
|
||||
}
|
||||
|
||||
impl GgmlDType {
|
||||
pub(crate) fn from_u32(u: u32) -> Result<Self> {
|
||||
let dtype = match u {
|
||||
0 => Self::F32,
|
||||
1 => Self::F16,
|
||||
2 => Self::Q4_0,
|
||||
3 => Self::Q4_1,
|
||||
6 => Self::Q5_0,
|
||||
7 => Self::Q5_1,
|
||||
8 => Self::Q8_0,
|
||||
9 => Self::Q8_1,
|
||||
10 => Self::Q2K,
|
||||
11 => Self::Q3K,
|
||||
12 => Self::Q4K,
|
||||
13 => Self::Q5K,
|
||||
14 => Self::Q6K,
|
||||
15 => Self::Q8K,
|
||||
_ => crate::bail!("unknown dtype for tensor {u}"),
|
||||
};
|
||||
Ok(dtype)
|
||||
}
|
||||
|
||||
pub(crate) fn to_u32(self) -> u32 {
|
||||
match self {
|
||||
Self::F32 => 0,
|
||||
Self::F16 => 1,
|
||||
Self::Q4_0 => 2,
|
||||
Self::Q4_1 => 3,
|
||||
Self::Q5_0 => 6,
|
||||
Self::Q5_1 => 7,
|
||||
Self::Q8_0 => 8,
|
||||
Self::Q8_1 => 9,
|
||||
Self::Q2K => 10,
|
||||
Self::Q3K => 11,
|
||||
Self::Q4K => 12,
|
||||
Self::Q5K => 13,
|
||||
Self::Q6K => 14,
|
||||
Self::Q8K => 15,
|
||||
}
|
||||
}
|
||||
|
||||
/// The type size for blocks in bytes.
|
||||
pub fn type_size(&self) -> usize {
|
||||
use k_quants::*;
|
||||
match self {
|
||||
Self::F32 => 4,
|
||||
Self::F16 => 2,
|
||||
Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
|
||||
Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
|
||||
Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
|
||||
Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
|
||||
Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
|
||||
Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
|
||||
Self::Q2K => std::mem::size_of::<BlockQ2K>(),
|
||||
Self::Q3K => std::mem::size_of::<BlockQ3K>(),
|
||||
Self::Q4K => std::mem::size_of::<BlockQ4K>(),
|
||||
Self::Q5K => std::mem::size_of::<BlockQ5K>(),
|
||||
Self::Q6K => std::mem::size_of::<BlockQ6K>(),
|
||||
Self::Q8K => std::mem::size_of::<BlockQ8K>(),
|
||||
}
|
||||
}
|
||||
|
||||
/// The block size, i.e. the number of elements stored in each block.
|
||||
pub fn blck_size(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 1,
|
||||
Self::F16 => 1,
|
||||
Self::Q4_0 => k_quants::QK4_0,
|
||||
Self::Q4_1 => k_quants::QK4_1,
|
||||
Self::Q5_0 => k_quants::QK5_0,
|
||||
Self::Q5_1 => k_quants::QK5_1,
|
||||
Self::Q8_0 => k_quants::QK8_0,
|
||||
Self::Q8_1 => k_quants::QK8_1,
|
||||
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
|
||||
pub trait QuantizedType: Send + Sync {
|
||||
fn dtype(&self) -> GgmlDType;
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
||||
fn storage_size_in_bytes(&self) -> usize;
|
||||
fn as_ptr(&self) -> *const u8;
|
||||
}
|
||||
|
||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
T::DTYPE
|
||||
}
|
||||
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
||||
T::to_float(self.as_slice(), ys)
|
||||
}
|
||||
|
||||
fn storage_size_in_bytes(&self) -> usize {
|
||||
self.len() * std::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
fn as_ptr(&self) -> *const u8 {
|
||||
self.as_ptr() as *const u8
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for QTensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
|
||||
}
|
||||
}
|
||||
|
||||
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
|
||||
let dims = shape.dims();
|
||||
if dims.is_empty() {
|
||||
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
||||
}
|
||||
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
|
||||
crate::bail!(
|
||||
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
||||
T::BLCK_SIZE
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl QTensor {
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
check_shape::<T>(&shape)?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||
let shape = src.shape();
|
||||
check_shape::<T>(shape)?;
|
||||
let src = src
|
||||
.to_dtype(crate::DType::F32)?
|
||||
.flatten_all()?
|
||||
.to_vec1::<f32>()?;
|
||||
if src.len() % T::BLCK_SIZE != 0 {
|
||||
crate::bail!(
|
||||
"tensor size ({shape:?}) is not divisible by block size {}",
|
||||
T::BLCK_SIZE
|
||||
)
|
||||
}
|
||||
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||
T::from_float(&src, &mut data)?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape: shape.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.data.dtype()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.shape.rank()
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||
self.data.to_float(&mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, &self.shape, device)
|
||||
}
|
||||
|
||||
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
self.data.matmul_t(mkn, lhs, dst)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.data.storage_size_in_bytes()
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const u8 {
|
||||
self.data.as_ptr()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||
Self(qtensor)
|
||||
}
|
||||
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
Self(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for QTensor {
|
||||
fn name(&self) -> &'static str {
|
||||
"qmatmul"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
storage: &crate::CpuStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CpuStorage, Shape)> {
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
let (n, k) = self.shape.dims2()?;
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let storage = storage.as_slice::<f32>()?;
|
||||
let storage =
|
||||
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
self.matmul_t(
|
||||
(dst_shape.elem_count() / n, k, n),
|
||||
storage,
|
||||
&mut dst_storage,
|
||||
)?;
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply_op1_no_bwd(self.0.as_ref())
|
||||
}
|
||||
}
|
727
candle-core/src/quantized/neon.rs
Normal file
727
candle-core/src/quantized/neon.rs
Normal file
@ -0,0 +1,727 @@
|
||||
use super::k_quants::{
|
||||
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
|
||||
};
|
||||
use crate::Result;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
#[allow(unused_imports)]
|
||||
#[cfg(target_arch = "arm")]
|
||||
use core::arch::arm::*;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
use core::arch::aarch64::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let m4b = vdupq_n_u8(0x0F);
|
||||
let s8b = vdupq_n_s8(0x8);
|
||||
|
||||
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
||||
let v0_1 = vld1q_u8(x1.qs.as_ptr());
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
||||
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
||||
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
||||
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
||||
|
||||
// sub 8
|
||||
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||
let v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||
let v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
|
||||
// load y
|
||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let v1_1l = vld1q_s8(y1.qs.as_ptr());
|
||||
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
// TODO: Support dotprod when it's available outside of nightly.
|
||||
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||
|
||||
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
|
||||
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
||||
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
|
||||
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
||||
|
||||
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
||||
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
||||
let x1_0 = vld1q_s8(x1.qs.as_ptr());
|
||||
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
|
||||
|
||||
// load y
|
||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let y1_0 = vld1q_s8(y1.qs.as_ptr());
|
||||
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
// TODO dotprod once this is the intrinsics are.
|
||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||
|
||||
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
|
||||
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
||||
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
|
||||
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
||||
|
||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
||||
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(p2, p3)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sum = 0f32;
|
||||
unsafe {
|
||||
let m4b = vdupq_n_u8(0xF);
|
||||
|
||||
let mone = vdupq_n_u8(3);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d_all = x.d.to_f32();
|
||||
|
||||
let mut q6 = x.ql.as_ptr();
|
||||
let mut qh = x.qh.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mut scale = x.scales.as_ptr();
|
||||
|
||||
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
|
||||
let scales = vld1q_s8(scale);
|
||||
let q6scales = int16x8x2_t(
|
||||
vmovl_s8(vget_low_s8(scales)),
|
||||
vmovl_s8(vget_high_s8(scales)),
|
||||
);
|
||||
|
||||
let prod = vaddq_s32(
|
||||
vaddq_s32(
|
||||
vmull_s16(vget_low_s16(q8sums.0), vget_low_s16(q6scales.0)),
|
||||
vmull_s16(vget_high_s16(q8sums.0), vget_high_s16(q6scales.0)),
|
||||
),
|
||||
vaddq_s32(
|
||||
vmull_s16(vget_low_s16(q8sums.1), vget_low_s16(q6scales.1)),
|
||||
vmull_s16(vget_high_s16(q8sums.1), vget_high_s16(q6scales.1)),
|
||||
),
|
||||
);
|
||||
let isum_mins = vaddvq_s32(prod);
|
||||
|
||||
let mut isum = 0i32;
|
||||
|
||||
for _j in 0..QK_K / 128 {
|
||||
let qhbits = vld1q_u8_x2(qh);
|
||||
qh = qh.add(32);
|
||||
let q6bits = vld1q_u8_x4(q6);
|
||||
q6 = q6.add(64);
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
|
||||
let q6h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
|
||||
let q6h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.0, 2);
|
||||
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.1, 2);
|
||||
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
|
||||
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.0, m4b), q6h_0));
|
||||
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.1, m4b), q6h_1));
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
|
||||
let shifted = vshrq_n_u8(qhbits.0, 4);
|
||||
let q6h_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.1, 4);
|
||||
let q6h_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.0, 6);
|
||||
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.1, 6);
|
||||
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
|
||||
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.0, 4), q6h_0));
|
||||
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.1, 4), q6h_1));
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
||||
|
||||
// TODO: dotprod case.
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
}
|
||||
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
||||
}
|
||||
}
|
||||
Ok(sum)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0f32;
|
||||
let mut utmp = [0u32; 4];
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
unsafe {
|
||||
let m4b = vdupq_n_u8(0xF);
|
||||
let mone = vdupq_n_u8(1);
|
||||
let mtwo = vdupq_n_u8(2);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let q8sums = vpaddq_s16(
|
||||
vld1q_s16(y.bsums.as_ptr()),
|
||||
vld1q_s16(y.bsums.as_ptr().add(8)),
|
||||
);
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8));
|
||||
let mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
|
||||
let prod = vaddq_s32(
|
||||
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
|
||||
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
|
||||
);
|
||||
let sumi_mins = vaddvq_s32(prod);
|
||||
|
||||
let mut scales = utmp.as_ptr() as *const u8;
|
||||
|
||||
let mut q5 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mut qhbits = vld1q_u8_x2(x.qh.as_ptr());
|
||||
|
||||
let mut sumi = 0i32;
|
||||
|
||||
for _j in 0..QK_K / 64 {
|
||||
let q5bits = vld1q_u8_x2(q5);
|
||||
q5 = q5.add(32);
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
|
||||
let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
|
||||
let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
|
||||
let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3);
|
||||
let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3);
|
||||
qhbits.0 = vshrq_n_u8(qhbits.0, 2);
|
||||
qhbits.1 = vshrq_n_u8(qhbits.1, 2);
|
||||
|
||||
let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0));
|
||||
let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1));
|
||||
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
||||
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
}
|
||||
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0f32;
|
||||
let mut utmp = [0u32; 4];
|
||||
let mut scales = [0u8; 16];
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
unsafe {
|
||||
let m4b = vdupq_n_u8(0xF);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let q8sums = vpaddq_s16(
|
||||
vld1q_s16(y.bsums.as_ptr()),
|
||||
vld1q_s16(y.bsums.as_ptr().add(8)),
|
||||
);
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
let mins8 = vld1_u32(
|
||||
[
|
||||
utmp[1] & KMASK1,
|
||||
((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4),
|
||||
]
|
||||
.as_ptr(),
|
||||
);
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
||||
let prod = vaddq_s32(
|
||||
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
|
||||
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
|
||||
);
|
||||
sumf -= dmin * vaddvq_s32(prod) as f32;
|
||||
|
||||
LittleEndian::write_u32_into(&utmp, &mut scales);
|
||||
|
||||
let mut q4 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mut sumi1 = 0i32;
|
||||
let mut sumi2 = 0i32;
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4bits = vld1q_u8_x2(q4);
|
||||
q4 = q4.add(32);
|
||||
// TODO: dotprod
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let q4bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
||||
);
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let q4bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
|
||||
}
|
||||
sumf += d * (sumi1 + sumi2) as f32;
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0f32;
|
||||
let mut utmp = [0u32; 4];
|
||||
let mut aux = [0u32; 3];
|
||||
const KMASK1: u32 = 0x03030303;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
|
||||
unsafe {
|
||||
let m3b = vdupq_n_u8(0x3);
|
||||
let m0 = vdupq_n_u8(1);
|
||||
let m1 = vshlq_n_u8(m0, 1);
|
||||
let m2 = vshlq_n_u8(m0, 2);
|
||||
let m3 = vshlq_n_u8(m0, 3);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let mut q3 = x.qs.as_ptr();
|
||||
let qh = x.hmask.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mut qhbits = vld1q_u8_x2(qh);
|
||||
|
||||
let mut isum = 0i32;
|
||||
|
||||
// Set up scales
|
||||
LittleEndian::read_u32_into(&x.scales, &mut aux);
|
||||
|
||||
utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4);
|
||||
utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4);
|
||||
utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4);
|
||||
utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4);
|
||||
|
||||
let mut scale = utmp.as_mut_ptr() as *mut i8;
|
||||
for j in 0..16 {
|
||||
*scale.add(j) -= 32i8
|
||||
}
|
||||
|
||||
for j in 0..QK_K / 128 {
|
||||
let q3bits = vld1q_u8_x2(q3);
|
||||
q3 = q3.add(32);
|
||||
let q8bytes_1 = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
let q8bytes_2 = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
|
||||
let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2);
|
||||
let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2);
|
||||
let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1);
|
||||
let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1);
|
||||
|
||||
let q3bytes_0 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)),
|
||||
vreinterpretq_s8_u8(q3h_0),
|
||||
);
|
||||
let q3bytes_1 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)),
|
||||
vreinterpretq_s8_u8(q3h_1),
|
||||
);
|
||||
let q3bytes_2 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_2),
|
||||
);
|
||||
let q3bytes_3 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
||||
let q3h_1 = vbicq_u8(m2, qhbits.1);
|
||||
let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1);
|
||||
let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1);
|
||||
|
||||
let q3bytes_0 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_0),
|
||||
);
|
||||
let q3bytes_1 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_1),
|
||||
);
|
||||
let q3bytes_2 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_2),
|
||||
);
|
||||
let q3bytes_3 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
if j == 0 {
|
||||
qhbits.0 = vshrq_n_u8(qhbits.0, 4);
|
||||
qhbits.1 = vshrq_n_u8(qhbits.1, 4);
|
||||
}
|
||||
}
|
||||
sumf += d * isum as f32;
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0f32;
|
||||
let mut aux = [0u8; 16];
|
||||
|
||||
unsafe {
|
||||
let m3 = vdupq_n_u8(0x3);
|
||||
let m4 = vdupq_n_u8(0xF);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = -y.d * x.dmin.to_f32();
|
||||
|
||||
let mut q2 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
let sc = x.scales.as_ptr();
|
||||
|
||||
let mins_and_scales = vld1q_u8(sc);
|
||||
let scales = vandq_u8(mins_and_scales, m4);
|
||||
vst1q_u8(aux.as_mut_ptr(), scales);
|
||||
|
||||
let mins = vshrq_n_u8(mins_and_scales, 4);
|
||||
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
|
||||
let mins16 = int16x8x2_t(
|
||||
vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))),
|
||||
vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))),
|
||||
);
|
||||
let s0 = vaddq_s32(
|
||||
vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)),
|
||||
vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)),
|
||||
);
|
||||
let s1 = vaddq_s32(
|
||||
vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)),
|
||||
vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)),
|
||||
);
|
||||
sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32;
|
||||
|
||||
let mut isum = 0i32;
|
||||
let mut is = 0usize;
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
for _j in 0..QK_K / 128 {
|
||||
let q2bits = vld1q_u8_x2(q2);
|
||||
q2 = q2.add(32);
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let mut q2bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)),
|
||||
);
|
||||
isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes);
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3));
|
||||
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3));
|
||||
isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes);
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3));
|
||||
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3));
|
||||
isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes);
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3));
|
||||
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3));
|
||||
isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes);
|
||||
|
||||
is += 8;
|
||||
}
|
||||
sumf += d * isum as f32;
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn multiply_accum_with_scale(
|
||||
aux: &[u8; 16],
|
||||
is: usize,
|
||||
index: usize,
|
||||
q2bytes: int8x16x2_t,
|
||||
q8bytes: int8x16x2_t,
|
||||
) -> i32 {
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
vaddvq_s16(p1) as i32 * aux[is + index] as i32
|
||||
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
|
||||
}
|
326
candle-core/src/quantized/utils.rs
Normal file
326
candle-core/src/quantized/utils.rs
Normal file
@ -0,0 +1,326 @@
|
||||
use crate::Result;
|
||||
|
||||
pub(super) fn nearest_int(v: f32) -> i32 {
|
||||
v.round() as i32
|
||||
}
|
||||
|
||||
/// Validates that the input and output are the right size and returns an iterator which maps each
|
||||
/// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed
|
||||
/// to be `T::BLCK_SIZE` long.
|
||||
pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
xs: &'b [f32],
|
||||
ys: &'a mut [T],
|
||||
) -> Result<Vec<(&'a mut T, &'b [f32])>> {
|
||||
let block_size = T::BLCK_SIZE;
|
||||
let dtype = T::DTYPE;
|
||||
|
||||
let expected_blocks = xs.len() / block_size;
|
||||
let actual_blocks = ys.len();
|
||||
|
||||
//validate that the input is the right size
|
||||
if expected_blocks != actual_blocks {
|
||||
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
||||
}
|
||||
|
||||
Ok(ys.iter_mut().zip(xs.chunks_exact(block_size)).collect())
|
||||
}
|
||||
|
||||
/// Validates that the input and output are the right size and returns an iterator which maps each
|
||||
/// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed
|
||||
/// to be `T::BLCK_SIZE` long.
|
||||
pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
xs: &'a [T],
|
||||
ys: &'b mut [f32],
|
||||
) -> Result<Vec<(&'a T, &'b mut [f32])>> {
|
||||
let block_size = T::BLCK_SIZE;
|
||||
let dtype = T::DTYPE;
|
||||
|
||||
let actual_output_len = ys.len();
|
||||
let expected_output_len = xs.len() * block_size;
|
||||
//validate that the output is the right size
|
||||
if expected_output_len != actual_output_len {
|
||||
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
|
||||
}
|
||||
|
||||
//zip the blocks and outputs together
|
||||
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
|
||||
}
|
||||
|
||||
pub(super) fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
|
||||
if j < 4 {
|
||||
let d = q[j] & 63;
|
||||
let m = q[j + 4] & 63;
|
||||
(d, m)
|
||||
} else {
|
||||
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
|
||||
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
|
||||
(d, m)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) unsafe fn make_qx_quants(
|
||||
n: usize,
|
||||
nmax: i32,
|
||||
x: *const f32,
|
||||
ls: *mut i8,
|
||||
rmse_type: i32,
|
||||
) -> f32 {
|
||||
let mut max = 0f32;
|
||||
let mut amax = 0f32;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let ax = x.abs();
|
||||
if ax > amax {
|
||||
amax = ax;
|
||||
max = x;
|
||||
}
|
||||
}
|
||||
if amax == 0. {
|
||||
// all zero
|
||||
for i in 0..n {
|
||||
*ls.add(i) = 0;
|
||||
}
|
||||
return 0.;
|
||||
}
|
||||
let mut iscale = -(nmax as f32) / max;
|
||||
if rmse_type == 0 {
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
|
||||
}
|
||||
return 1.0 / iscale;
|
||||
}
|
||||
let weight_type = rmse_type % 2;
|
||||
let mut sumlx = 0f32;
|
||||
let mut suml2 = 0f32;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
let l = l.clamp(-nmax, nmax - 1);
|
||||
*ls.add(i) = (l + nmax) as i8;
|
||||
let w = if weight_type == 1 { x * x } else { 1.0 };
|
||||
let l = l as f32;
|
||||
sumlx += w * x * l;
|
||||
suml2 += w * l * l;
|
||||
}
|
||||
let mut scale = sumlx / suml2;
|
||||
let mut best = scale * sumlx;
|
||||
for _itry in 0..3 {
|
||||
let iscale = 1.0 / scale;
|
||||
let mut slx = 0f32;
|
||||
let mut sl2 = 0f32;
|
||||
let mut changed = false;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
let l = l.clamp(-nmax, nmax - 1);
|
||||
if l + nmax != *ls.add(i) as i32 {
|
||||
changed = true;
|
||||
}
|
||||
let w = if weight_type == 1 { x * x } else { 1f32 };
|
||||
let l = l as f32;
|
||||
slx += w * x * l;
|
||||
sl2 += w * l * l;
|
||||
}
|
||||
if !changed || sl2 == 0.0 || slx * slx <= best * sl2 {
|
||||
break;
|
||||
}
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
|
||||
}
|
||||
sumlx = slx;
|
||||
suml2 = sl2;
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
}
|
||||
for _itry in 0..5 {
|
||||
let mut n_changed = 0;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let w = if weight_type == 1 { x * x } else { 1. };
|
||||
let l = *ls.add(i) as i32 - nmax;
|
||||
let mut slx = sumlx - w * x * l as f32;
|
||||
if slx > 0. {
|
||||
let mut sl2 = suml2 - w * l as f32 * l as f32;
|
||||
let new_l = nearest_int(x * sl2 / slx);
|
||||
let new_l = new_l.clamp(-nmax, nmax - 1);
|
||||
if new_l != l {
|
||||
slx += w * x * new_l as f32;
|
||||
sl2 += w * new_l as f32 * new_l as f32;
|
||||
if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 {
|
||||
*ls.add(i) = (nmax + new_l) as i8;
|
||||
sumlx = slx;
|
||||
suml2 = sl2;
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
n_changed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if n_changed == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if rmse_type < 3 {
|
||||
return scale;
|
||||
}
|
||||
for is in -4..4 {
|
||||
if is == 0 {
|
||||
continue;
|
||||
}
|
||||
iscale = -(nmax as f32 + 0.1f32 * is as f32) / max;
|
||||
let mut sumlx = 0.;
|
||||
let mut suml2 = 0.;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
let l = l.clamp(-nmax, nmax - 1);
|
||||
let w = if weight_type == 1 { x * x } else { 1. };
|
||||
let l = l as f32;
|
||||
sumlx += w * x * l;
|
||||
suml2 += w * l * l;
|
||||
}
|
||||
if suml2 > 0. && sumlx * sumlx > best * suml2 {
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
|
||||
}
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
}
|
||||
}
|
||||
scale
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L224
|
||||
pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) {
|
||||
let n = x.len();
|
||||
let mut l = vec![0; n];
|
||||
// Get min/max
|
||||
let min = *x
|
||||
.iter()
|
||||
.take(n)
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&x[0]);
|
||||
let max = *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap_or(&x[0]);
|
||||
|
||||
// If min == max, all values are the same => nothing to do here
|
||||
if max == min {
|
||||
return (0.0, 0.0);
|
||||
}
|
||||
|
||||
// Ensure min <= 0.0
|
||||
let mut min = min.min(0.);
|
||||
|
||||
// Compute scale and inverse scale
|
||||
let mut iscale = nmax as f32 / (max - min);
|
||||
let mut scale = 1.0 / iscale;
|
||||
|
||||
for _ in 0..ntry {
|
||||
let mut sumlx = 0.0;
|
||||
let mut suml2 = 0;
|
||||
let mut did_change = false;
|
||||
|
||||
for (i, value) in x.iter().enumerate().take(n) {
|
||||
let li = nearest_int(iscale * (value - min)).clamp(0, nmax);
|
||||
let clamped_li = li as u8;
|
||||
if clamped_li != l[i] {
|
||||
l[i] = clamped_li;
|
||||
did_change = true;
|
||||
}
|
||||
sumlx += (value - min) * li as f32;
|
||||
suml2 += li * li;
|
||||
}
|
||||
scale = sumlx / suml2 as f32;
|
||||
|
||||
let sum: f32 = x
|
||||
.iter()
|
||||
.take(n)
|
||||
.zip(l.iter().take(n))
|
||||
.map(|(xi, &li)| xi - scale * li as f32)
|
||||
.sum();
|
||||
|
||||
min = sum / n as f32;
|
||||
if min > 0.0 {
|
||||
min = 0.0;
|
||||
}
|
||||
iscale = 1.0 / scale;
|
||||
if !did_change {
|
||||
break;
|
||||
}
|
||||
}
|
||||
(scale, -min)
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L165
|
||||
pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 {
|
||||
let n = x.len();
|
||||
let mut l = vec![0i8; n];
|
||||
|
||||
let mut max = 0.0;
|
||||
let mut amax = 0.0;
|
||||
for &xi in x.iter().take(n) {
|
||||
let ax = xi.abs();
|
||||
if ax > amax {
|
||||
amax = ax;
|
||||
max = xi;
|
||||
}
|
||||
}
|
||||
|
||||
if amax == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let iscale = -(nmax as f32) / max;
|
||||
if do_rmse {
|
||||
let mut sumlx = 0.0;
|
||||
let mut suml2 = 0.0;
|
||||
for i in 0..n {
|
||||
let li = (iscale * x[i]).round() as i32;
|
||||
let li = li.clamp(-nmax, nmax - 1);
|
||||
l[i] = li as i8;
|
||||
let w = x[i] * x[i];
|
||||
sumlx += w * x[i] * li as f32;
|
||||
suml2 += w * (li * li) as f32;
|
||||
}
|
||||
for _ in 0..5 {
|
||||
let mut n_changed = 0;
|
||||
for i in 0..n {
|
||||
let w = x[i] * x[i];
|
||||
let mut slx = sumlx - w * x[i] * l[i] as f32;
|
||||
if slx > 0.0 {
|
||||
let mut sl2 = suml2 - w * (l[i] as i32 * l[i] as i32) as f32;
|
||||
let mut new_l = (x[i] * sl2 / slx).round() as i32;
|
||||
new_l = new_l.clamp(-nmax, nmax - 1);
|
||||
if new_l != l[i] as i32 {
|
||||
slx += w * x[i] * new_l as f32;
|
||||
sl2 += w * (new_l * new_l) as f32;
|
||||
if sl2 > 0.0 && slx * slx * suml2 > sumlx * sumlx * sl2 {
|
||||
l[i] = new_l as i8;
|
||||
sumlx = slx;
|
||||
suml2 = sl2;
|
||||
n_changed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if n_changed == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for li in l.iter_mut() {
|
||||
*li += nmax as i8;
|
||||
}
|
||||
return sumlx / suml2;
|
||||
}
|
||||
for i in 0..n {
|
||||
let li = (iscale * x[i]).round() as i32;
|
||||
l[i] = (li.clamp(-nmax, nmax - 1) + nmax) as i8;
|
||||
}
|
||||
1.0 / iscale
|
||||
}
|
@ -10,6 +10,7 @@ impl From<DType> for st::Dtype {
|
||||
match value {
|
||||
DType::U8 => st::Dtype::U8,
|
||||
DType::U32 => st::Dtype::U32,
|
||||
DType::I64 => st::Dtype::I64,
|
||||
DType::BF16 => st::Dtype::BF16,
|
||||
DType::F16 => st::Dtype::F16,
|
||||
DType::F32 => st::Dtype::F32,
|
||||
@ -24,6 +25,7 @@ impl TryFrom<st::Dtype> for DType {
|
||||
match value {
|
||||
st::Dtype::U8 => Ok(DType::U8),
|
||||
st::Dtype::U32 => Ok(DType::U32),
|
||||
st::Dtype::I64 => Ok(DType::I64),
|
||||
st::Dtype::BF16 => Ok(DType::BF16),
|
||||
st::Dtype::F16 => Ok(DType::F16),
|
||||
st::Dtype::F32 => Ok(DType::F32),
|
||||
@ -189,6 +191,7 @@ impl Tensor {
|
||||
match dtype {
|
||||
DType::U8 => convert_slice::<u8>(data, shape, device),
|
||||
DType::U32 => convert_slice::<u32>(data, shape, device),
|
||||
DType::I64 => convert_slice::<i64>(data, shape, device),
|
||||
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
|
||||
DType::F16 => convert_slice::<half::f16>(data, shape, device),
|
||||
DType::F32 => convert_slice::<f32>(data, shape, device),
|
||||
@ -205,24 +208,15 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
convert_with_cast_::<u16, u32, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::U32 => convert_::<u32>(view, device),
|
||||
st::Dtype::I32 => {
|
||||
let conv = |x| Ok(i64::from(x));
|
||||
convert_with_cast_::<i32, i64, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::I64 => convert_::<i64>(view, device),
|
||||
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
|
||||
st::Dtype::F16 => convert_::<half::f16>(view, device),
|
||||
st::Dtype::F32 => convert_::<f32>(view, device),
|
||||
st::Dtype::F64 => convert_::<f64>(view, device),
|
||||
st::Dtype::I32 => {
|
||||
let conv = |x| {
|
||||
u32::try_from(x)
|
||||
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
|
||||
};
|
||||
convert_with_cast_::<i32, u32, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::I64 => {
|
||||
let conv = |x| {
|
||||
u32::try_from(x)
|
||||
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
|
||||
};
|
||||
convert_with_cast_::<i64, u32, _>(view, device, conv)
|
||||
}
|
||||
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
|
||||
}
|
||||
}
|
||||
@ -233,6 +227,7 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
match tensor.dtype() {
|
||||
DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
|
||||
DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
|
||||
DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
|
||||
DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
|
||||
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
|
||||
DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
|
||||
@ -242,18 +237,28 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
|
||||
pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
|
||||
let data = std::fs::read(filename.as_ref())?;
|
||||
let st = safetensors::SafeTensors::deserialize(&data)?;
|
||||
load_buffer(&data[..], device)
|
||||
}
|
||||
|
||||
pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
|
||||
let st = safetensors::SafeTensors::deserialize(data)?;
|
||||
st.tensors()
|
||||
.into_iter()
|
||||
.map(|(name, view)| Ok((name, view.load(device)?)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Result<()> {
|
||||
pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
|
||||
tensors: &HashMap<K, Tensor>,
|
||||
filename: P,
|
||||
) -> Result<()> {
|
||||
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
|
||||
}
|
||||
|
||||
pub struct MmapedFile(memmap2::Mmap);
|
||||
pub struct MmapedFile {
|
||||
path: std::path::PathBuf,
|
||||
inner: memmap2::Mmap,
|
||||
}
|
||||
|
||||
impl MmapedFile {
|
||||
/// Creates a wrapper around a memory mapped file from which you can retrieve
|
||||
@ -263,13 +268,20 @@ impl MmapedFile {
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let mmap = memmap2::MmapOptions::new().map(&file)?;
|
||||
Ok(Self(mmap))
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let inner = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok(Self {
|
||||
inner,
|
||||
path: p.to_path_buf(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
|
||||
let st = safetensors::SafeTensors::deserialize(&self.0)?;
|
||||
let st = safetensors::SafeTensors::deserialize(&self.inner)
|
||||
.map_err(|e| Error::from(e).with_path(&self.path))?;
|
||||
Ok(st)
|
||||
}
|
||||
}
|
||||
|
23
candle-core/src/scalar.rs
Normal file
23
candle-core/src/scalar.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use crate::{Result, Tensor, WithDType};
|
||||
|
||||
pub enum TensorScalar {
|
||||
Tensor(Tensor),
|
||||
Scalar(Tensor),
|
||||
}
|
||||
|
||||
pub trait TensorOrScalar {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar>;
|
||||
}
|
||||
|
||||
impl TensorOrScalar for &Tensor {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
||||
Ok(TensorScalar::Tensor(self.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: WithDType> TensorOrScalar for T {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
||||
let scalar = Tensor::new(self, &crate::Device::Cpu)?;
|
||||
Ok(TensorScalar::Scalar(scalar))
|
||||
}
|
||||
}
|
@ -1,3 +1,5 @@
|
||||
//! The shape of a tensor is a tuple with the size of each of its dimensions.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{Error, Result};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
@ -71,6 +73,14 @@ impl From<(usize, usize, usize, usize, usize)> for Shape {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
|
||||
fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![
|
||||
d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<usize>> for Shape {
|
||||
fn from(dims: Vec<usize>) -> Self {
|
||||
Self(dims)
|
||||
@ -79,20 +89,25 @@ impl From<Vec<usize>> for Shape {
|
||||
|
||||
macro_rules! extract_dims {
|
||||
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
|
||||
impl Shape {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
if self.0.len() != $cnt {
|
||||
pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
|
||||
if dims.len() != $cnt {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: $cnt,
|
||||
got: self.0.len(),
|
||||
shape: self.clone(),
|
||||
got: dims.len(),
|
||||
shape: Shape::from(dims),
|
||||
}
|
||||
.bt())
|
||||
} else {
|
||||
Ok($dims(&self.0))
|
||||
Ok($dims(dims))
|
||||
}
|
||||
}
|
||||
|
||||
impl Shape {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
$fn_name(self.0.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Tensor {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
self.shape().$fn_name()
|
||||
@ -113,6 +128,7 @@ impl Shape {
|
||||
Self(dims.to_vec())
|
||||
}
|
||||
|
||||
/// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
|
||||
pub fn rank(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
@ -121,10 +137,12 @@ impl Shape {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// The dimensions as a slice of `usize`.
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// The total number of elements, this is the product of all dimension sizes.
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
}
|
||||
@ -176,10 +194,75 @@ impl Shape {
|
||||
true
|
||||
}
|
||||
|
||||
/// Modifies the shape by adding a list of additional dimensions at the end of the existing
|
||||
/// dimensions.
|
||||
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
|
||||
self.0.extend(additional_dims);
|
||||
self
|
||||
}
|
||||
|
||||
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
||||
/// broadcasted shape. This is to be used for binary pointwise ops.
|
||||
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
let lhs_ndims = lhs_dims.len();
|
||||
let rhs_ndims = rhs_dims.len();
|
||||
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
|
||||
let mut bcast_dims = vec![0; bcast_ndims];
|
||||
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
|
||||
let rev_idx = bcast_ndims - idx;
|
||||
let l_value = if lhs_ndims < rev_idx {
|
||||
1
|
||||
} else {
|
||||
lhs_dims[lhs_ndims - rev_idx]
|
||||
};
|
||||
let r_value = if rhs_ndims < rev_idx {
|
||||
1
|
||||
} else {
|
||||
rhs_dims[rhs_ndims - rev_idx]
|
||||
};
|
||||
*bcast_value = if l_value == r_value {
|
||||
l_value
|
||||
} else if l_value == 1 {
|
||||
r_value
|
||||
} else if r_value == 1 {
|
||||
l_value
|
||||
} else {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: lhs.clone(),
|
||||
rhs: rhs.clone(),
|
||||
op,
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
Ok(Shape::from(bcast_dims))
|
||||
}
|
||||
|
||||
pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
if lhs_dims.len() < 2 || rhs_dims.len() < 2 {
|
||||
crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}")
|
||||
}
|
||||
let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]);
|
||||
let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]);
|
||||
if lhs_k != rhs_k {
|
||||
crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}")
|
||||
}
|
||||
|
||||
let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]);
|
||||
let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]);
|
||||
let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?;
|
||||
let bcast_dims = bcast.dims();
|
||||
|
||||
let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat();
|
||||
let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat();
|
||||
Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs)))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Dim {
|
||||
@ -340,7 +423,28 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
||||
}
|
||||
}
|
||||
|
||||
extract_dims!(dims0, 0, |_: &Vec<usize>| (), ());
|
||||
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
let d3 = self.3.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2, d3])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
let d3 = self.3.to_index(shape, op)?;
|
||||
let d4 = self.4.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2, d3, d4])
|
||||
}
|
||||
}
|
||||
|
||||
extract_dims!(dims0, 0, |_: &[usize]| (), ());
|
||||
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
extract_dims!(
|
||||
@ -378,3 +482,171 @@ mod tests {
|
||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ShapeWithOneHole {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape>;
|
||||
}
|
||||
|
||||
impl<S: Into<Shape>> ShapeWithOneHole for S {
|
||||
fn into_shape(self, _el_count: usize) -> Result<Shape> {
|
||||
Ok(self.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((),) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
Ok(el_count.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1) = self;
|
||||
if el_count % d1 != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||
}
|
||||
Ok((el_count / d1, d1).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, ()) = self;
|
||||
if el_count % d1 != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||
}
|
||||
Ok((d1, el_count / d1).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2) = self;
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2) = self;
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, ()) = self;
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, ()) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3, d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3, d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3, d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, (), d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, el_count / d, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, d4, ()) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, d4, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
@ -68,6 +68,19 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
@ -138,7 +151,7 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
||||
pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
||||
match self {
|
||||
Self::Cpu(storage) => {
|
||||
let (storage, shape) = c.cpu_fwd(storage, l)?;
|
||||
@ -151,7 +164,7 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn custom_op2(
|
||||
pub(crate) fn apply_op2(
|
||||
&self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
@ -172,7 +185,7 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn custom_op3(
|
||||
pub(crate) fn apply_op3(
|
||||
&self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
@ -266,6 +279,109 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
self.same_device(kernel, "conv2d")?;
|
||||
self.same_dtype(kernel, "conv2d")?;
|
||||
match (self, &kernel) {
|
||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cpu(s))
|
||||
}
|
||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "conv2d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
self.same_device(kernel, "conv_transpose2d")?;
|
||||
self.same_dtype(kernel, "conv_transpose2d")?;
|
||||
match (self, &kernel) {
|
||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cpu(s))
|
||||
}
|
||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "conv_transpose2d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn avg_pool2d(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
|
@ -1,7 +1,10 @@
|
||||
//! Tensors are N-dimenional matrixes of elements using a single data type.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{
|
||||
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
||||
};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
@ -106,7 +109,9 @@ macro_rules! broadcast_binary_op {
|
||||
($fn_name:ident, $inner_fn_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let lhs = self;
|
||||
let shape = lhs.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let shape = lhs
|
||||
.shape()
|
||||
.broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
|
||||
let l_broadcast = shape != *lhs.shape();
|
||||
let r_broadcast = shape != *rhs.shape();
|
||||
match (l_broadcast, r_broadcast) {
|
||||
@ -122,7 +127,7 @@ macro_rules! broadcast_binary_op {
|
||||
}
|
||||
|
||||
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
|
||||
fn from_storage<S: Into<Shape>>(
|
||||
pub(crate) fn from_storage<S: Into<Shape>>(
|
||||
storage: Storage,
|
||||
shape: S,
|
||||
op: BackpropOp,
|
||||
@ -269,6 +274,10 @@ impl Tensor {
|
||||
Self::rand_impl(lo, up, s, device, false)
|
||||
}
|
||||
|
||||
pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
|
||||
Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
|
||||
}
|
||||
|
||||
pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
|
||||
mean: T,
|
||||
std: T,
|
||||
@ -296,6 +305,17 @@ impl Tensor {
|
||||
Ok(from_storage(storage, s, none, is_variable))
|
||||
}
|
||||
|
||||
pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
|
||||
Tensor::randn_f64_impl(
|
||||
mean,
|
||||
stdev,
|
||||
self.shape(),
|
||||
self.dtype(),
|
||||
self.device(),
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values sampled from a normal distribution with the
|
||||
/// specified `mean` and standard deviation `std`.
|
||||
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
|
||||
@ -400,48 +420,6 @@ impl Tensor {
|
||||
Self::new_impl(array, shape.into(), device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn broadcast_shape_binary_op<'a>(
|
||||
&'a self,
|
||||
rhs: &'a Self,
|
||||
op: &'static str,
|
||||
) -> Result<Shape> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.shape().dims();
|
||||
let rhs_dims = rhs.shape().dims();
|
||||
let lhs_ndims = lhs_dims.len();
|
||||
let rhs_ndims = rhs_dims.len();
|
||||
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
|
||||
let mut bcast_dims = vec![0; bcast_ndims];
|
||||
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
|
||||
let rev_idx = bcast_ndims - idx;
|
||||
let l_value = if lhs_ndims < rev_idx {
|
||||
1
|
||||
} else {
|
||||
lhs_dims[lhs_ndims - rev_idx]
|
||||
};
|
||||
let r_value = if rhs_ndims < rev_idx {
|
||||
1
|
||||
} else {
|
||||
rhs_dims[rhs_ndims - rev_idx]
|
||||
};
|
||||
*bcast_value = if l_value == r_value {
|
||||
l_value
|
||||
} else if l_value == 1 {
|
||||
r_value
|
||||
} else if r_value == 1 {
|
||||
l_value
|
||||
} else {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op,
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
Ok(Shape::from(bcast_dims))
|
||||
}
|
||||
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||
let lhs = self.shape();
|
||||
let rhs = rhs.shape();
|
||||
@ -469,16 +447,22 @@ impl Tensor {
|
||||
binary_op!(mul, Mul);
|
||||
binary_op!(sub, Sub);
|
||||
binary_op!(div, Div);
|
||||
binary_op!(maximum, Maximum);
|
||||
binary_op!(minimum, Minimum);
|
||||
broadcast_binary_op!(broadcast_add, add);
|
||||
broadcast_binary_op!(broadcast_mul, mul);
|
||||
broadcast_binary_op!(broadcast_sub, sub);
|
||||
broadcast_binary_op!(broadcast_div, div);
|
||||
broadcast_binary_op!(broadcast_maximum, maximum);
|
||||
broadcast_binary_op!(broadcast_minimum, minimum);
|
||||
|
||||
unary_op!(recip, Recip);
|
||||
unary_op!(neg, Neg);
|
||||
unary_op!(exp, Exp);
|
||||
unary_op!(log, Log);
|
||||
unary_op!(sin, Sin);
|
||||
unary_op!(cos, Cos);
|
||||
unary_op!(tanh, Tanh);
|
||||
unary_op!(abs, Abs);
|
||||
unary_op!(sqr, Sqr);
|
||||
unary_op!(sqrt, Sqrt);
|
||||
@ -511,6 +495,25 @@ impl Tensor {
|
||||
self.to_scalar::<S>()
|
||||
}
|
||||
|
||||
/// Repeat this tensor along the specified dimensions.
|
||||
pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
||||
// Similar to PyTorch, we extend the number of dimensions of self if needed.
|
||||
let repeats = shape.into();
|
||||
let repeats = repeats.dims();
|
||||
let mut inp = if self.rank() < repeats.len() {
|
||||
let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
|
||||
self.reshape(shape)?
|
||||
} else {
|
||||
self.clone()
|
||||
};
|
||||
for (idx, &repeat) in repeats.iter().enumerate() {
|
||||
if repeat > 1 {
|
||||
inp = Tensor::cat(&vec![&inp; repeat], idx)?
|
||||
}
|
||||
}
|
||||
Ok(inp)
|
||||
}
|
||||
|
||||
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
|
||||
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
|
||||
/// be performed.
|
||||
@ -535,6 +538,13 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Raise the tensor to some float exponent `e`.
|
||||
pub fn powf(&self, e: f64) -> Result<Self> {
|
||||
let storage = self.storage().powf(self.layout(), e)?;
|
||||
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
if dim >= self.dims().len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
@ -548,6 +558,32 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a tensor into the specified number of chunks, this may return less chunks than
|
||||
/// specificed.
|
||||
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
|
||||
let dim = dim.to_index(self.shape(), "chunk")?;
|
||||
let size = self.dim(dim)?;
|
||||
if size < chunks {
|
||||
(0..size).map(|i| self.narrow(dim, i, 1)).collect()
|
||||
} else {
|
||||
let chunk_size = size / chunks;
|
||||
let cnt_additional = size % chunks;
|
||||
let mut tensors = vec![];
|
||||
let mut sum_chunk_size = 0;
|
||||
for i in 0..chunks {
|
||||
let chunk_size = if i < cnt_additional {
|
||||
chunk_size + 1
|
||||
} else {
|
||||
chunk_size
|
||||
};
|
||||
let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
|
||||
tensors.push(tensor);
|
||||
sum_chunk_size += chunk_size
|
||||
}
|
||||
Ok(tensors)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + len`.
|
||||
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
||||
@ -663,18 +699,58 @@ impl Tensor {
|
||||
self.sum_impl(sum_dims, false)
|
||||
}
|
||||
|
||||
/// Returns the mean of all elements in the input tensor. The mean is performed over all the
|
||||
/// input dimensions.
|
||||
///
|
||||
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
|
||||
/// that the number of elements for each dimension index in `mean_dims` is 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
|
||||
/// let s = a.mean_keepdim(0)?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[1., 2.]]);
|
||||
/// let s = a.mean_keepdim(1)?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[0.5], [2.5]]);
|
||||
/// let s = a.mean_keepdim((0, 1))?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.5]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
|
||||
let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
|
||||
let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
|
||||
let scale = 1f64 / (reduced_dim as f64);
|
||||
self.sum_impl(mean_dims, true)? * scale
|
||||
}
|
||||
|
||||
/// Returns the mean of all elements in the input tensor. The mean is performed over all the
|
||||
/// input dimensions and compared to `mean_keepdim` these dimensions are squeezed rather than
|
||||
/// kept.
|
||||
pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
|
||||
let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
|
||||
let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
|
||||
let scale = 1f64 / (reduced_dim as f64);
|
||||
self.sum_impl(mean_dims, false)? * scale
|
||||
}
|
||||
|
||||
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
||||
/// number of dimensions as the original tensor and the select dimension has a single element.
|
||||
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, true, ReduceOp::Max)
|
||||
}
|
||||
|
||||
/// Similar to `max_keepdim` but the target dimension is squeezed.
|
||||
pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, false, ReduceOp::Max)
|
||||
}
|
||||
|
||||
/// Gathers the minimum value across the selected dimension. The resulting shape has the same
|
||||
/// number of dimensions as the original tensor and the select dimension has a single element.
|
||||
pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, true, ReduceOp::Min)
|
||||
}
|
||||
|
||||
/// Similar to `min_keepdim` but the target dimension is squeezed.
|
||||
pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, false, ReduceOp::Min)
|
||||
}
|
||||
@ -683,6 +759,7 @@ impl Tensor {
|
||||
self.reduce_impl(dim, true, ReduceOp::ArgMax)
|
||||
}
|
||||
|
||||
/// Similar to `argmax_keepdim` but the target dimension is squeezed.
|
||||
pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, false, ReduceOp::ArgMax)
|
||||
}
|
||||
@ -691,12 +768,24 @@ impl Tensor {
|
||||
self.reduce_impl(dim, true, ReduceOp::ArgMin)
|
||||
}
|
||||
|
||||
/// Similar to `argmin_keepdim` but the target dimension is squeezed.
|
||||
pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, false, ReduceOp::ArgMin)
|
||||
}
|
||||
|
||||
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, "cmp")?;
|
||||
/// Element-wise comparison between two tensors, e.g. equality, greater than, ... The actual
|
||||
/// comparison operation is specified by the `op` argument.
|
||||
///
|
||||
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
|
||||
pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
|
||||
let rhs = match rhs.to_tensor_scalar()? {
|
||||
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
|
||||
crate::scalar::TensorScalar::Scalar(rhs) => rhs
|
||||
.to_dtype(self.dtype())?
|
||||
.to_device(self.device())?
|
||||
.broadcast_as(self.shape())?,
|
||||
};
|
||||
let shape = self.same_shape_binary_op(&rhs, "cmp")?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
|
||||
@ -704,75 +793,122 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape.dims(), op, false))
|
||||
}
|
||||
|
||||
pub fn eq(&self, rhs: &Self) -> Result<Self> {
|
||||
/// Element-wise equality.
|
||||
pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Eq)
|
||||
}
|
||||
|
||||
pub fn ne(&self, rhs: &Self) -> Result<Self> {
|
||||
/// Element-wise non-equality.
|
||||
pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Ne)
|
||||
}
|
||||
|
||||
pub fn lt(&self, rhs: &Self) -> Result<Self> {
|
||||
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Lt)
|
||||
}
|
||||
|
||||
pub fn gt(&self, rhs: &Self) -> Result<Self> {
|
||||
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Gt)
|
||||
}
|
||||
|
||||
pub fn ge(&self, rhs: &Self) -> Result<Self> {
|
||||
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Ge)
|
||||
}
|
||||
|
||||
pub fn le(&self, rhs: &Self) -> Result<Self> {
|
||||
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Le)
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = match *self.dims() {
|
||||
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
|
||||
[c_in, l_in] => (None, c_in, l_in),
|
||||
_ => Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
k_shape: kernel.shape().clone(),
|
||||
padding,
|
||||
stride,
|
||||
msg: "input rank is not 2 or 3",
|
||||
/// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
|
||||
/// nearest element.
|
||||
///
|
||||
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
||||
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
||||
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
let (n, c, _h, _w) = self.dims4()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest2d(self.layout(), target_h, target_w)?;
|
||||
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
if c_in != c_in_k {
|
||||
Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
k_shape: kernel.shape().clone(),
|
||||
padding,
|
||||
stride,
|
||||
msg: "the number of in-channels on the input doesn't match the kernel size",
|
||||
|
||||
/// 2D average pooling over an input tensor with multiple channels.
|
||||
///
|
||||
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
||||
/// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
|
||||
/// the two last dimensions using a kernel of size `sz`. The returned element is the average
|
||||
/// value over the kernel window.
|
||||
pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
|
||||
let sz = sz.to_usize2();
|
||||
self.avg_pool2d_with_stride(sz, sz)
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let params = crate::conv::ParamsConv1D {
|
||||
b_size,
|
||||
l_in,
|
||||
c_out,
|
||||
c_in,
|
||||
k_size,
|
||||
padding,
|
||||
stride,
|
||||
};
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
|
||||
|
||||
/// Same as `avg_pool2d` but with a `stride` that can be set to a value different from the
|
||||
/// kernel size.
|
||||
pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
|
||||
&self,
|
||||
kernel_size: T,
|
||||
stride: T,
|
||||
) -> Result<Self> {
|
||||
let kernel_size = kernel_size.to_usize2();
|
||||
let stride = stride.to_usize2();
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||
let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
kernel_size,
|
||||
stride,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
let storage = self
|
||||
.storage()
|
||||
.avg_pool2d(self.layout(), kernel_size, stride)?;
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
/// 2D max pooling over an input tensor with multiple channels.
|
||||
///
|
||||
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
||||
/// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
|
||||
/// the two last dimensions using a kernel of size `sz`, the returned element is the maximum
|
||||
/// value over the kernel window.
|
||||
pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
|
||||
let sz = sz.to_usize2();
|
||||
self.max_pool2d_with_stride(sz, sz)
|
||||
}
|
||||
|
||||
/// Same as `max_pool2d` but with a `stride` that can be set to a value different from the
|
||||
/// kernel size.
|
||||
pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
|
||||
&self,
|
||||
kernel_size: T,
|
||||
stride: T,
|
||||
) -> Result<Self> {
|
||||
let kernel_size = kernel_size.to_usize2();
|
||||
let stride = stride.to_usize2();
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||
let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
stride,
|
||||
});
|
||||
let storage = self
|
||||
.storage()
|
||||
.max_pool2d(self.layout(), kernel_size, stride)?;
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
|
||||
@ -825,6 +961,28 @@ impl Tensor {
|
||||
Ok(from_storage(storage, c_shape, op, false))
|
||||
}
|
||||
|
||||
/// Matrix-multiplication with broadcasting support.
|
||||
///
|
||||
/// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as
|
||||
/// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has
|
||||
/// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`.
|
||||
pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
|
||||
let lhs = self;
|
||||
let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
|
||||
let l_broadcast = l_shape != *lhs.shape();
|
||||
let r_broadcast = r_shape != *rhs.shape();
|
||||
// TODO: Avoid concretising the broadcasted matrixes via contiguous.
|
||||
match (l_broadcast, r_broadcast) {
|
||||
(true, true) => lhs
|
||||
.broadcast_as(&l_shape)?
|
||||
.contiguous()?
|
||||
.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
|
||||
(false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
|
||||
(true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
|
||||
(false, false) => lhs.matmul(rhs),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a tensor with the same shape as the input tensor, the values are taken from
|
||||
/// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
|
||||
/// input tensor is equal to zero.
|
||||
@ -917,6 +1075,7 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
let source_dims = source.dims();
|
||||
@ -965,6 +1124,17 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Gather values across the target dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `self` - The input tensor.
|
||||
/// * `indexes` - The indices of elements to gather, this should have the same shape as `self`
|
||||
/// but can have a different number of elements on the target dimension.
|
||||
/// * `dim` - the target dimension.
|
||||
///
|
||||
/// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
|
||||
/// dimension `dim` by the values in `indexes`.
|
||||
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "gather")?;
|
||||
let self_dims = self.dims();
|
||||
@ -995,6 +1165,13 @@ impl Tensor {
|
||||
Ok(from_storage(storage, indexes.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Select values for the input tensor at the target indexes across the specified dimension.
|
||||
///
|
||||
/// The `indexes` is argument is an int tensor with a single dimension.
|
||||
/// The output has the same number of dimension as the `self` input. The target dimension of
|
||||
/// the output has length the length of `indexes` and the values are taken from `self` using
|
||||
/// the index from `indexes`. Other dimensions have the same number of elements as the input
|
||||
/// tensor.
|
||||
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-select")?;
|
||||
let indexes_len = match indexes.dims() {
|
||||
@ -1202,6 +1379,10 @@ impl Tensor {
|
||||
self.sum(dims)
|
||||
}
|
||||
|
||||
pub fn mean_all(&self) -> Result<Tensor> {
|
||||
self.sum_all()? / self.elem_count() as f64
|
||||
}
|
||||
|
||||
fn flatten_<D1: Dim, D2: Dim>(
|
||||
&self,
|
||||
start_dim: Option<D1>,
|
||||
@ -1323,6 +1504,42 @@ impl Tensor {
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// Returns a tensor with the same data as the input where the dimensions have been permuted.
|
||||
/// dims must be a permutation, i.e. include each dimension index exactly once.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
|
||||
/// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
|
||||
/// let tensor = tensor.permute((2, 3, 1, 0))?;
|
||||
/// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
|
||||
let dims = dims.to_indexes(self.shape(), "permute")?;
|
||||
// O(n^2) permutation check but these arrays are small.
|
||||
let is_permutation =
|
||||
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
|
||||
if !is_permutation {
|
||||
crate::bail!(
|
||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||
self.dims(),
|
||||
dims
|
||||
)
|
||||
}
|
||||
let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout.permute(&dims)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
self.layout.is_contiguous()
|
||||
@ -1476,12 +1693,15 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), true))
|
||||
}
|
||||
|
||||
// TODO: Do we want to allow target shape using -1 on some dimensions?
|
||||
/// Reshape returns a tensor with the target shape provided that the number of elements of the
|
||||
/// original tensor is the same.
|
||||
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
|
||||
/// a new storage and copies the data over, the returned tensor is always contiguous.
|
||||
///
|
||||
/// The shape can be specified using a tuple of `usize` and at most one `()` in which case
|
||||
/// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
|
||||
/// as to match the number of elements in the tensor.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, DType, Device, D};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
@ -1491,10 +1711,14 @@ impl Tensor {
|
||||
///
|
||||
/// let c = a.reshape((3, 2))?;
|
||||
/// assert_eq!(c.shape().dims(), &[3, 2]);
|
||||
///
|
||||
/// let c = a.reshape((2, (), 1))?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 3, 1]);
|
||||
///
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
||||
let shape = shape.into();
|
||||
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
let shape = s.into_shape(self.elem_count())?;
|
||||
if shape.elem_count() != self.elem_count() {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
@ -1717,7 +1941,40 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
|
||||
/// input tensor values and `right` elements after.
|
||||
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||
if left == 0 && right == 0 {
|
||||
Ok(self.clone())
|
||||
} else if left == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = right;
|
||||
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||
Tensor::cat(&[self, &right], dim)
|
||||
} else if right == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = left;
|
||||
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||
Tensor::cat(&[&left, self], dim)
|
||||
} else {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = left;
|
||||
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||
dims[dim] = right;
|
||||
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||
Tensor::cat(&[&left, self, &right], dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the `forward` method of `m` on `self`.
|
||||
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
||||
m.forward(self)
|
||||
}
|
||||
|
||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
||||
@ -1742,22 +1999,53 @@ impl Tensor {
|
||||
std::ptr::eq(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Applies a unary custom op without backward support
|
||||
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op without backward support
|
||||
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) =
|
||||
self.storage()
|
||||
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op without backward support
|
||||
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c,
|
||||
)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a unary custom op.
|
||||
pub fn custom_op1_arc(&self, c: Arc<Box<dyn CustomOp1>>) -> Result<Self> {
|
||||
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
||||
let (storage, shape) = self
|
||||
.storage()
|
||||
.custom_op1(self.layout(), c.as_ref().as_ref())?;
|
||||
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn custom_op1<C: 'static + CustomOp1>(&self, c: C) -> Result<Self> {
|
||||
self.custom_op1_arc(Arc::new(Box::new(c)))
|
||||
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
||||
self.apply_op1_arc(Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op.
|
||||
pub fn custom_op2_arc(&self, rhs: &Self, c: Arc<Box<dyn CustomOp2>>) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().custom_op2(
|
||||
pub fn apply_op2_arc(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op2(
|
||||
self.layout(),
|
||||
&rhs.storage(),
|
||||
rhs.layout(),
|
||||
@ -1767,13 +2055,18 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn custom_op2<C: 'static + CustomOp2>(&self, r: &Self, c: C) -> Result<Self> {
|
||||
self.custom_op2_arc(r, Arc::new(Box::new(c)))
|
||||
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
||||
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op.
|
||||
pub fn custom_op3_arc(&self, t2: &Self, t3: &Self, c: Arc<Box<dyn CustomOp3>>) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().custom_op3(
|
||||
pub fn apply_op3_arc(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
@ -1787,8 +2080,13 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn custom_op3<C: 'static + CustomOp3>(&self, t2: &Self, t3: &Self, c: C) -> Result<Self> {
|
||||
self.custom_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: C,
|
||||
) -> Result<Self> {
|
||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1810,6 +2108,22 @@ macro_rules! bin_trait {
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn $fn1(self, rhs: Tensor) -> Self::Output {
|
||||
Tensor::$fn1(self?.borrow(), &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn $fn1(self, rhs: &Tensor) -> Self::Output {
|
||||
Tensor::$fn1(self?.borrow(), rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
@ -1848,3 +2162,69 @@ bin_trait!(Add, add, |_| 1., |v| v);
|
||||
bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
|
||||
bin_trait!(Mul, mul, |v| v, |_| 0.);
|
||||
bin_trait!(Div, div, |v| 1. / v, |_| 0.);
|
||||
|
||||
impl std::ops::Add<Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn add(self, rhs: Tensor) -> Self::Output {
|
||||
rhs + self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Add<&Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn add(self, rhs: &Tensor) -> Self::Output {
|
||||
rhs + self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Mul<Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn mul(self, rhs: Tensor) -> Self::Output {
|
||||
rhs * self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Mul<&Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn mul(self, rhs: &Tensor) -> Self::Output {
|
||||
rhs * self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Sub<Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn sub(self, rhs: Tensor) -> Self::Output {
|
||||
rhs.affine(-1., self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Sub<&Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn sub(self, rhs: &Tensor) -> Self::Output {
|
||||
rhs.affine(-1., self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Div<Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, rhs: Tensor) -> Self::Output {
|
||||
rhs.recip()? * self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Div<&Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, rhs: &Tensor) -> Self::Output {
|
||||
rhs.recip()? * self
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,4 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use candle_core::{Result, Tensor};
|
||||
use crate::{Result, Tensor};
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! test_device {
|
||||
@ -20,6 +18,12 @@ macro_rules! test_device {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec0::<f32>()?;
|
||||
Ok(f32::round(t * b) / b)
|
||||
}
|
||||
|
||||
pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec1::<f32>()?;
|
||||
@ -37,7 +41,7 @@ pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
|
||||
pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec3::<f32>()?;
|
||||
let t = t
|
@ -11,16 +11,30 @@ pub fn get_num_threads() -> usize {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_accelerate() -> bool {
|
||||
cfg!(feature = "accelerate")
|
||||
}
|
||||
|
||||
pub fn has_mkl() -> bool {
|
||||
#[cfg(feature = "mkl")]
|
||||
return true;
|
||||
#[cfg(not(feature = "mkl"))]
|
||||
return false;
|
||||
cfg!(feature = "mkl")
|
||||
}
|
||||
|
||||
pub fn cuda_is_available() -> bool {
|
||||
#[cfg(feature = "cuda")]
|
||||
return true;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
return false;
|
||||
cfg!(feature = "cuda")
|
||||
}
|
||||
|
||||
pub fn with_avx() -> bool {
|
||||
cfg!(target_feature = "avx")
|
||||
}
|
||||
|
||||
pub fn with_neon() -> bool {
|
||||
cfg!(target_feature = "neon")
|
||||
}
|
||||
|
||||
pub fn with_simd128() -> bool {
|
||||
cfg!(target_feature = "simd128")
|
||||
}
|
||||
|
||||
pub fn with_f16c() -> bool {
|
||||
cfg!(target_feature = "f16c")
|
||||
}
|
||||
|
495
candle-core/tests/conv_tests.rs
Normal file
495
candle-core/tests/conv_tests.rs
Normal file
@ -0,0 +1,495 @@
|
||||
use anyhow::Result;
|
||||
use candle_core::{test_device, test_utils, Device, IndexOp, Tensor};
|
||||
|
||||
/* This test is based on the following script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 4, 5))
|
||||
w = torch.randn((2, 4, 3))
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv1d(t, w)
|
||||
print(res.flatten())
|
||||
res = torch.nn.functional.conv1d(t, w, padding=1)
|
||||
print(res.flatten())
|
||||
*/
|
||||
fn conv1d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
|
||||
1.8025, -0.1536, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278, -1.0124, 0.5599,
|
||||
],
|
||||
dev,
|
||||
)?
|
||||
.reshape((1, 4, 5))?;
|
||||
let w = Tensor::new(
|
||||
&[
|
||||
-0.8404f32, -0.3490, 0.0130, 1.3123, 0.1763, -1.9249, 1.4270, 0.9421, 0.8670, -0.7181,
|
||||
-1.1111, 0.8869, -1.2429, 1.8357, 1.6052, -1.3844, 0.3951, -1.2036, 0.6686, 1.6261,
|
||||
-0.6451, -0.0840, -1.4247, 0.5512,
|
||||
],
|
||||
dev,
|
||||
)?
|
||||
.reshape((2, 4, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
|
||||
);
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 5]);
|
||||
// Same as pytorch default padding: use zeros.
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn conv1d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
|
||||
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 2]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.4056, -0.8689]
|
||||
);
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 4]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.0, 0.4056, -0.8689, -0.0773],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/* This test is based on the following script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 4, 5, 5))
|
||||
w = torch.randn((2, 4, 3, 3))
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
|
||||
w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
|
||||
res = torch.nn.functional.conv2d(t, w, dilation=2)
|
||||
print(res.shape)
|
||||
print(res[0])
|
||||
|
||||
res = torch.nn.functional.conv_transpose2d(t, w_t, dilation=2)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv2d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
|
||||
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
|
||||
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
|
||||
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
|
||||
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
|
||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
let w = Tensor::new(
|
||||
&[
|
||||
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
|
||||
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
|
||||
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
|
||||
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
|
||||
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
|
||||
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
|
||||
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
|
||||
0.5583, 0.4623, 0.6026,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
let t = t.reshape((1, 4, 5, 5))?;
|
||||
let w = w.reshape((2, 4, 3, 3))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
[
|
||||
[
|
||||
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
|
||||
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
|
||||
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
|
||||
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
|
||||
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
|
||||
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
|
||||
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
|
||||
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
|
||||
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
|
||||
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
|
||||
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
|
||||
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
|
||||
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
|
||||
]
|
||||
]
|
||||
);
|
||||
// Dilations.
|
||||
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.45, -2.3504],
|
||||
);
|
||||
|
||||
// Transpose and dilations.
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 2, 9, 9]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
[
|
||||
[
|
||||
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
|
||||
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
|
||||
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
|
||||
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
|
||||
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
|
||||
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
|
||||
[-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
|
||||
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
|
||||
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
|
||||
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
|
||||
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
|
||||
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
|
||||
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
|
||||
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
|
||||
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
|
||||
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
|
||||
[-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
|
||||
]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/* This test is based on the following script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 2, 3, 3))
|
||||
w = torch.randn((1, 2, 1, 1))
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
|
||||
w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res.flatten())
|
||||
|
||||
t_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t_t, w)
|
||||
print(res.shape)
|
||||
print(res.flatten())
|
||||
*/
|
||||
fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
|
||||
-0.6266, 0.3529, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
|
||||
let t = t.reshape((1, 2, 3, 3))?;
|
||||
let w = w.reshape((1, 2, 1, 1))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
|
||||
);
|
||||
let res = t.conv2d(&w, 2, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
|
||||
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||
]
|
||||
);
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
|
||||
);
|
||||
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [2, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
-0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728, 0.528,
|
||||
-1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838, 0.5802,
|
||||
-0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396, -0.8156, 0.4594,
|
||||
2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn conv2d_smaller(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
|
||||
let t = t.reshape((1, 1, 3, 3))?;
|
||||
let w = w.reshape((1, 1, 3, 3))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 1, 1]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[-0.6197]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/* This test is based on the following script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 2, 4, 2))
|
||||
w = torch.randn((1, 2, 1, 1))
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
*/
|
||||
fn conv2d_non_square(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?;
|
||||
let t = t.reshape((1, 2, 4, 2))?;
|
||||
let w = w.reshape((1, 2, 1, 1))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 4, 2]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.2312, 5.2238, 2.3772, 1.9076, 2.0256, -0.5776, -1.6028, -1.467]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 4, 5, 5), requires_grad=True)
|
||||
w = torch.randn((2, 4, 3, 3), requires_grad=True)
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
loss = (res ** 2).sum()
|
||||
print(loss)
|
||||
loss.backward()
|
||||
print(t.grad.shape)
|
||||
print(t.grad.flatten())
|
||||
print(w.grad.shape)
|
||||
print(w.grad.flatten())
|
||||
|
||||
t.grad.zero_()
|
||||
w.grad.zero_()
|
||||
res = torch.nn.functional.conv2d(t, w, stride=2)
|
||||
print(res.flatten())
|
||||
loss = (res ** 2).sum()
|
||||
print(loss)
|
||||
loss.backward()
|
||||
print(t.grad.shape)
|
||||
print(t.grad[0])
|
||||
print(w.grad.shape)
|
||||
print(w.grad[0])
|
||||
*/
|
||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
use candle_core::Var;
|
||||
let t = Var::from_slice(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
|
||||
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
|
||||
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
|
||||
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
|
||||
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
|
||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
],
|
||||
(1, 4, 5, 5),
|
||||
dev,
|
||||
)?;
|
||||
let w = Var::from_slice(
|
||||
&[
|
||||
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
|
||||
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
|
||||
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
|
||||
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
|
||||
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
|
||||
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
|
||||
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
|
||||
0.5583, 0.4623, 0.6026,
|
||||
],
|
||||
(2, 4, 3, 3),
|
||||
dev,
|
||||
)?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32);
|
||||
let grads = loss.backward()?;
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_t.flatten_all()?, 2)?,
|
||||
[
|
||||
9.29, -2.84, -5.71, 3.38, -7.71, -19.15, 7.02, 29.1, 9.34, 34.73, -22.87, 24.35,
|
||||
-39.88, -14.01, 21.08, 9.94, 13.63, -34.68, 11.21, -6.26, 7.72, -6.32, -16.64, -1.08,
|
||||
-20.22, 21.73, -0.37, -4.06, 5.82, -3.65, -30.73, 14.55, 87.7, 31.6, 4.53, -89.78,
|
||||
-75.37, -57.43, -7.56, 92.96, 18.79, -4.63, -159.75, -42.47, -47.26, 52.88, 37.32,
|
||||
49.0, 12.82, 2.01, -8.98, 20.18, 16.62, 12.06, 15.38, 20.0, 2.57, -15.22, 72.62,
|
||||
-10.75, 2.25, -31.2, 3.75, -0.2, 9.76, -0.68, 5.21, -40.44, -22.59, -61.61, 17.28,
|
||||
20.41, 37.55, 5.23, 6.81, 23.54, 23.62, -9.99, -9.13, 4.87, -35.06, -26.1, 63.48,
|
||||
25.81, -39.21, -70.68, -46.96, 2.33, 41.81, 82.42, -28.63, -11.78, -35.33, -10.28,
|
||||
-28.57, -9.13, 7.21, -9.05, -9.62, -11.25
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_w.flatten_all()?, 2)?,
|
||||
[
|
||||
-28.92, -22.88, -141.23, 73.35, 61.07, 47.81, -20.0, -73.71, -41.82, -13.59, 21.5,
|
||||
28.72, 28.57, -46.85, -90.19, 143.61, 16.68, 7.43, 18.88, -90.81, -20.29, 54.79, 82.63,
|
||||
22.94, 77.81, -16.39, -13.2, 9.34, -40.39, -26.62, 5.33, -60.91, 9.09, -59.37, 7.08,
|
||||
58.64, 5.55, 20.52, 2.5, -17.25, -6.8, 22.21, 30.15, -7.52, -37.46, 5.67, 22.58, 9.03,
|
||||
47.05, 17.61, 37.31, -98.13, -14.61, -4.8, -6.36, 44.69, 23.34, 8.37, -13.52, 80.05,
|
||||
-34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6
|
||||
]
|
||||
);
|
||||
|
||||
// Same as before but with stride.
|
||||
let res = t.conv2d(&w, 0, 2, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 277.16f32);
|
||||
let grads = loss.backward()?;
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[9.29, -7.03, 0.94, 3.49, -7.71],
|
||||
[-1.8, -7.82, 8.9, 8.46, 7.43],
|
||||
[-25.84, 22.09, -19.27, -0.22, 1.69],
|
||||
[4.02, 18.53, -18.37, 2.3, -24.51],
|
||||
[7.72, -9.68, -12.34, 5.6, -20.22]
|
||||
],
|
||||
[
|
||||
[21.73, 3.39, -18.27, 3.86, -3.65],
|
||||
[8.25, 3.73, 30.73, -8.61, -11.93],
|
||||
[-72.15, -15.36, -17.53, -12.32, -1.61],
|
||||
[-22.32, -7.79, -91.82, 6.44, -37.69],
|
||||
[52.88, 14.44, 42.75, 9.88, 2.01]
|
||||
],
|
||||
[
|
||||
[-8.98, 9.91, 6.75, -4.68, 15.38],
|
||||
[4.93, -0.33, 9.94, -1.46, 14.78],
|
||||
[13.62, -30.63, 3.96, -3.58, -4.48],
|
||||
[-14.13, 1.19, -34.43, 3.08, -33.83],
|
||||
[17.28, 12.94, 31.83, -3.35, 6.81]
|
||||
],
|
||||
[
|
||||
[23.54, 6.98, -24.52, 0.52, 4.87],
|
||||
[9.65, 6.18, 1.71, -25.23, -4.93],
|
||||
[-54.99, -23.66, 3.19, -3.73, 18.58],
|
||||
[-21.35, -10.39, -39.88, 28.73, -30.76],
|
||||
[-9.13, 11.12, -14.0, -8.23, -11.25]
|
||||
]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[28.34, -7.91, -45.75],
|
||||
[21.03, 3.86, 29.86],
|
||||
[0.72, -36.58, -35.28]
|
||||
],
|
||||
[
|
||||
[-16.04, 11.53, -16.38],
|
||||
[29.62, -16.32, -48.35],
|
||||
[57.5, 28.29, 25.81]
|
||||
],
|
||||
[
|
||||
[2.93, -19.6, 1.57],
|
||||
[27.15, 53.88, -24.64],
|
||||
[12.74, -22.6, -26.2]
|
||||
],
|
||||
[
|
||||
[-0.18, -14.86, -6.82],
|
||||
[-19.55, -2.72, 45.9],
|
||||
[-2.54, 36.97, 27.11]
|
||||
]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||
test_device!(
|
||||
conv2d_non_square,
|
||||
conv2d_non_square_cpu,
|
||||
conv2d_non_square_gpu
|
||||
);
|
||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
||||
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
@ -1,10 +1,8 @@
|
||||
use candle_core::backend::BackendStorage;
|
||||
use candle_core::cpu_backend;
|
||||
use candle_core::test_utils::to_vec1_round;
|
||||
use candle_core::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
|
||||
|
||||
mod test_utils;
|
||||
use test_utils::to_vec1_round;
|
||||
|
||||
fn fwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
|
||||
if v.is_sign_positive() {
|
||||
v
|
||||
@ -39,7 +37,7 @@ fn custom_op1_no_backward() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
|
||||
let t = (t - 5.)?;
|
||||
let elu_t = t.custom_op1(Elu { alpha: 1. })?;
|
||||
let elu_t = t.apply_op1_no_bwd(&Elu { alpha: 1. })?;
|
||||
assert_eq!(
|
||||
to_vec1_round(&elu_t, 4)?,
|
||||
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
@ -96,7 +94,7 @@ impl CustomOp1 for EluWithBackward {
|
||||
|
||||
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
let alpha = self.0.alpha;
|
||||
let bwd = arg.custom_op1(EluBackward { alpha })?;
|
||||
let bwd = arg.apply_op1(EluBackward { alpha })?;
|
||||
Ok(Some(grad_res.mul(&bwd)?))
|
||||
}
|
||||
}
|
||||
@ -105,7 +103,7 @@ impl CustomOp1 for EluWithBackward {
|
||||
fn custom_op1_with_backward() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
|
||||
let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
|
||||
let elu_t = t.apply_op1(EluWithBackward::new(2.))?;
|
||||
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
|
||||
|
||||
let grads = elu_t.backward()?;
|
||||
|
@ -1,6 +1,5 @@
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{Device, Shape, Tensor, Var};
|
||||
mod test_utils;
|
||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
||||
|
||||
fn simple_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||
@ -85,8 +84,14 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let y = (x.log()? + 1.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [2.0986123, 1.0, 2.3862944, -0.89712]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.0986, 1.0, 2.3863, -0.8971]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[0.3333, 1.0, 0.25, 6.6667]
|
||||
);
|
||||
let y = x.exp()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
@ -141,7 +146,7 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3.0, 1.0, 4.0, 0.15]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1.0, 1.0, 1.0, 1.0]);
|
||||
assert_eq!(test_utils::to_vec1_round(grad_x, 4)?, [1.0, 1.0, 1.0, 1.0]);
|
||||
let y = x.neg()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
@ -155,7 +160,10 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let y = Tensor::new(1f32, device)?.broadcast_div(x)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[0.3333, 1.0, 0.25, 6.6667]
|
||||
);
|
||||
assert_eq!(
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[-0.11111111, -1.0, -0.0625, -44.444443],
|
||||
@ -165,6 +173,51 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [6., 2., 8., 0.3]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [2., 2., 2., 2.]);
|
||||
|
||||
let x = Var::new(&[3f32, 1., 4., 0.15], device)?;
|
||||
let y = x.powf(2.5)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(test_utils::to_vec1_round(&y, 2)?, [15.59, 1.0, 32.0, 0.01]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[12.99, 2.5, 20.0, 0.15]
|
||||
);
|
||||
|
||||
let y = x.tanh()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(test_utils::to_vec1_round(&y, 2)?, [1.0, 0.76, 1.0, 0.15]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[0.01, 0.42, 0.0, 0.98],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn binary_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[3f32, 1., -4., -1.], device)?;
|
||||
let x = x.as_tensor();
|
||||
// leaky relu
|
||||
let y = x.maximum(&(x * 0.1)?)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(x.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -0.4, -0.1]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 0.1, 0.1]);
|
||||
|
||||
let y = x.minimum(&(x * 0.1)?)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [0.3, 0.1, -4., -1.]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [0.1, 0.1, 1., 1.]);
|
||||
|
||||
// This one is easy to mess up, we want the gradient to be one as it is the identity function.
|
||||
let y = x.minimum(x)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -173,3 +226,4 @@ test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
||||
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
|
||||
|
@ -1,8 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, IndexOp, Tensor};
|
||||
|
||||
mod test_utils;
|
||||
|
||||
#[test]
|
||||
fn integer_index() -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
|
@ -1,5 +1,4 @@
|
||||
mod test_utils;
|
||||
use candle::{Device, IndexOp, Result, Tensor};
|
||||
use candle::{test_device, Device, IndexOp, Result, Tensor};
|
||||
use candle_core as candle;
|
||||
|
||||
fn contiguous(device: &Device) -> Result<()> {
|
||||
|
112
candle-core/tests/pool_tests.rs
Normal file
112
candle-core/tests/pool_tests.rs
Normal file
@ -0,0 +1,112 @@
|
||||
use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
|
||||
|
||||
// https://github.com/huggingface/candle/issues/364
|
||||
fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
let data: Vec<f32> = vec![
|
||||
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
|
||||
|
||||
let data: Vec<f32> = vec![
|
||||
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 2, 8), dev)?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[5. / 4., 6. / 4., 6. / 4., 1.]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn max_pool2d(dev: &Device) -> Result<()> {
|
||||
let data: Vec<f32> = vec![
|
||||
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
||||
|
||||
let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
|
||||
|
||||
let t = t.reshape((1, 1, 2, 8))?;
|
||||
let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2.0, 3.0, 5.0, 1.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/* This test corresponds to the following PyTorch script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 2, 4, 4))
|
||||
print(t.flatten())
|
||||
res = torch.nn.functional.avg_pool2d(t, 2)
|
||||
print(res)
|
||||
*/
|
||||
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
|
||||
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
|
||||
0.2477, 1.3127,
|
||||
],
|
||||
dev,
|
||||
)?
|
||||
.reshape((1, 2, 4, 4))?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&pool, 4)?,
|
||||
[
|
||||
[[-1.1926, -0.0395], [0.2688, 0.1871]],
|
||||
[[0.1835, -0.1606], [0.6249, 0.3217]]
|
||||
]
|
||||
);
|
||||
let pool = t.avg_pool2d(3)?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&pool, 4)?,
|
||||
[[[0.085]], [[0.0078]]]
|
||||
);
|
||||
|
||||
let t = t.reshape((1, 1, 4, 8))?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&pool, 4)?,
|
||||
[
|
||||
[0.7745, 0.0276, -1.6983, 0.12],
|
||||
[0.3542, 0.1625, 0.4542, -0.0014]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 6f32, dev)?.reshape((1, 1, 2, 3))?;
|
||||
let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?;
|
||||
assert_eq!(
|
||||
t.i(0)?.i(0)?.to_vec2::<f32>()?,
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
upsampled.to_vec2::<f32>()?,
|
||||
[
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
|
||||
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
||||
test_device!(
|
||||
avg_pool2d_pytorch,
|
||||
avg_pool2d_pytorch_cpu,
|
||||
avg_pool2d_pytorch_gpu
|
||||
);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
||||
test_device!(
|
||||
upsample_nearest2d,
|
||||
upsample_nearest2d_cpu,
|
||||
upsample_nearest2d_gpu
|
||||
);
|
689
candle-core/tests/quantized_tests.rs
Normal file
689
candle-core/tests/quantized_tests.rs
Normal file
@ -0,0 +1,689 @@
|
||||
use candle_core::{
|
||||
quantized::{self, GgmlDType},
|
||||
test_utils::to_vec2_round,
|
||||
Device, Result, Tensor,
|
||||
};
|
||||
use quantized::{k_quants, GgmlType};
|
||||
use rand::prelude::*;
|
||||
|
||||
const GGML_TEST_SIZE: usize = 32 * 128;
|
||||
|
||||
const GGML_MAX_QUANTIZATION_TOTAL_ERROR: f32 = 0.002;
|
||||
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
|
||||
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
|
||||
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||
&[
|
||||
85120.0, 214562.0, 345455.0, 474748.0, 213475.0, 604465.0, 1000686.0, 1388317.0,
|
||||
341876.0, 994283.0, 1655709.0, 2301518.0
|
||||
]
|
||||
);
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
assert_eq!(
|
||||
mm.to_vec2::<f32>()?,
|
||||
&[
|
||||
[85344.0, 214368.0, 343392.0, 472416.0],
|
||||
[214368.0, 605536.0, 996704.0, 1387872.0],
|
||||
[343392.0, 996704.0, 1650016.0, 2303328.0]
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||
[213475.0, 604465.0, 1000686.0, 1388317.0],
|
||||
[341876.0, 994283.0, 1655709.0, 2301518.0]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_neg() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..k * n)
|
||||
.map(|v| v as f32 - (k * n) as f32 / 3.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||
&[
|
||||
243524.0, -19596.0, -285051.0, -549815.0, 23777.0, 21651.0, 19398.0, 18367.0,
|
||||
-196472.0, 63012.0, 324585.0, 587902.0
|
||||
]
|
||||
);
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&mm, 0)?,
|
||||
&[
|
||||
[244064.0, -20128.0, -284320.0, -548512.0],
|
||||
[23563.0, 21515.0, 19467.0, 17419.0],
|
||||
[-196939.0, 63157.0, 323253.0, 583349.0]
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||
[23777.0, 21651.0, 19398.0, 18367.0],
|
||||
[-196472.0, 63012.0, 324585.0, 587902.0]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4_0() -> Result<()> {
|
||||
use k_quants::BlockQ4_0;
|
||||
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ4_0::zeros(); 4];
|
||||
BlockQ4_0::from_float(&src, &mut quant)?;
|
||||
BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
dst,
|
||||
&[
|
||||
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
|
||||
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
|
||||
23.25, 27.125, 27.125, 27.125, 27.125, 31.0, 31.0, 31.5, 31.5, 31.5, 31.5, 39.375,
|
||||
39.375, 39.375, 39.375, 39.375, 39.375, 39.375, 39.375, 47.25, 47.25, 47.25, 47.25,
|
||||
47.25, 47.25, 47.25, 47.25, 55.125, 55.125, 55.125, 55.125, 55.125, 55.125, 55.125,
|
||||
55.125, 63.0, 63.0, 63.0, 63.0, 59.375, 59.375, 71.25, 71.25, 71.25, 71.25, 71.25,
|
||||
71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 83.125, 83.125, 83.125, 83.125,
|
||||
83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 95.0, 95.0, 95.0, 95.0,
|
||||
95.0, 95.0, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 111.125, 111.125,
|
||||
111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125,
|
||||
111.125, 111.125, 111.125, 111.125, 111.125, 127.0, 127.0, 127.0, 127.0, 127.0, 127.0,
|
||||
127.0, 127.0
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4_1() -> Result<()> {
|
||||
use k_quants::BlockQ4_1;
|
||||
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ4_1::zeros(); 4];
|
||||
BlockQ4_1::from_float(&src, &mut quant)?;
|
||||
BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
round_vector(&dst),
|
||||
&[
|
||||
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
|
||||
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
|
||||
22.73, 24.797, 24.797, 26.863, 26.863, 28.93, 28.93, 30.996, 30.996, 32.0, 32.0,
|
||||
34.066, 34.066, 36.133, 36.133, 38.199, 38.199, 40.266, 40.266, 42.332, 42.332, 44.398,
|
||||
44.398, 46.465, 46.465, 48.531, 48.531, 50.598, 50.598, 52.664, 52.664, 54.73, 54.73,
|
||||
56.797, 56.797, 58.863, 58.863, 60.93, 60.93, 62.996, 62.996, 64.0, 64.0, 66.066,
|
||||
66.066, 68.133, 68.133, 70.199, 70.199, 72.266, 72.266, 74.332, 74.332, 76.398, 76.398,
|
||||
78.465, 78.465, 80.531, 80.531, 82.598, 82.598, 84.664, 84.664, 86.73, 86.73, 88.797,
|
||||
88.797, 90.863, 90.863, 92.93, 92.93, 94.996, 94.996, 96.0, 96.0, 98.066, 98.066,
|
||||
100.133, 100.133, 102.199, 102.199, 104.266, 104.266, 106.332, 106.332, 108.398,
|
||||
108.398, 110.465, 110.465, 112.531, 112.531, 114.598, 114.598, 116.664, 116.664,
|
||||
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q5_0() -> Result<()> {
|
||||
use k_quants::BlockQ5_0;
|
||||
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ5_0::zeros(); 4];
|
||||
BlockQ5_0::from_float(&src, &mut quant)?;
|
||||
BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
round_vector(&dst),
|
||||
&[
|
||||
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
|
||||
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
|
||||
23.25, 23.25, 25.188, 25.188, 27.125, 27.125, 29.063, 29.063, 31.0, 31.5, 31.5, 35.438,
|
||||
35.438, 35.438, 35.438, 39.375, 39.375, 39.375, 39.375, 43.313, 43.313, 43.313, 43.313,
|
||||
47.25, 47.25, 47.25, 47.25, 51.188, 51.188, 51.188, 51.188, 55.125, 55.125, 55.125,
|
||||
55.125, 59.063, 59.063, 59.063, 59.063, 63.0, 63.0, 65.313, 65.313, 65.313, 65.313,
|
||||
65.313, 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 77.188, 77.188, 77.188, 77.188,
|
||||
77.188, 77.188, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 89.063, 89.063, 89.063,
|
||||
89.063, 89.063, 89.063, 95.0, 95.0, 95.0, 95.25, 95.25, 95.25, 95.25, 103.188, 103.188,
|
||||
103.188, 103.188, 103.188, 103.188, 103.188, 103.188, 111.125, 111.125, 111.125,
|
||||
111.125, 111.125, 111.125, 111.125, 111.125, 119.063, 119.063, 119.063, 119.063,
|
||||
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q5_1() -> Result<()> {
|
||||
use k_quants::BlockQ5_1;
|
||||
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ5_1::zeros(); 4];
|
||||
BlockQ5_1::from_float(&src, &mut quant)?;
|
||||
BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
dst,
|
||||
&[
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
|
||||
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
|
||||
30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0,
|
||||
44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0,
|
||||
58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0,
|
||||
72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0,
|
||||
86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0,
|
||||
100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0,
|
||||
112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0,
|
||||
124.0, 125.0, 126.0, 127.0
|
||||
]
|
||||
);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
|
||||
fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
|
||||
assert!(
|
||||
size % crate::quantized::k_quants::QK_K == 0,
|
||||
"size must be a multiple of {}",
|
||||
crate::quantized::k_quants::QK_K
|
||||
);
|
||||
|
||||
let src = (0..size)
|
||||
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let dst = vec![0f32; size];
|
||||
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
|
||||
(src, dst)
|
||||
}
|
||||
|
||||
/// Round a vector
|
||||
fn round_vector(values: &[f32]) -> Vec<f32> {
|
||||
values
|
||||
.iter()
|
||||
.map(|x| (1000. * x).round() / 1000.)
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
|
||||
for (i, (value, expected_value)) in values.iter().zip(expected.iter()).enumerate() {
|
||||
let difference = (value - expected_value).abs();
|
||||
|
||||
assert!(
|
||||
difference < tolerance,
|
||||
"Error at index {}: value = {}, expected = {}. Difference = {} exceeds tolerance = {}.",
|
||||
i,
|
||||
value,
|
||||
expected_value,
|
||||
difference,
|
||||
tolerance
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
||||
(0..GGML_TEST_SIZE)
|
||||
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Calculates the root mean square error between two vectors
|
||||
fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
|
||||
assert_eq!(a.len(), b.len());
|
||||
let sum = a
|
||||
.iter()
|
||||
.zip(b)
|
||||
.map(|(a, b)| (a - b).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt();
|
||||
sum / a.len() as f32
|
||||
}
|
||||
|
||||
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
|
||||
let src = create_ggml_like_vector(0.0);
|
||||
let mut dst = vec![0.0; GGML_TEST_SIZE];
|
||||
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let error = calculate_rmse(src.as_slice(), dst.as_slice());
|
||||
if error > max_error {
|
||||
candle_core::bail!(
|
||||
"Quantization error {} exceeds max error {}",
|
||||
error,
|
||||
max_error
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
|
||||
let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||
T::from_float(src, &mut quant)?;
|
||||
T::to_float(&quant, dst)?;
|
||||
Ok(quant)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q2k() -> Result<()> {
|
||||
use k_quants::BlockQ2K;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q3k() -> Result<()> {
|
||||
use k_quants::BlockQ3K;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4k() -> Result<()> {
|
||||
use k_quants::BlockQ4K;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q5k() -> Result<()> {
|
||||
use k_quants::BlockQ5K;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q6k() -> Result<()> {
|
||||
use k_quants::BlockQ6K;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q8k() -> Result<()> {
|
||||
use k_quants::BlockQ8K;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Very simple dot product implementation
|
||||
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
||||
}
|
||||
|
||||
/// Returns the error achieved by the GGML matmul unit test.
|
||||
fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
||||
let err = match dtype {
|
||||
GgmlDType::F16 => 0.000010,
|
||||
GgmlDType::Q2K => 0.004086,
|
||||
GgmlDType::Q3K => 0.016148,
|
||||
GgmlDType::Q4K => 0.002425,
|
||||
GgmlDType::Q5K => 0.000740,
|
||||
GgmlDType::Q6K => 0.000952,
|
||||
GgmlDType::Q4_0 => 0.001143,
|
||||
GgmlDType::Q4_1 => 0.007784,
|
||||
GgmlDType::Q5_0 => 0.001353,
|
||||
GgmlDType::Q5_1 => 0.001363,
|
||||
GgmlDType::Q8_0 => 0.000092,
|
||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||
};
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
let a = create_ggml_like_vector(0.0);
|
||||
let b = create_ggml_like_vector(1.0);
|
||||
let length = a.len();
|
||||
|
||||
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
|
||||
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
|
||||
T::from_float(&a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||
|
||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||
let reference_result = vec_dot_reference(&a, &b);
|
||||
|
||||
let error = (result - reference_result).abs() / length as f32;
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
|
||||
if error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
candle_core::bail!(
|
||||
"Dot product error {} exceeds max error {}",
|
||||
error,
|
||||
GGML_MAX_DOT_PRODUCT_ERROR
|
||||
);
|
||||
}
|
||||
|
||||
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
|
||||
// => we use a slightly higher error threshold
|
||||
const ERROR_LENIENCY: f32 = 0.00001;
|
||||
if error - ERROR_LENIENCY > ggml_error {
|
||||
candle_core::bail!(
|
||||
"Dot product error {} exceeds ggml reference error {}",
|
||||
error,
|
||||
ggml_error
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
|
||||
fn get_random_tensors(
|
||||
m: usize,
|
||||
k: usize,
|
||||
n: usize,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let mut rng = StdRng::seed_from_u64(314159265358979);
|
||||
|
||||
let lhs = (0..m * k)
|
||||
.map(|_| rng.gen::<f32>() - 0.5)
|
||||
.collect::<Vec<_>>();
|
||||
let rhs = (0..n * k)
|
||||
.map(|_| rng.gen::<f32>() - 0.5)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lhs = Tensor::from_vec(lhs, (m, k), device)?;
|
||||
let rhs = Tensor::from_vec(rhs, (n, k), device)?;
|
||||
|
||||
let mm = lhs.matmul(&rhs.t()?)?;
|
||||
Ok((lhs, rhs, mm))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q2k() -> Result<()> {
|
||||
use k_quants::BlockQ2K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [0.916, 0.422, 0.215, 1.668]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ2K>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q3k() -> Result<()> {
|
||||
use k_quants::BlockQ3K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ3K>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q4k() -> Result<()> {
|
||||
use k_quants::BlockQ4K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ4K>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q5k() -> Result<()> {
|
||||
use k_quants::BlockQ5K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.192, 1.491, -0.18, 1.743]);
|
||||
|
||||
//Expected: 0.000740408897
|
||||
ggml_matmul_error_test::<BlockQ5K>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q6k() -> Result<()> {
|
||||
use k_quants::BlockQ6K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.324, 1.49, -0.164, 1.741]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ6K>()?;
|
||||
Ok(())
|
||||
}
|
@ -1,5 +1,4 @@
|
||||
mod test_utils;
|
||||
use candle_core::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
@ -36,10 +35,10 @@ fn tensor_2d(device: &Device) -> Result<()> {
|
||||
|
||||
fn binary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let tensor1 = Tensor::new(data, device)?;
|
||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor2 = Tensor::new(data2, device)?;
|
||||
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
|
||||
let tensor = (&tensor1 + (&tensor1 * &tensor1)? / (&tensor1 + &tensor2))?;
|
||||
let dims = tensor.dims2()?;
|
||||
assert_eq!(dims, (2, 5));
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
@ -49,6 +48,17 @@ fn binary_op(device: &Device) -> Result<()> {
|
||||
let tensor = (&tensor - &tensor)?;
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
assert_eq!(content[0], [0., 0., 0., 0., 0.]);
|
||||
|
||||
let min = tensor1.minimum(&(&tensor2 * 0.5)?)?;
|
||||
let max = tensor1.maximum(&(&tensor2 * 0.5)?)?;
|
||||
assert_eq!(
|
||||
min.to_vec2::<f32>()?,
|
||||
[[2.5, 1.0, 2.5, 1.0, 2.5], [1.0, 0.5, 3.5, 4.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
max.to_vec2::<f32>()?,
|
||||
[[3.0, 2.5, 4.0, 2.5, 5.0], [2.0, 1.0, 7.0, 8.0, 2.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -747,6 +757,25 @@ fn matmul(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||
let out = lhs.broadcast_matmul(&rhs)?;
|
||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||
for idx1 in 0..3 {
|
||||
for idx2 in 0..6 {
|
||||
let out = out.i((idx1, idx2))?;
|
||||
let lhs = lhs.i((idx1, 0))?;
|
||||
let rhs = rhs.i(idx2)?;
|
||||
let out2 = lhs.matmul(&rhs);
|
||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||
// With cuda, we see errors of up to ~1e-12.
|
||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcasting(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
||||
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
||||
@ -864,8 +893,20 @@ test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
|
||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
#[test]
|
||||
fn randn_hasneg() -> Result<()> {
|
||||
let t = Tensor::randn(0f32, 1f32, 200, &Device::Cpu)?.to_vec1::<f32>()?;
|
||||
if t.iter().all(|&v| v >= 0.) {
|
||||
candle_core::bail!("all values in tensors are non-negative")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
23
candle-datasets/Cargo.toml
Normal file
23
candle-datasets/Cargo.toml
Normal file
@ -0,0 +1,23 @@
|
||||
[package]
|
||||
name = "candle-datasets"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
rand = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
parquet = { workspace = true}
|
||||
image = { workspace = true }
|
1
candle-datasets/README.md
Normal file
1
candle-datasets/README.md
Normal file
@ -0,0 +1 @@
|
||||
# candle-datasets
|
73
candle-datasets/src/hub.rs
Normal file
73
candle-datasets/src/hub.rs
Normal file
@ -0,0 +1,73 @@
|
||||
use hf_hub::{
|
||||
api::sync::{Api, ApiRepo},
|
||||
Repo, RepoType,
|
||||
};
|
||||
use parquet::file::reader::SerializedFileReader;
|
||||
use std::fs::File;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("ApiError : {0}")]
|
||||
ApiError(#[from] hf_hub::api::sync::ApiError),
|
||||
|
||||
#[error("IoError : {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
#[error("ParquetError : {0}")]
|
||||
ParquetError(#[from] parquet::errors::ParquetError),
|
||||
}
|
||||
|
||||
fn sibling_to_parquet(
|
||||
rfilename: &str,
|
||||
repo: &ApiRepo,
|
||||
) -> Result<SerializedFileReader<File>, Error> {
|
||||
let local = repo.get(rfilename)?;
|
||||
let file = File::open(local)?;
|
||||
let reader = SerializedFileReader::new(file)?;
|
||||
Ok(reader)
|
||||
}
|
||||
|
||||
pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> {
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let info = repo.info()?;
|
||||
|
||||
let files: Result<Vec<_>, _> = info
|
||||
.siblings
|
||||
.into_iter()
|
||||
.filter_map(|s| -> Option<Result<_, _>> {
|
||||
let filename = s.rfilename;
|
||||
if filename.ends_with(".parquet") {
|
||||
let reader_result = sibling_to_parquet(&filename, &repo);
|
||||
Some(reader_result)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let files = files?;
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use parquet::file::reader::FileReader;
|
||||
|
||||
#[test]
|
||||
fn test_dataset() {
|
||||
let api = Api::new().unwrap();
|
||||
let files = from_hub(
|
||||
&api,
|
||||
"hf-internal-testing/dummy_image_text_data".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(files.len(), 1);
|
||||
assert_eq!(files[0].metadata().file_metadata().num_rows(), 20);
|
||||
}
|
||||
}
|
7
candle-datasets/src/lib.rs
Normal file
7
candle-datasets/src/lib.rs
Normal file
@ -0,0 +1,7 @@
|
||||
//! Datasets & Dataloaders for Candle
|
||||
pub mod batcher;
|
||||
pub mod hub;
|
||||
pub mod nlp;
|
||||
pub mod vision;
|
||||
|
||||
pub use batcher::Batcher;
|
1
candle-datasets/src/nlp/mod.rs
Normal file
1
candle-datasets/src/nlp/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
pub mod tinystories;
|
122
candle-datasets/src/nlp/tinystories.rs
Normal file
122
candle-datasets/src/nlp/tinystories.rs
Normal file
@ -0,0 +1,122 @@
|
||||
//! Helper functions for the tinystories dataset. This uses the pre-tokenized version as generated
|
||||
//! by the tools from https://github.com/karpathy/llama2.c
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
pub struct Dataset {
|
||||
valid_tokens: Vec<memmap2::Mmap>,
|
||||
train_tokens: Vec<memmap2::Mmap>,
|
||||
}
|
||||
|
||||
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
|
||||
Ok(mmap)
|
||||
}
|
||||
|
||||
impl Dataset {
|
||||
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
|
||||
let dir = dir.as_ref();
|
||||
let mut bin_files = vec![];
|
||||
for file in std::fs::read_dir(dir)?.flatten() {
|
||||
let file = file.path();
|
||||
if let Some(extension) = file.extension() {
|
||||
if extension == "bin" {
|
||||
bin_files.push(file)
|
||||
}
|
||||
}
|
||||
}
|
||||
if bin_files.len() < 2 {
|
||||
candle::bail!("found less than two bin files in {:?}", dir)
|
||||
}
|
||||
bin_files.sort();
|
||||
let valid_tokens = mmap_file(&bin_files[0])?;
|
||||
let train_tokens = bin_files[1..]
|
||||
.iter()
|
||||
.map(mmap_file)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
valid_tokens: vec![valid_tokens],
|
||||
train_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn train_tokens(&self) -> usize {
|
||||
self.train_tokens.len()
|
||||
}
|
||||
|
||||
pub fn valid_tokens(&self) -> usize {
|
||||
self.valid_tokens.len()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DatasetRandomIter<'a> {
|
||||
all_tokens: &'a [memmap2::Mmap],
|
||||
tokens: Vec<&'a memmap2::Mmap>,
|
||||
current_tokens: &'a memmap2::Mmap,
|
||||
indexes_in_bytes: Vec<usize>,
|
||||
seq_len: usize,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl<'a> DatasetRandomIter<'a> {
|
||||
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let all_tokens = if valid {
|
||||
&ds.valid_tokens
|
||||
} else {
|
||||
&ds.train_tokens
|
||||
};
|
||||
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
||||
tokens.shuffle(&mut thread_rng());
|
||||
let current_tokens = tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = seq_len * 2;
|
||||
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
Self {
|
||||
all_tokens,
|
||||
tokens,
|
||||
current_tokens,
|
||||
indexes_in_bytes,
|
||||
seq_len,
|
||||
device,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for DatasetRandomIter<'a> {
|
||||
type Item = Result<(Tensor, Tensor)>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let seq_len = self.seq_len;
|
||||
if self.indexes_in_bytes.is_empty() {
|
||||
if self.tokens.is_empty() {
|
||||
self.tokens = self.all_tokens.iter().collect();
|
||||
self.tokens.shuffle(&mut thread_rng());
|
||||
}
|
||||
self.current_tokens = self.tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = self.seq_len * 2;
|
||||
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
self.indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
}
|
||||
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
||||
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
||||
let mut tokens = vec![0u16; bytes.len() / 2];
|
||||
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
|
||||
return Some(Err(err.into()));
|
||||
}
|
||||
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
|
||||
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
|
||||
let targets = Tensor::new(&tokens[1..], &self.device);
|
||||
Some(candle::error::zip(inputs, targets))
|
||||
}
|
||||
}
|
122
candle-datasets/src/vision/mnist.rs
Normal file
122
candle-datasets/src/vision/mnist.rs
Normal file
@ -0,0 +1,122 @@
|
||||
//! The MNIST hand-written digit dataset.
|
||||
//!
|
||||
//! The files can be obtained from the following link:
|
||||
//! <http://yann.lecun.com/exdb/mnist/>
|
||||
use candle::{DType, Device, Error, Result, Tensor};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufReader, Read};
|
||||
|
||||
fn read_u32<T: Read>(reader: &mut T) -> Result<u32> {
|
||||
let mut b = vec![0u8; 4];
|
||||
reader.read_exact(&mut b)?;
|
||||
let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| {
|
||||
(s + basis * u64::from(x), basis * 256)
|
||||
});
|
||||
Ok(result as u32)
|
||||
}
|
||||
|
||||
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
|
||||
let magic_number = read_u32(reader)?;
|
||||
if magic_number != expected {
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("incorrect magic number {magic_number} != {expected}"),
|
||||
))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_labels(filename: &std::path::Path) -> Result<Tensor> {
|
||||
let mut buf_reader = BufReader::new(File::open(filename)?);
|
||||
check_magic_number(&mut buf_reader, 2049)?;
|
||||
let samples = read_u32(&mut buf_reader)?;
|
||||
let mut data = vec![0u8; samples as usize];
|
||||
buf_reader.read_exact(&mut data)?;
|
||||
let samples = data.len();
|
||||
Tensor::from_vec(data, samples, &Device::Cpu)
|
||||
}
|
||||
|
||||
fn read_images(filename: &std::path::Path) -> Result<Tensor> {
|
||||
let mut buf_reader = BufReader::new(File::open(filename)?);
|
||||
check_magic_number(&mut buf_reader, 2051)?;
|
||||
let samples = read_u32(&mut buf_reader)? as usize;
|
||||
let rows = read_u32(&mut buf_reader)? as usize;
|
||||
let cols = read_u32(&mut buf_reader)? as usize;
|
||||
let data_len = samples * rows * cols;
|
||||
let mut data = vec![0u8; data_len];
|
||||
buf_reader.read_exact(&mut data)?;
|
||||
let tensor = Tensor::from_vec(data, (samples, rows * cols), &Device::Cpu)?;
|
||||
tensor.to_dtype(DType::F32)? / 255.
|
||||
}
|
||||
|
||||
pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Dataset> {
|
||||
let dir = dir.as_ref();
|
||||
let train_images = read_images(&dir.join("train-images-idx3-ubyte"))?;
|
||||
let train_labels = read_labels(&dir.join("train-labels-idx1-ubyte"))?;
|
||||
let test_images = read_images(&dir.join("t10k-images-idx3-ubyte"))?;
|
||||
let test_labels = read_labels(&dir.join("t10k-labels-idx1-ubyte"))?;
|
||||
Ok(crate::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
|
||||
let samples = parquet.metadata().file_metadata().num_rows() as usize;
|
||||
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 784);
|
||||
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
|
||||
for row in parquet.into_iter().flatten() {
|
||||
for (_name, field) in row.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
} else if let parquet::record::Field::Long(label) = field {
|
||||
buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
/ 255.)?;
|
||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||
Ok((images, labels))
|
||||
}
|
||||
|
||||
pub fn load() -> Result<crate::vision::Dataset> {
|
||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let dataset_id = "mnist".to_string();
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo
|
||||
.get("mnist/test/0000.parquet")
|
||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let train_parquet_filename = repo
|
||||
.get("mnist/train/0000.parquet")
|
||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
|
||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
||||
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
||||
Ok(crate::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
})
|
||||
}
|
@ -10,10 +10,12 @@ license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@ -21,26 +23,33 @@ num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
hf-hub = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
imageproc = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rusttype = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
|
@ -1,5 +1,8 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
mod model;
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
@ -39,6 +42,10 @@ struct Args {
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -55,16 +62,16 @@ impl Args {
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
||||
let cache = Cache::default();
|
||||
let cache = Cache::default().repo(repo);
|
||||
(
|
||||
cache
|
||||
.get(&repo, "config.json")
|
||||
.get("config.json")
|
||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
||||
cache
|
||||
.get(&repo, "tokenizer.json")
|
||||
.get("tokenizer.json")
|
||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
||||
cache
|
||||
.get(&repo, "model.safetensors")
|
||||
.get("model.safetensors")
|
||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
||||
)
|
||||
} else {
|
||||
@ -107,7 +114,10 @@ fn main() -> Result<()> {
|
||||
let device = &model.device;
|
||||
|
||||
if let Some(prompt) = args.prompt {
|
||||
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
@ -164,7 +174,13 @@ fn main() -> Result<()> {
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if args.normalize_embeddings {
|
||||
normalize_l2(&embeddings)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
|
||||
let mut similarities = vec![];
|
||||
for i in 0..n_sentences {
|
||||
let e_i = embeddings.get(i)?;
|
||||
@ -184,3 +200,7 @@ fn main() -> Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, VarBuilder};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
|
@ -1,6 +1,9 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
@ -65,10 +68,7 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
let token = self
|
||||
.tokenizer
|
||||
.decode(vec![next_token], true)
|
||||
.map_err(E::msg)?;
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
|
@ -2,19 +2,16 @@
|
||||
// own forward pass (CPU and GPU versions) as well as their backward pass.
|
||||
//
|
||||
// In this example we add the RMS normalization operation and implement it for f32.
|
||||
#![allow(dead_code)]
|
||||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[allow(unused)]
|
||||
mod cuda_kernels;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cpu_backend;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, Layout, Result, Shape, Tensor};
|
||||
use candle::{CpuStorage, CustomOp1, Layout, Result, Shape, Tensor};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -57,8 +54,9 @@ impl CustomOp1 for LayerNorm {
|
||||
storage: &candle::CudaStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::{cudarc, WrapErr};
|
||||
use cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||
use candle::cuda_backend::WrapErr;
|
||||
let (d1, d2) = layout.shape().dims2()?;
|
||||
let d1 = d1 as u32;
|
||||
let d2 = d2 as u32;
|
||||
@ -89,7 +87,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
|
||||
println!("{t}");
|
||||
let t = t.custom_op1(LayerNorm { eps: 1e-5 })?;
|
||||
let t = t.apply_op1(LayerNorm { eps: 1e-5 })?;
|
||||
println!("{t}");
|
||||
Ok(())
|
||||
}
|
||||
|
339
candle-examples/examples/dinov2/main.rs
Normal file
339
candle-examples/examples/dinov2/main.rs
Normal file
@ -0,0 +1,339 @@
|
||||
//! DINOv2: Learning Robust Visual Features without Supervision
|
||||
//! https://github.com/facebookresearch/dinov2
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const IMG_SIZE: usize = 518;
|
||||
const PATCH_SIZE: usize = 14;
|
||||
const NUM_CLASSES: usize = 1000;
|
||||
|
||||
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
num_heads: usize,
|
||||
scale: f64,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
proj_bias: bool,
|
||||
) -> Result<Self> {
|
||||
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
|
||||
let scale = 1. / ((dim / num_heads) as f64).sqrt();
|
||||
Ok(Self {
|
||||
qkv,
|
||||
proj,
|
||||
num_heads,
|
||||
scale,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, c) = xs.dims3()?;
|
||||
let qkv = self
|
||||
.qkv
|
||||
.forward(xs)?
|
||||
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
|
||||
.transpose(1, 2)? // 02134
|
||||
.transpose(0, 1)? // 20134
|
||||
.transpose(2, 3)?; // 20314
|
||||
let q = (qkv.i(0)? * self.scale)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
|
||||
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
|
||||
self.proj.forward(&attn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LayerScale {
|
||||
gamma: Tensor,
|
||||
}
|
||||
|
||||
impl LayerScale {
|
||||
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
|
||||
let gamma = vb.get(dim, "gamma")?;
|
||||
Ok(Self { gamma })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerScale {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.broadcast_mul(&self.gamma)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Mlp {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
|
||||
let out_features = in_features;
|
||||
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
|
||||
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
|
||||
Ok(Self { fc1, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?.gelu()?;
|
||||
self.fc2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Block {
|
||||
norm1: LayerNorm,
|
||||
attn: Attention,
|
||||
ls1: LayerScale,
|
||||
norm2: LayerNorm,
|
||||
mlp: Mlp,
|
||||
ls2: LayerScale,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
|
||||
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
|
||||
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
|
||||
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
|
||||
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
|
||||
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
ls1,
|
||||
norm2,
|
||||
mlp,
|
||||
ls2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self
|
||||
.ls1
|
||||
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = self
|
||||
.ls2
|
||||
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
proj: candle_nn::Conv2d,
|
||||
patch_size: (usize, usize),
|
||||
num_patches: usize,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
img_size: usize,
|
||||
patch_size: usize,
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
) -> Result<Self> {
|
||||
let config = candle_nn::Conv2dConfig {
|
||||
stride: patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
|
||||
let num_patches = (img_size / patch_size) * (img_size / patch_size);
|
||||
Ok(Self {
|
||||
proj,
|
||||
patch_size: (patch_size, patch_size),
|
||||
num_patches,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _c, h, w) = xs.dims4()?;
|
||||
let (patch_h, patch_w) = self.patch_size;
|
||||
if (h % patch_h) != 0 {
|
||||
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
|
||||
}
|
||||
if (w % patch_w) != 0 {
|
||||
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
|
||||
}
|
||||
let xs = self.proj.forward(xs)?;
|
||||
let (b, c, h, w) = xs.dims4()?;
|
||||
// flatten embeddings.
|
||||
xs.reshape((b, c, h * w))?.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DinoVisionTransformer {
|
||||
patch_embed: PatchEmbed,
|
||||
cls_token: Tensor,
|
||||
pos_embed: Tensor,
|
||||
blocks: Vec<Block>,
|
||||
norm: LayerNorm,
|
||||
head: Linear,
|
||||
}
|
||||
|
||||
impl DinoVisionTransformer {
|
||||
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let patch_embed =
|
||||
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
|
||||
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
|
||||
let num_tokens = 1;
|
||||
let pos_embed = vb.get(
|
||||
(1, patch_embed.num_patches + num_tokens, embed_dim),
|
||||
"pos_embed",
|
||||
)?;
|
||||
let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
|
||||
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
|
||||
let vb_b = vb.pp("blocks");
|
||||
let blocks = (0..depth)
|
||||
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
cls_token,
|
||||
pos_embed,
|
||||
blocks,
|
||||
norm,
|
||||
head,
|
||||
})
|
||||
}
|
||||
|
||||
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
|
||||
let npatch = xs.dim(1)? - 1;
|
||||
let n = self.pos_embed.dim(1)? - 1;
|
||||
let sqrt_n = (n as f64).sqrt();
|
||||
if npatch == n && w == h {
|
||||
return Ok(xs.clone());
|
||||
}
|
||||
let class_pos_embed = self.pos_embed.i((.., ..1))?;
|
||||
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
|
||||
let dim = xs.dim(D::Minus1)?;
|
||||
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
|
||||
let patch_pos_embed = patch_pos_embed
|
||||
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
|
||||
.transpose(2, 3)?
|
||||
.transpose(1, 2)?;
|
||||
// This uses bicubic interpolation in the original implementation.
|
||||
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
|
||||
let el_count = patch_pos_embed.shape().elem_count();
|
||||
let patch_pos_embed =
|
||||
patch_pos_embed
|
||||
.transpose(1, 2)?
|
||||
.transpose(2, 3)?
|
||||
.reshape((1, el_count / dim, dim))?;
|
||||
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
|
||||
}
|
||||
|
||||
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _nc, w, h) = xs.dims4()?;
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
|
||||
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DinoVisionTransformer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.prepare_tokens_with_mask(xs)?;
|
||||
for blk in self.blocks.iter() {
|
||||
xs = blk.forward(&xs)?
|
||||
}
|
||||
let xs = self.norm.forward(&xs)?;
|
||||
let xs_norm_clstoken = xs.i((.., 0))?;
|
||||
let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
|
||||
let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
|
||||
self.head.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
|
||||
DinoVisionTransformer::new(vb, 12, 384, 6)
|
||||
}
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-dino-v2".into());
|
||||
api.get("dinov2_vits14.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let model = vit_small(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
429
candle-examples/examples/efficientnet/main.rs
Normal file
429
candle-examples/examples/efficientnet/main.rs
Normal file
@ -0,0 +1,429 @@
|
||||
//! EfficientNet implementation.
|
||||
//!
|
||||
//! https://arxiv.org/abs/1905.11946
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use nn::{Module, VarBuilder};
|
||||
|
||||
// Based on the Python version from torchvision.
|
||||
// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct MBConvConfig {
|
||||
expand_ratio: f64,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
input_channels: usize,
|
||||
out_channels: usize,
|
||||
num_layers: usize,
|
||||
}
|
||||
|
||||
fn make_divisible(v: f64, divisor: usize) -> usize {
|
||||
let min_value = divisor;
|
||||
let new_v = usize::max(
|
||||
min_value,
|
||||
(v + divisor as f64 * 0.5) as usize / divisor * divisor,
|
||||
);
|
||||
if (new_v as f64) < 0.9 * v {
|
||||
new_v + divisor
|
||||
} else {
|
||||
new_v
|
||||
}
|
||||
}
|
||||
|
||||
fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
|
||||
let bneck_conf = |e, k, s, i, o, n| {
|
||||
let input_channels = make_divisible(i as f64 * width_mult, 8);
|
||||
let out_channels = make_divisible(o as f64 * width_mult, 8);
|
||||
let num_layers = (n as f64 * depth_mult).ceil() as usize;
|
||||
MBConvConfig {
|
||||
expand_ratio: e,
|
||||
kernel: k,
|
||||
stride: s,
|
||||
input_channels,
|
||||
out_channels,
|
||||
num_layers,
|
||||
}
|
||||
};
|
||||
vec![
|
||||
bneck_conf(1., 3, 1, 32, 16, 1),
|
||||
bneck_conf(6., 3, 2, 16, 24, 2),
|
||||
bneck_conf(6., 5, 2, 24, 40, 2),
|
||||
bneck_conf(6., 3, 2, 40, 80, 3),
|
||||
bneck_conf(6., 5, 1, 80, 112, 3),
|
||||
bneck_conf(6., 5, 2, 112, 192, 4),
|
||||
bneck_conf(6., 3, 1, 192, 320, 1),
|
||||
]
|
||||
}
|
||||
|
||||
impl MBConvConfig {
|
||||
fn b0() -> Vec<Self> {
|
||||
bneck_confs(1.0, 1.0)
|
||||
}
|
||||
fn b1() -> Vec<Self> {
|
||||
bneck_confs(1.0, 1.1)
|
||||
}
|
||||
fn b2() -> Vec<Self> {
|
||||
bneck_confs(1.1, 1.2)
|
||||
}
|
||||
fn b3() -> Vec<Self> {
|
||||
bneck_confs(1.2, 1.4)
|
||||
}
|
||||
fn b4() -> Vec<Self> {
|
||||
bneck_confs(1.4, 1.8)
|
||||
}
|
||||
fn b5() -> Vec<Self> {
|
||||
bneck_confs(1.6, 2.2)
|
||||
}
|
||||
fn b6() -> Vec<Self> {
|
||||
bneck_confs(1.8, 2.6)
|
||||
}
|
||||
fn b7() -> Vec<Self> {
|
||||
bneck_confs(2.0, 3.1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Conv2D with same padding.
|
||||
#[derive(Debug)]
|
||||
struct Conv2DSame {
|
||||
conv2d: nn::Conv2d,
|
||||
s: usize,
|
||||
k: usize,
|
||||
}
|
||||
|
||||
impl Conv2DSame {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
i: usize,
|
||||
o: usize,
|
||||
k: usize,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
bias: bool,
|
||||
) -> Result<Self> {
|
||||
let conv_config = nn::Conv2dConfig {
|
||||
stride,
|
||||
groups,
|
||||
..Default::default()
|
||||
};
|
||||
let conv2d = if bias {
|
||||
nn::conv2d(i, o, k, conv_config, vb)?
|
||||
} else {
|
||||
nn::conv2d_no_bias(i, o, k, conv_config, vb)?
|
||||
};
|
||||
Ok(Self {
|
||||
conv2d,
|
||||
s: stride,
|
||||
k,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Conv2DSame {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let s = self.s;
|
||||
let k = self.k;
|
||||
let (_, _, ih, iw) = xs.dims4()?;
|
||||
let oh = (ih + s - 1) / s;
|
||||
let ow = (iw + s - 1) / s;
|
||||
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
|
||||
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
|
||||
if pad_h > 0 || pad_w > 0 {
|
||||
let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
|
||||
let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
|
||||
self.conv2d.forward(&xs)
|
||||
} else {
|
||||
self.conv2d.forward(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConvNormActivation {
|
||||
conv2d: Conv2DSame,
|
||||
bn2d: nn::BatchNorm,
|
||||
activation: bool,
|
||||
}
|
||||
|
||||
impl ConvNormActivation {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
i: usize,
|
||||
o: usize,
|
||||
k: usize,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
|
||||
let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
|
||||
Ok(Self {
|
||||
conv2d,
|
||||
bn2d,
|
||||
activation: true,
|
||||
})
|
||||
}
|
||||
|
||||
fn no_activation(self) -> Self {
|
||||
Self {
|
||||
activation: false,
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ConvNormActivation {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.conv2d.forward(xs)?;
|
||||
let xs = self.bn2d.forward(&xs)?;
|
||||
if self.activation {
|
||||
swish(&xs)
|
||||
} else {
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SqueezeExcitation {
|
||||
fc1: Conv2DSame,
|
||||
fc2: Conv2DSame,
|
||||
}
|
||||
|
||||
impl SqueezeExcitation {
|
||||
fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
|
||||
let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
|
||||
let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
|
||||
Ok(Self { fc1, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SqueezeExcitation {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
// equivalent to adaptive_avg_pool2d([1, 1])
|
||||
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
|
||||
let xs = self.fc1.forward(&xs)?;
|
||||
let xs = swish(&xs)?;
|
||||
let xs = self.fc2.forward(&xs)?;
|
||||
let xs = nn::ops::sigmoid(&xs)?;
|
||||
residual.broadcast_mul(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MBConv {
|
||||
expand_cna: Option<ConvNormActivation>,
|
||||
depthwise_cna: ConvNormActivation,
|
||||
squeeze_excitation: SqueezeExcitation,
|
||||
project_cna: ConvNormActivation,
|
||||
config: MBConvConfig,
|
||||
}
|
||||
|
||||
impl MBConv {
|
||||
fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
|
||||
let vb = vb.pp("block");
|
||||
let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
|
||||
let expand_cna = if exp != c.input_channels {
|
||||
Some(ConvNormActivation::new(
|
||||
vb.pp("0"),
|
||||
c.input_channels,
|
||||
exp,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let start_index = if expand_cna.is_some() { 1 } else { 0 };
|
||||
let depthwise_cna =
|
||||
ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
|
||||
let squeeze_channels = usize::max(1, c.input_channels / 4);
|
||||
let squeeze_excitation =
|
||||
SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
|
||||
let project_cna =
|
||||
ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
|
||||
.no_activation();
|
||||
Ok(Self {
|
||||
expand_cna,
|
||||
depthwise_cna,
|
||||
squeeze_excitation,
|
||||
project_cna,
|
||||
config: c,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MBConv {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let use_res_connect =
|
||||
self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
|
||||
let ys = match &self.expand_cna {
|
||||
Some(expand_cna) => expand_cna.forward(xs)?,
|
||||
None => xs.clone(),
|
||||
};
|
||||
let ys = self.depthwise_cna.forward(&ys)?;
|
||||
let ys = self.squeeze_excitation.forward(&ys)?;
|
||||
let ys = self.project_cna.forward(&ys)?;
|
||||
if use_res_connect {
|
||||
ys + xs
|
||||
} else {
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn swish(s: &Tensor) -> Result<Tensor> {
|
||||
s * nn::ops::sigmoid(s)?
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EfficientNet {
|
||||
init_cna: ConvNormActivation,
|
||||
blocks: Vec<MBConv>,
|
||||
final_cna: ConvNormActivation,
|
||||
classifier: nn::Linear,
|
||||
}
|
||||
|
||||
impl EfficientNet {
|
||||
fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
|
||||
let f_p = p.pp("features");
|
||||
let first_in_c = configs[0].input_channels;
|
||||
let last_out_c = configs.last().unwrap().out_channels;
|
||||
let final_out_c = 4 * last_out_c;
|
||||
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
|
||||
let nconfigs = configs.len();
|
||||
let mut blocks = vec![];
|
||||
for (index, cnf) in configs.into_iter().enumerate() {
|
||||
let f_p = f_p.pp(index + 1);
|
||||
for r_index in 0..cnf.num_layers {
|
||||
let cnf = if r_index == 0 {
|
||||
cnf
|
||||
} else {
|
||||
MBConvConfig {
|
||||
input_channels: cnf.out_channels,
|
||||
stride: 1,
|
||||
..cnf
|
||||
}
|
||||
};
|
||||
blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
|
||||
}
|
||||
}
|
||||
let final_cna =
|
||||
ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
|
||||
let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
|
||||
Ok(Self {
|
||||
init_cna,
|
||||
blocks,
|
||||
final_cna,
|
||||
classifier,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EfficientNet {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.init_cna.forward(xs)?;
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
}
|
||||
let xs = self.final_cna.forward(&xs)?;
|
||||
// Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
|
||||
let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
|
||||
self.classifier.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
B0,
|
||||
B1,
|
||||
B2,
|
||||
B3,
|
||||
B4,
|
||||
B5,
|
||||
B6,
|
||||
B7,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Variant of the model to use.
|
||||
#[arg(value_enum, long, default_value_t = Which::B2)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-efficientnet".into());
|
||||
let filename = match args.which {
|
||||
Which::B0 => "efficientnet-b0.safetensors",
|
||||
Which::B1 => "efficientnet-b1.safetensors",
|
||||
Which::B2 => "efficientnet-b2.safetensors",
|
||||
Which::B3 => "efficientnet-b3.safetensors",
|
||||
Which::B4 => "efficientnet-b4.safetensors",
|
||||
Which::B5 => "efficientnet-b5.safetensors",
|
||||
Which::B6 => "efficientnet-b6.safetensors",
|
||||
Which::B7 => "efficientnet-b7.safetensors",
|
||||
};
|
||||
api.get(filename)?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let cfg = match args.which {
|
||||
Which::B0 => MBConvConfig::b0(),
|
||||
Which::B1 => MBConvConfig::b1(),
|
||||
Which::B2 => MBConvConfig::b2(),
|
||||
Which::B3 => MBConvConfig::b3(),
|
||||
Which::B4 => MBConvConfig::b4(),
|
||||
Which::B5 => MBConvConfig::b5(),
|
||||
Which::B6 => MBConvConfig::b6(),
|
||||
Which::B7 => MBConvConfig::b7(),
|
||||
};
|
||||
let model = EfficientNet::new(vb, cfg, candle_examples::imagenet::CLASS_COUNT as usize)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,5 +1,8 @@
|
||||
// TODO: Add an offline mode.
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
@ -19,6 +22,8 @@ struct TextGeneration {
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
@ -28,6 +33,8 @@ impl TextGeneration {
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
device: &Device,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp);
|
||||
Self {
|
||||
@ -35,6 +42,8 @@ impl TextGeneration {
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
device: device.clone(),
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
}
|
||||
}
|
||||
|
||||
@ -60,6 +69,16 @@ impl TextGeneration {
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
@ -69,16 +88,14 @@ impl TextGeneration {
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
self.tokenizer
|
||||
.decode(vec![next_token], true)
|
||||
.map_err(E::msg)?
|
||||
self.tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{sample_len} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
sample_len as f64 / dt.as_secs_f64(),
|
||||
self.tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
self.tokenizer.decode(&new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
@ -115,6 +132,14 @@ struct Args {
|
||||
|
||||
#[arg(long, default_value = "refs/pr/43")]
|
||||
revision: String,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -161,7 +186,15 @@ fn main() -> Result<()> {
|
||||
let model = Falcon::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
&device,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
|
@ -1,199 +0,0 @@
|
||||
# Adapted from:
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
|
||||
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
"""
|
||||
Sample usage:
|
||||
|
||||
```
|
||||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
|
||||
```
|
||||
"""
|
||||
|
||||
INTERMEDIATE_SIZE_MAP = {
|
||||
"7B": 11008,
|
||||
"13B": 13824,
|
||||
"30B": 17920,
|
||||
"65B": 22016,
|
||||
}
|
||||
NUM_SHARDS = {
|
||||
"7B": 1,
|
||||
"13B": 2,
|
||||
"30B": 4,
|
||||
"65B": 8,
|
||||
}
|
||||
|
||||
|
||||
def compute_intermediate_size(n):
|
||||
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(text, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_model(model_path, input_base_path, model_size):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
params = read_json(os.path.join(input_base_path, "params.json"))
|
||||
num_shards = NUM_SHARDS[model_size]
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
n_heads_per_shard = n_heads // num_shards
|
||||
dim = params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w):
|
||||
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
||||
# Load weights
|
||||
if model_size == "7B":
|
||||
# Not sharded
|
||||
# (The sharded implementation would also work, but this is simpler.)
|
||||
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
|
||||
else:
|
||||
# Sharded
|
||||
loaded = [
|
||||
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
|
||||
for i in range(num_shards)
|
||||
]
|
||||
param_count = 0
|
||||
all_dicts = {}
|
||||
for layer_i in range(n_layers):
|
||||
if model_size == "7B":
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wq.weight"]
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wk.weight"]
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
|
||||
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
|
||||
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
|
||||
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
|
||||
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
|
||||
}
|
||||
else:
|
||||
# Sharded
|
||||
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
|
||||
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
|
||||
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
|
||||
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
|
||||
f"layers.{layer_i}.attention_norm.weight"
|
||||
].clone(),
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
|
||||
f"layers.{layer_i}.ffn_norm.weight"
|
||||
].clone(),
|
||||
}
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
|
||||
)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
||||
all_dicts |= state_dict
|
||||
|
||||
if model_size == "7B":
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
|
||||
"model.norm.weight": loaded["norm.weight"],
|
||||
"lm_head.weight": loaded["output.weight"],
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
"model.norm.weight": loaded[0]["norm.weight"],
|
||||
"model.embed_tokens.weight": torch.cat(
|
||||
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
|
||||
),
|
||||
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
|
||||
}
|
||||
all_dicts |= state_dict
|
||||
all_dicts = {k: v.numpy() for k, v in all_dicts.items()}
|
||||
np.savez(os.path.join(model_path, "llama.npz"), **all_dicts)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
choices=["7B", "13B", "30B", "65B"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
write_model(
|
||||
model_path=args.output_dir,
|
||||
input_base_path=os.path.join(args.input_dir, args.model_size),
|
||||
model_size=args.model_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -5,76 +5,28 @@
|
||||
//
|
||||
// The tokenizer config can be retrieved from:
|
||||
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
|
||||
//
|
||||
// In order to convert the llama weights to a .npz file, run:
|
||||
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use anyhow::{bail, Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
mod model;
|
||||
use model::{Config, Llama};
|
||||
use model::{Config, Llama, LlamaConfig};
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_PROMPT: &str = r"
|
||||
EDWARD:
|
||||
I wonder how our princely father 'scaped,
|
||||
Or whether he be 'scaped away or no
|
||||
From Clifford's and Northumberland's pursuit:
|
||||
Had he been ta'en, we should have heard the news;
|
||||
Had he been slain, we should have heard the news;
|
||||
Or had he 'scaped, methinks we should have heard
|
||||
The happy tidings of his good escape.
|
||||
How fares my brother? why is he so sad?
|
||||
|
||||
RICHARD:
|
||||
I cannot joy, until I be resolved
|
||||
Where our right valiant father is become.
|
||||
I saw him in the battle range about;
|
||||
And watch'd him how he singled Clifford forth.
|
||||
Methought he bore him in the thickest troop
|
||||
As doth a lion in a herd of neat;
|
||||
Or as a bear, encompass'd round with dogs,
|
||||
Who having pinch'd a few and made them cry,
|
||||
The rest stand all aloof, and bark at him.
|
||||
So fared our father with his enemies;
|
||||
So fled his enemies my warlike father:
|
||||
Methinks, 'tis prize enough to be his son.
|
||||
See how the morning opes her golden gates,
|
||||
And takes her farewell of the glorious sun!
|
||||
How well resembles it the prime of youth,
|
||||
Trimm'd like a younker prancing to his love!
|
||||
|
||||
EDWARD:
|
||||
Dazzle mine eyes, or do I see three suns?
|
||||
|
||||
RICHARD:
|
||||
Three glorious suns, each one a perfect sun;
|
||||
Not separated with the racking clouds,
|
||||
But sever'd in a pale clear-shining sky.
|
||||
See, see! they join, embrace, and seem to kiss,
|
||||
As if they vow'd some league inviolable:
|
||||
Now are they but one lamp, one light, one sun.
|
||||
In this the heaven figures some event.
|
||||
|
||||
EDWARD:
|
||||
'Tis wondrous strange, the like yet never heard of.
|
||||
I think it cites us, brother, to the field,
|
||||
That we, the sons of brave Plantagenet,
|
||||
Each one already blazing by our meeds,
|
||||
Should notwithstanding join our lights together
|
||||
And over-shine the earth as this the world.
|
||||
Whate'er it bodes, henceforward will I bear
|
||||
Upon my target three fair-shining suns.
|
||||
";
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -107,38 +59,73 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// Use f32 computations rather than f16.
|
||||
/// Use different dtype than f16
|
||||
#[arg(long)]
|
||||
use_f32: bool,
|
||||
dtype: Option<String>,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
v1: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// The folder name that contains safetensor weights and json files
|
||||
/// (same structure as huggingface online)
|
||||
#[arg(long)]
|
||||
local_weights: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = match args.dtype.as_deref() {
|
||||
Some("f16") => DType::F16,
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
let (llama, tokenizer_filename, cache) = match args.npy {
|
||||
Some(filename) => {
|
||||
let config = if args.v1 {
|
||||
Config::config_7b_v1(args.use_flash_attn)
|
||||
} else {
|
||||
Config::config_7b_v2(args.use_flash_attn)
|
||||
};
|
||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
let (llama, tokenizer_filename) = match args.npy {
|
||||
Some(filename) => {
|
||||
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
|
||||
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer)
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer, cache)
|
||||
}
|
||||
None => {
|
||||
let api = Api::new()?;
|
||||
@ -150,16 +137,36 @@ fn main() -> Result<()> {
|
||||
}
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let api = api.model(model_id);
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
|
||||
let tokenizer_filename = match &args.local_weights {
|
||||
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
||||
_ => api.get("tokenizer.json")?,
|
||||
};
|
||||
|
||||
let config_filename = match &args.local_weights {
|
||||
Some(path) => (path.to_owned() + "config.json").into(),
|
||||
_ => api.get("config.json")?,
|
||||
};
|
||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
match &args.local_weights {
|
||||
Some(path) => {
|
||||
filenames.push((path.to_owned() + rfilename).into());
|
||||
}
|
||||
_ => {
|
||||
let filename = api.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
println!("building the model");
|
||||
let handles = filenames
|
||||
@ -170,12 +177,14 @@ fn main() -> Result<()> {
|
||||
.iter()
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename)
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
}
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
@ -184,12 +193,12 @@ fn main() -> Result<()> {
|
||||
.to_vec();
|
||||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let context_size = if cache.use_kv_cache && index > 0 {
|
||||
1
|
||||
} else {
|
||||
@ -199,25 +208,40 @@ fn main() -> Result<()> {
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, index_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
println!(
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||
);
|
||||
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if Some(next_token) == eos_token_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
"\n\n{} tokens generated ({} token/s)\n",
|
||||
token_generated,
|
||||
token_generated as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,20 +1,54 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct LlamaConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: Option<usize>,
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
impl LlamaConfig {
|
||||
pub fn into_config(self, use_flash_attn: bool) -> Config {
|
||||
Config {
|
||||
hidden_size: self.hidden_size,
|
||||
intermediate_size: self.intermediate_size,
|
||||
vocab_size: self.vocab_size,
|
||||
num_hidden_layers: self.num_hidden_layers,
|
||||
num_attention_heads: self.num_attention_heads,
|
||||
num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
|
||||
rms_norm_eps: self.rms_norm_eps,
|
||||
rope_theta: self.rope_theta,
|
||||
use_flash_attn,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub n_layer: usize,
|
||||
pub n_head: usize,
|
||||
pub n_embd: usize,
|
||||
pub n_key_value_head: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub use_flash_attn: bool,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f32,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -23,12 +57,12 @@ impl Config {
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
n_layer: 32,
|
||||
n_head: 32,
|
||||
n_embd: 4096,
|
||||
n_key_value_head: 32,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 32,
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-6,
|
||||
rope_theta: 10_000.0,
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,16 +71,31 @@ impl Config {
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
n_layer: 32,
|
||||
n_head: 32,
|
||||
n_embd: 4096,
|
||||
n_key_value_head: 32,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 32,
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||
// model.
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
@ -61,10 +110,10 @@ pub struct Cache {
|
||||
impl Cache {
|
||||
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
|
||||
// precompute freqs_cis
|
||||
let n_elem = config.n_embd / config.n_head;
|
||||
let n_elem = config.hidden_size / config.num_attention_heads;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
@ -79,7 +128,7 @@ impl Cache {
|
||||
Ok(Self {
|
||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||
use_kv_cache,
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
|
||||
device: device.clone(),
|
||||
cos,
|
||||
sin,
|
||||
@ -101,13 +150,10 @@ impl Cache {
|
||||
}
|
||||
}
|
||||
|
||||
fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
Ok(Linear::new(weight, None))
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
@ -116,32 +162,20 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
eps: f64,
|
||||
inner: candle_nn::RmsNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let scale = vb.get(size, "weight")?;
|
||||
Ok(Self { scale, eps })
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let in_dtype = x.dtype();
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
||||
let size = self.scale.dims1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((b_sz, seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
let x = x.to_dtype(in_dtype)?;
|
||||
Ok(x)
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
@ -150,11 +184,13 @@ struct CausalSelfAttention {
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
n_head: usize,
|
||||
n_key_value_head: usize,
|
||||
num_attention_heads: usize,
|
||||
num_key_value_heads: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
@ -175,32 +211,34 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
||||
let _enter = self.span_rot.enter();
|
||||
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
let v = self.v_proj.forward(x)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
@ -249,13 +287,13 @@ impl CausalSelfAttention {
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
|
||||
};
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.n_head / self.n_key_value_head;
|
||||
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
@ -263,15 +301,17 @@ impl CausalSelfAttention {
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
|
||||
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let size_in = cfg.hidden_size;
|
||||
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
|
||||
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
|
||||
let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
|
||||
let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
|
||||
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
|
||||
@ -281,11 +321,13 @@ impl CausalSelfAttention {
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
n_head: cfg.n_head,
|
||||
n_key_value_head: cfg.n_key_value_head,
|
||||
head_dim: cfg.hidden_size / cfg.n_head,
|
||||
num_attention_heads: cfg.num_attention_heads,
|
||||
num_key_value_heads: cfg.num_key_value_heads,
|
||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||
cache: cache.clone(),
|
||||
use_flash_attn: cfg.use_flash_attn,
|
||||
span,
|
||||
span_rot,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -301,15 +343,18 @@ struct Mlp {
|
||||
c_fc1: Linear,
|
||||
c_fc2: Linear,
|
||||
c_proj: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
let _enter = self.span.enter();
|
||||
let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||
let h_size = cfg.hidden_size;
|
||||
let i_size = cfg.intermediate_size;
|
||||
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
||||
@ -319,6 +364,7 @@ impl Mlp {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -328,10 +374,12 @@ struct Block {
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
@ -341,6 +389,7 @@ impl Block {
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "block");
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
@ -354,6 +403,7 @@ impl Block {
|
||||
attn,
|
||||
rms_2,
|
||||
mlp,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -382,7 +432,7 @@ impl Llama {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layer)
|
||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
|
||||
.collect();
|
||||
|
||||
|
@ -1,5 +1,8 @@
|
||||
// https://github.com/karpathy/llama2.c
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
@ -27,7 +30,7 @@ struct InferenceCmd {
|
||||
#[arg(long, default_value = "")]
|
||||
prompt: String,
|
||||
|
||||
/// Config file in binary format.
|
||||
/// Config file in binary or safetensors format.
|
||||
#[arg(long)]
|
||||
config: Option<String>,
|
||||
|
||||
@ -100,6 +103,14 @@ pub struct Args {
|
||||
/// Tokenizer config file.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -200,7 +211,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
|
||||
}
|
||||
});
|
||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
for inp_tgt in batch_iter {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
@ -225,11 +236,22 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (vb, config) = if is_safetensors {
|
||||
let config = Config::tiny();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
(vb, config)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
println!("{config:?}");
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
(vb, config)
|
||||
};
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
@ -254,6 +276,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(common_args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
common_args.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
|
@ -1,6 +1,6 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::linear_no_bias as linear;
|
||||
use candle_nn::{embedding, Embedding, Linear, VarBuilder};
|
||||
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -94,32 +94,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?;
|
||||
Ok(Self { scale, eps })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
||||
let size = self.scale.dims1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((b_sz, seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
@ -290,9 +264,9 @@ impl Block {
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
|
||||
rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
|
||||
Ok(Self::new(
|
||||
input_layernorm,
|
||||
attn,
|
||||
@ -325,7 +299,7 @@ impl Llama {
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
||||
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
|
||||
.collect();
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user