mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
451 Commits
ivarflakst
...
0.7.2
Author | SHA1 | Date | |
---|---|---|---|
3a3c48b14b | |||
261ed65f36 | |||
62525e8352 | |||
2c25754281 | |||
ed48f54b54 | |||
ad8a4c5e5a | |||
c3c392f45c | |||
a0184a4fe4 | |||
10d47183c0 | |||
d01207dbf3 | |||
8097559c1a | |||
829dcfa8dc | |||
c2fca0ca11 | |||
844d45cde4 | |||
af2104078f | |||
5fc4f17727 | |||
c58c5d5b01 | |||
382c6b51af | |||
6eea45a761 | |||
ebf722b446 | |||
c09afc211c | |||
b60faebea4 | |||
72d649058b | |||
0cb0bd1dfa | |||
afb6575835 | |||
5635650d38 | |||
13b2a8a4a0 | |||
e3261216b1 | |||
c02b7c3272 | |||
86613c00e2 | |||
29e25c458d | |||
aafa24ed93 | |||
fdc2622686 | |||
ccdbe87639 | |||
2ec8729d51 | |||
e3c146ada6 | |||
1e96b8b695 | |||
a8288b7a72 | |||
6070278a31 | |||
b47c0bc475 | |||
14fd2d97e0 | |||
31a1075f4b | |||
236b29ff15 | |||
58197e1896 | |||
736d8eb752 | |||
7cff5898ec | |||
b75ef051cf | |||
c1b9e07e35 | |||
69fdcfe96a | |||
2b75dd9551 | |||
53ce65f706 | |||
68aa9c7320 | |||
35e5f31397 | |||
d3fe989d08 | |||
14db029494 | |||
6e6c1c99b0 | |||
b7d9af00cc | |||
59bbc0d287 | |||
dfdce2b602 | |||
500c9f2882 | |||
2be9bd211e | |||
89eae41efd | |||
c0a559d427 | |||
aa7ac1832d | |||
19db6b9723 | |||
0fcb40b229 | |||
6991a37b94 | |||
9ca277a9d7 | |||
2e9c010609 | |||
ac51f477eb | |||
d4b6f6eef6 | |||
957d604a78 | |||
ce90287f45 | |||
1ba87a9450 | |||
bd80078acf | |||
fea46cb719 | |||
8696cf6494 | |||
4a52aeb437 | |||
24d54d0ff9 | |||
636eff652a | |||
0f5cbb08b3 | |||
ddafc61055 | |||
a925ae6bc6 | |||
6056fd5c90 | |||
ebc9aa60bc | |||
2489a606fe | |||
3c815b1dca | |||
42891cc613 | |||
f25173d68b | |||
6a4741bbf9 | |||
30cdd769f9 | |||
d74fbed334 | |||
c63048d374 | |||
a226a9736b | |||
25960676ca | |||
9cd54aa5d4 | |||
eec11ce2ce | |||
9182f9f5c2 | |||
ecff05d72b | |||
7f1ba8038c | |||
74e9e41911 | |||
e27aac0a06 | |||
a3dd87f15e | |||
242e006bbb | |||
6baa1d486b | |||
36cf54525d | |||
2b10aaa05d | |||
9f804af29d | |||
54ff971e35 | |||
b9fac7ec00 | |||
f65e90e7ef | |||
d39462856b | |||
cb180eb23a | |||
9182c828e6 | |||
3f13ad3d79 | |||
cd4d941ed1 | |||
03344d3c19 | |||
1ec3b2cc18 | |||
f7773d498a | |||
7abc3b8cd7 | |||
46012ed31f | |||
f3fade3b03 | |||
ea260aeffd | |||
0814dfd148 | |||
3ceca9901a | |||
1df2bddccf | |||
6f0b807ffd | |||
d54e02d73d | |||
45e235a747 | |||
31cf64147b | |||
77ea479a18 | |||
72e7ca529a | |||
7ff921c538 | |||
9b8537a62f | |||
7ebc3548e1 | |||
eefc1c77ef | |||
01545f7303 | |||
349c3e806a | |||
bdaa34216a | |||
cc80e065e5 | |||
13c64f6828 | |||
21f82a5155 | |||
9cff7bc3f4 | |||
d9bc5ec151 | |||
84328e2b60 | |||
82b641fd27 | |||
01794dc16e | |||
a75cd8164f | |||
b13a82a438 | |||
59b18d974e | |||
89f53b9d7b | |||
a09d451d11 | |||
fa06f5f5f9 | |||
09d4845aa8 | |||
a0d03aded1 | |||
3bbb88fcb4 | |||
ed7b99f525 | |||
287013ef28 | |||
eb26e2467e | |||
c68ed8963f | |||
e5c8b88f90 | |||
805f3be8e1 | |||
3b429f3023 | |||
96a48e5cc4 | |||
6cf82fd7a3 | |||
cfab6e7616 | |||
11d4a3c588 | |||
9d3f1c8af5 | |||
7211009179 | |||
6fadaf2eff | |||
8a05743a21 | |||
b2e816752b | |||
618ecf5e23 | |||
267601eec1 | |||
08a15cb79e | |||
c388be93e7 | |||
d22f1d4f4e | |||
0067fe00a8 | |||
587ee3bb6f | |||
dd78422701 | |||
9215e9ce8c | |||
52ae332910 | |||
8b390ddd29 | |||
c97d639fa0 | |||
b45c710dbf | |||
9c532aef47 | |||
f7a6468238 | |||
2b93dffe64 | |||
e6ee7ba4d4 | |||
1690ab45d2 | |||
8de0ce6cba | |||
ce6d08df94 | |||
2817643db9 | |||
4d14777673 | |||
f135b7963d | |||
af955f260c | |||
8ad822a983 | |||
e198bb0816 | |||
f7d5bf5b97 | |||
c119600d6e | |||
c449f65b12 | |||
db7dbf3071 | |||
4ecedb1598 | |||
53e5380bf6 | |||
50e49ecc5f | |||
4c88c3ce06 | |||
8b8fb630df | |||
fb805b8ca2 | |||
79e3bec789 | |||
e6d412b156 | |||
26cbbf8d84 | |||
2bf413caa3 | |||
3ad4770eb6 | |||
a0460cd2b1 | |||
b81ecf712d | |||
a4d5a414e3 | |||
798e0335cd | |||
718671a0d5 | |||
c5fe4a7f89 | |||
7f354473cf | |||
33c9b66554 | |||
9fd52b3b71 | |||
e662431acf | |||
ab892274d1 | |||
b869a659ec | |||
88f7793598 | |||
2ac302a5d1 | |||
ace282e5c2 | |||
c87381fc96 | |||
c5626b8271 | |||
e6a5b82ba6 | |||
5aebe53dd2 | |||
f76bb7794a | |||
30b145150f | |||
f48c07e242 | |||
8967c46563 | |||
1e46cf8b19 | |||
bd8db2a771 | |||
318d143224 | |||
2be1a35710 | |||
26226068a4 | |||
cd6b9e317c | |||
08c049def3 | |||
d17b2cdad9 | |||
fb918a23c8 | |||
b23436bf90 | |||
be9c200cbb | |||
ea0d8d3753 | |||
308ea070ed | |||
b20acd622c | |||
5522bbc57c | |||
888c09a3db | |||
318cb82f16 | |||
c7557b65dc | |||
cd29c7ccd4 | |||
f9954b73ba | |||
eead1dcead | |||
92f81d2fcb | |||
3144150b8d | |||
b190fd8592 | |||
efe4a0c84b | |||
665da30487 | |||
356a170ae9 | |||
7ecbc6d50b | |||
8ad12a0e81 | |||
eb1b27abcd | |||
708e422456 | |||
c5092f2c29 | |||
cdc8b57b5c | |||
b0340d72ec | |||
b3484e7a5e | |||
ada5d7c096 | |||
13ae5a34c7 | |||
ab86cd37c8 | |||
a9abde5f93 | |||
75b6d4b0da | |||
66f0a4eeea | |||
4523ecfb2a | |||
f5dfe883d7 | |||
196765e995 | |||
60676780a9 | |||
d3a8d291d5 | |||
cd254074f3 | |||
e7f8e72588 | |||
1b98f84a2b | |||
cf7d7fcf2f | |||
8c0db87992 | |||
e2b4829531 | |||
5e70821dd0 | |||
a62a97340c | |||
fdfe8fd129 | |||
790037390c | |||
6f877592a7 | |||
cc856db9ce | |||
fc1fe5e45b | |||
32f567bac4 | |||
fee33b45c2 | |||
6708870e63 | |||
a00e24d752 | |||
c07e4057ab | |||
c0bdd9c7a6 | |||
9563a5fee4 | |||
ec97c98e81 | |||
bb3ee48039 | |||
0c11e055be | |||
18036c6ccb | |||
0fddec762e | |||
74b7f59261 | |||
af7f8b87d3 | |||
b219903d0f | |||
469635a3eb | |||
455c42aa72 | |||
2a8679509e | |||
143c481c20 | |||
f115895b9e | |||
90fc82211f | |||
6a966cf9e0 | |||
04a61a9c72 | |||
58605252e8 | |||
d365ef32d9 | |||
754fa1e813 | |||
184105792f | |||
a15f859ab4 | |||
e316cb6997 | |||
ce9fbc3682 | |||
db8b24ae92 | |||
74bf6994b1 | |||
cdc4c172c4 | |||
e1f9c3776d | |||
3318fe30fb | |||
2bb9c683b9 | |||
ff03fd3fb3 | |||
df5f69444e | |||
0c5eecbc0f | |||
56c9d3ee7b | |||
dd00482ea3 | |||
936f6a4840 | |||
3440cec3a0 | |||
e7fc1daa21 | |||
be5b68cd0b | |||
ea984d0421 | |||
9634583781 | |||
758366160e | |||
0a3487a776 | |||
0c09d10f32 | |||
8a99cf7dd2 | |||
bd9ab9bc04 | |||
8cc0a183ba | |||
6530932285 | |||
924ccae30c | |||
60dc72b96b | |||
20abb72fec | |||
ca5d727ba2 | |||
09e0148cce | |||
de11623752 | |||
21f1d04976 | |||
4fff5b51f5 | |||
314630638d | |||
3e3def4134 | |||
6980774a91 | |||
64d4038e4f | |||
979deaca07 | |||
b485e4b6ee | |||
2c95b7394a | |||
4fd00b8900 | |||
57267cd536 | |||
60ee5cfd4d | |||
56e44aabe3 | |||
d0aca6c3c6 | |||
15e8644149 | |||
0c49e95dfb | |||
205767f9de | |||
5e526abc8c | |||
6400e1b0a0 | |||
32544a2ad6 | |||
badf886583 | |||
918136ba46 | |||
1a6043af51 | |||
2f22afd80e | |||
8d04f70f4d | |||
eeb7e2b683 | |||
11ea7aac4d | |||
32eb56d6b3 | |||
28057781aa | |||
544018b6d0 | |||
c753f72c85 | |||
8013b50829 | |||
45d5322d62 | |||
a2cb2edead | |||
fc67d878bb | |||
3ba37443e5 | |||
1fb728772d | |||
cb86b0c82c | |||
6284ad784c | |||
678d44a7f6 | |||
41416d2376 | |||
5ebcfeaf0f | |||
7c7400fb63 | |||
058a910d0e | |||
26fe162ab5 | |||
121a71e01f | |||
2d5f2a728d | |||
68f7655895 | |||
b60064780d | |||
14010a8498 | |||
0de0795220 | |||
c1b418586c | |||
ad73e93da2 | |||
13c67226e6 | |||
d0aa197b07 | |||
274bf11633 | |||
1e26d539d9 | |||
74497e6bf7 | |||
8ab384e63d | |||
27ffd644a9 | |||
bf20cc854c | |||
42ce593ec6 | |||
67589791d2 | |||
1c8d61f051 | |||
90447bc993 | |||
40ce16001b | |||
5657e596cd | |||
0dee8ea19b | |||
9cadd4e644 | |||
020a979de2 | |||
cdc3823d8f | |||
e5eb9602d0 | |||
b75e8945bc | |||
a90fc5ca5a | |||
adfae2460a | |||
678f64dd27 | |||
b545f54a19 | |||
1ba11f22d6 | |||
982722019b | |||
a83ca2ece0 | |||
153c940a9c | |||
50be8a98ba | |||
58cc896e69 | |||
5cdd84e0f6 | |||
a510ddec4e | |||
d32abbce53 | |||
dfab45e1c8 | |||
96bc704d17 | |||
a52d407ae6 | |||
9e824ec810 | |||
beadb1b434 | |||
6d83d42efb | |||
b6afb46601 | |||
73d79e6092 | |||
b1879f17f6 | |||
4f79f5df8a |
75
.github/workflows/ci_cuda.yaml
vendored
75
.github/workflows/ci_cuda.yaml
vendored
@ -5,49 +5,16 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
start-runner:
|
|
||||||
name: Start self-hosted EC2 runner
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
# Don't run on forks, they won't have access to secrets anyway.
|
|
||||||
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
|
||||||
env:
|
|
||||||
AWS_REGION: us-east-1
|
|
||||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
|
||||||
EC2_INSTANCE_TYPE: g5.xlarge
|
|
||||||
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
|
||||||
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
|
||||||
outputs:
|
|
||||||
label: ${{ steps.start-ec2-runner.outputs.label }}
|
|
||||||
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
|
||||||
steps:
|
|
||||||
- name: Configure AWS credentials
|
|
||||||
uses: aws-actions/configure-aws-credentials@v1
|
|
||||||
with:
|
|
||||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
|
||||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
|
||||||
aws-region: ${{ env.AWS_REGION }}
|
|
||||||
- name: Start EC2 runner
|
|
||||||
id: start-ec2-runner
|
|
||||||
uses: philschmid/philschmid-ec2-github-runner@main
|
|
||||||
with:
|
|
||||||
mode: start
|
|
||||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
|
||||||
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
|
||||||
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
|
||||||
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
|
||||||
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
|
||||||
aws-resource-tags: > # optional, requires additional permissions
|
|
||||||
[
|
|
||||||
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
|
||||||
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
|
||||||
]
|
|
||||||
|
|
||||||
test-cuda:
|
test-cuda:
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
needs: start-runner # required to start the main job when the runner is ready
|
runs-on:
|
||||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
group: aws-g4dn-2xlarge
|
||||||
|
container:
|
||||||
|
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
|
||||||
|
options: --gpus 0
|
||||||
|
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
@ -58,32 +25,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
- name: Install dependencies
|
||||||
|
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
|
||||||
- name: Install Rust Stable
|
- name: Install Rust Stable
|
||||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
uses: actions-rust-lang/setup-rust-toolchain@v1
|
||||||
- uses: Swatinem/rust-cache@v2
|
- uses: Swatinem/rust-cache@v2
|
||||||
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
|
|
||||||
- name: Test (cuda)
|
- name: Test (cuda)
|
||||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
run: cargo test --features cuda
|
||||||
stop-runner:
|
|
||||||
name: Stop self-hosted EC2 runner
|
|
||||||
needs:
|
|
||||||
- start-runner
|
|
||||||
- test-cuda
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
env:
|
|
||||||
AWS_REGION: us-east-1
|
|
||||||
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
|
|
||||||
steps:
|
|
||||||
- name: Configure AWS credentials
|
|
||||||
uses: aws-actions/configure-aws-credentials@v1
|
|
||||||
with:
|
|
||||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
|
||||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
|
||||||
aws-region: ${{ env.AWS_REGION }}
|
|
||||||
- name: Stop EC2 runner
|
|
||||||
uses: philschmid/philschmid-ec2-github-runner@main
|
|
||||||
with:
|
|
||||||
mode: stop
|
|
||||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
|
||||||
label: ${{ needs.start-runner.outputs.label }}
|
|
||||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
|
||||||
|
6
.github/workflows/python.yml
vendored
6
.github/workflows/python.yml
vendored
@ -18,9 +18,9 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest] # For now, only test on Linux
|
os: [ubuntu-latest] # For now, only test on Linux
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Rust
|
- name: Install Rust
|
||||||
uses: actions-rs/toolchain@v1
|
uses: actions-rs/toolchain@v1
|
||||||
@ -65,4 +65,4 @@ jobs:
|
|||||||
working-directory: ./candle-pyo3
|
working-directory: ./candle-pyo3
|
||||||
run: |
|
run: |
|
||||||
source .env/bin/activate
|
source .env/bin/activate
|
||||||
python -m pytest -s -v tests
|
python -m pytest -s -v tests
|
||||||
|
12
.github/workflows/rust-ci.yml
vendored
12
.github/workflows/rust-ci.yml
vendored
@ -1,6 +1,6 @@
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ jobs:
|
|||||||
os: [ubuntu-latest, windows-latest, macOS-latest]
|
os: [ubuntu-latest, windows-latest, macOS-latest]
|
||||||
rust: [stable]
|
rust: [stable]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
profile: minimal
|
profile: minimal
|
||||||
@ -34,7 +34,7 @@ jobs:
|
|||||||
os: [ubuntu-latest, windows-latest, macOS-latest]
|
os: [ubuntu-latest, windows-latest, macOS-latest]
|
||||||
rust: [stable]
|
rust: [stable]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
profile: minimal
|
profile: minimal
|
||||||
@ -49,7 +49,7 @@ jobs:
|
|||||||
name: Rustfmt
|
name: Rustfmt
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
profile: minimal
|
profile: minimal
|
||||||
@ -65,7 +65,7 @@ jobs:
|
|||||||
name: Clippy
|
name: Clippy
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
profile: minimal
|
profile: minimal
|
||||||
|
15
.github/workflows/trufflehog.yml
vendored
Normal file
15
.github/workflows/trufflehog.yml
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
on:
|
||||||
|
push:
|
||||||
|
|
||||||
|
name: Secret Leaks
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
trufflehog:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: Secret Scanning
|
||||||
|
uses: trufflesecurity/trufflehog@main
|
10
.gitignore
vendored
10
.gitignore
vendored
@ -9,6 +9,10 @@ target/
|
|||||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||||
Cargo.lock
|
Cargo.lock
|
||||||
|
|
||||||
|
# editor config
|
||||||
|
.helix
|
||||||
|
.vscode
|
||||||
|
|
||||||
# These are backup files generated by rustfmt
|
# These are backup files generated by rustfmt
|
||||||
**/*.rs.bk
|
**/*.rs.bk
|
||||||
|
|
||||||
@ -36,3 +40,9 @@ candle-wasm-examples/*/package-lock.json
|
|||||||
candle-wasm-examples/**/config*.json
|
candle-wasm-examples/**/config*.json
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.idea/*
|
.idea/*
|
||||||
|
__pycache__
|
||||||
|
out.safetensors
|
||||||
|
out.wav
|
||||||
|
bria.mp3
|
||||||
|
bria.safetensors
|
||||||
|
bria.wav
|
||||||
|
36
Cargo.toml
36
Cargo.toml
@ -9,6 +9,7 @@ members = [
|
|||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
"candle-wasm-examples/*",
|
"candle-wasm-examples/*",
|
||||||
"candle-wasm-tests",
|
"candle-wasm-tests",
|
||||||
|
"tensor-tools",
|
||||||
]
|
]
|
||||||
exclude = [
|
exclude = [
|
||||||
"candle-flash-attn",
|
"candle-flash-attn",
|
||||||
@ -19,7 +20,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.3.3"
|
version = "0.7.2"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -28,48 +29,49 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
|
ab_glyph = "0.2.23"
|
||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core" }
|
candle = { path = "./candle-core", package = "candle-core", version = "0.7.2" }
|
||||||
candle-datasets = { path = "./candle-datasets" }
|
candle-datasets = { path = "./candle-datasets", version = "0.7.2" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn" }
|
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.2" }
|
||||||
candle-kernels = { path = "./candle-kernels" }
|
candle-kernels = { path = "./candle-kernels", version = "0.7.2" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels" }
|
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.2" }
|
||||||
candle-nn = { path = "./candle-nn" }
|
candle-nn = { path = "./candle-nn", version = "0.7.2" }
|
||||||
candle-onnx = { path = "./candle-onnx" }
|
candle-onnx = { path = "./candle-onnx", version = "0.7.2" }
|
||||||
candle-transformers = { path = "./candle-transformers" }
|
candle-transformers = { path = "./candle-transformers", version = "0.7.2" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.3.0"
|
hf-hub = "0.3.0"
|
||||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
hound = "3.5.1"
|
||||||
imageproc = { version = "0.23.0", default-features = false }
|
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
||||||
|
imageproc = { version = "0.24.0", default-features = false }
|
||||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||||
libc = { version = "0.2.147" }
|
libc = { version = "0.2.147" }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||||
num_cpus = "1.15.0"
|
num_cpus = "1.15.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
parquet = { version = "50.0.0" }
|
parquet = { version = "51.0.0" }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
rand_distr = "0.4.3"
|
rand_distr = "0.4.3"
|
||||||
rayon = "1.7.0"
|
rayon = "1.7.0"
|
||||||
rusttype = { version = "0.9", default-features = false }
|
|
||||||
safetensors = "0.4.1"
|
safetensors = "0.4.1"
|
||||||
serde = { version = "1.0.171", features = ["derive"] }
|
serde = { version = "1.0.171", features = ["derive"] }
|
||||||
serde_plain = "1.0.2"
|
serde_plain = "1.0.2"
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tokenizers = { version = "0.15.0", default-features = false }
|
tokenizers = { version = "0.19.1", default-features = false }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
wav = "1.0.0"
|
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "1.1.1", default-features = false }
|
||||||
metal = { version = "0.27.0", features = ["mps"]}
|
metal = { version = "0.27.0", features = ["mps"]}
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
|
72
README.md
72
README.md
@ -60,20 +60,31 @@ These online demos run entirely in your browser:
|
|||||||
|
|
||||||
We also provide a some command line based examples using state of the art models:
|
We also provide a some command line based examples using state of the art models:
|
||||||
|
|
||||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
|
||||||
the SOLAR-10.7B variant.
|
the SOLAR-10.7B variant.
|
||||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level
|
||||||
|
- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM
|
||||||
|
- [Gemma v1 and v2](./candle-examples/examples/gemma/): 2b and 7b+/9b general LLMs from Google Deepmind.
|
||||||
|
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
||||||
|
Griffin based models from Google that mix attention with a RNN like state.
|
||||||
|
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
|
||||||
|
2.7b, and 3.8b general LLMs with performance on par with 7b models.
|
||||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||||
pre-trained on 1T tokens of English and code datasets.
|
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||||
- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
|
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
||||||
|
- [Mamba](./candle-examples/examples/mamba/): an inference only
|
||||||
implementation of the Mamba state space model.
|
implementation of the Mamba state space model.
|
||||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||||
better performance than all publicly available 13b models as of 2023-09-28.
|
better performance than all publicly available 13b models as of 2023-09-28.
|
||||||
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
||||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
||||||
much faster inference.
|
much faster inference.
|
||||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
- [StarCoder](./candle-examples/examples/bigcode/) and
|
||||||
|
[StarCoder2](./candle-examples/examples/starcoder2/): LLM specialized to code generation.
|
||||||
|
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
|
||||||
|
- [RWKV v5 and v6](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
|
||||||
|
performance.
|
||||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||||
@ -103,7 +114,14 @@ We also provide a some command line based examples using state of the art models
|
|||||||
|
|
||||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||||
|
|
||||||
|
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmentation model.
|
||||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||||
|
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
|
||||||
|
model using residual vector quantization.
|
||||||
|
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
|
||||||
|
text-to-speech.
|
||||||
|
- [Parler-TTS](./candle-examples/examples/parler-tts/): large text-to-speech
|
||||||
|
model.
|
||||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||||
@ -111,11 +129,16 @@ We also provide a some command line based examples using state of the art models
|
|||||||
evaluation, segmentation).
|
evaluation, segmentation).
|
||||||
- [VGG](./candle-examples/examples/vgg/),
|
- [VGG](./candle-examples/examples/vgg/),
|
||||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
|
||||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
generate captions for an image.
|
generate captions for an image.
|
||||||
|
- [CLIP](./candle-examples/examples/clip/): multi-model vision and language
|
||||||
|
model.
|
||||||
|
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
||||||
|
dedicated submodels for hand-writing and printed recognition.
|
||||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||||
model, generates the translated text from the input text.
|
model, generates the translated text from the input text.
|
||||||
|
- [Moondream](./candle-examples/examples/moondream/): tiny computer-vision model
|
||||||
|
that can answer real-world questions about images.
|
||||||
|
|
||||||
Run them using commands like:
|
Run them using commands like:
|
||||||
```
|
```
|
||||||
@ -159,9 +182,11 @@ And then head over to
|
|||||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||||
serving local LLMs including an OpenAI compatible API server.
|
serving local LLMs including an OpenAI compatible API server.
|
||||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||||
|
- [`candle-coursera-ml`](https://github.com/vishpat/candle-coursera-ml): Implementation of ML algorithms from Coursera's [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) course.
|
||||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||||
|
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
|
||||||
|
|
||||||
If you have an addition to this list, please submit a pull request.
|
If you have an addition to this list, please submit a pull request.
|
||||||
|
|
||||||
@ -180,17 +205,20 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- WASM support, run your models in a browser.
|
- WASM support, run your models in a browser.
|
||||||
- Included models.
|
- Included models.
|
||||||
- Language Models.
|
- Language Models.
|
||||||
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
- LLaMA v1, v2, and v3 with variants such as SOLAR-10.7B.
|
||||||
- Falcon.
|
- Falcon.
|
||||||
- StarCoder.
|
- StarCoder, StarCoder2.
|
||||||
- Phi 1, 1.5, and 2.
|
- Phi 1, 1.5, 2, and 3.
|
||||||
- Minimal Mamba
|
- Mamba, Minimal Mamba
|
||||||
|
- Gemma v1 2b and 7b+, v2 2b and 9b.
|
||||||
- Mistral 7b v0.1.
|
- Mistral 7b v0.1.
|
||||||
- Mixtral 8x7b v0.1.
|
- Mixtral 8x7b v0.1.
|
||||||
- StableLM-3B-4E1T.
|
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
||||||
- Replit-code-v1.5-3B.
|
- Replit-code-v1.5-3B.
|
||||||
- Bert.
|
- Bert.
|
||||||
- Yi-6B and Yi-34B.
|
- Yi-6B and Yi-34B.
|
||||||
|
- Qwen1.5, Qwen1.5 MoE.
|
||||||
|
- RWKV v5 and v6.
|
||||||
- Quantized LLMs.
|
- Quantized LLMs.
|
||||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||||
- Mistral 7b, and 7b instruct.
|
- Mistral 7b, and 7b instruct.
|
||||||
@ -200,16 +228,23 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Text to text.
|
- Text to text.
|
||||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||||
- Marian MT (Machine Translation).
|
- Marian MT (Machine Translation).
|
||||||
- Whisper (multi-lingual support).
|
|
||||||
- Text to image.
|
- Text to image.
|
||||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||||
- Wurstchen v2.
|
- Wurstchen v2.
|
||||||
- Image to text.
|
- Image to text.
|
||||||
- BLIP.
|
- BLIP.
|
||||||
|
- TrOCR.
|
||||||
|
- Audio.
|
||||||
|
- Whisper, multi-lingual speech-to-text.
|
||||||
|
- EnCodec, audio compression model.
|
||||||
|
- MetaVoice-1B, text-to-speech model.
|
||||||
|
- Parler-TTS, text-to-speech model.
|
||||||
- Computer Vision Models.
|
- Computer Vision Models.
|
||||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG.
|
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
||||||
|
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT.
|
||||||
- yolo-v3, yolo-v8.
|
- yolo-v3, yolo-v8.
|
||||||
- Segment-Anything Model (SAM).
|
- Segment-Anything Model (SAM).
|
||||||
|
- SegFormer.
|
||||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||||
- Serverless (on CPU), small and fast deployments.
|
- Serverless (on CPU), small and fast deployments.
|
||||||
- Quantization support using the llama.cpp quantized types.
|
- Quantization support using the llama.cpp quantized types.
|
||||||
@ -346,9 +381,9 @@ git submodule update --init
|
|||||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||||
```
|
```
|
||||||
|
|
||||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
|
||||||
```
|
```
|
||||||
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Linking error on windows when running rustdoc or mdbook tests
|
#### Linking error on windows when running rustdoc or mdbook tests
|
||||||
@ -378,3 +413,10 @@ This may be caused by the models being loaded from `/mnt/c`, more details on
|
|||||||
|
|
||||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||||
error is generated.
|
error is generated.
|
||||||
|
|
||||||
|
#### CudaRC error
|
||||||
|
|
||||||
|
If you encounter an error like this one `called `Result::unwrap()` on an `Err` value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }` on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.
|
||||||
|
`c:\Windows\System32\nvcuda.dll` -> `cuda.dll`
|
||||||
|
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll` -> `cublas.dll`
|
||||||
|
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll` -> `curand.dll`
|
||||||
|
@ -37,7 +37,6 @@ tokenizers = { workspace = true, features = ["onig"] }
|
|||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tracing-chrome = { workspace = true }
|
tracing-chrome = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
wav = { workspace = true }
|
|
||||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
parquet = { workspace = true }
|
parquet = { workspace = true }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
|
@ -81,7 +81,7 @@ let mut tp_shape = view.shape().to_vec();
|
|||||||
let size = tp_shape[0];
|
let size = tp_shape[0];
|
||||||
|
|
||||||
if size % world_size != 0 {
|
if size % world_size != 0 {
|
||||||
panic!("The dimension is not divisble by `world_size`");
|
panic!("The dimension is not divisible by `world_size`");
|
||||||
}
|
}
|
||||||
let block_size = size / world_size;
|
let block_size = size / world_size;
|
||||||
let start = rank * block_size;
|
let start = rank * block_size;
|
||||||
@ -106,8 +106,8 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[test]
|
|
||||||
fn book_training_1() -> Result<()>{
|
fn book_training_1() -> Result<()>{
|
||||||
// ANCHOR: book_training_1
|
// ANCHOR: book_training_1
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
@ -48,3 +48,7 @@ metal = ["dep:metal", "dep:candle-metal-kernels"]
|
|||||||
[[bench]]
|
[[bench]]
|
||||||
name = "bench_main"
|
name = "bench_main"
|
||||||
harness = false
|
harness = false
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "metal_basics"
|
||||||
|
required-features = ["metal"]
|
||||||
|
@ -5,5 +5,8 @@ criterion_main!(
|
|||||||
benchmarks::affine::benches,
|
benchmarks::affine::benches,
|
||||||
benchmarks::matmul::benches,
|
benchmarks::matmul::benches,
|
||||||
benchmarks::random::benches,
|
benchmarks::random::benches,
|
||||||
benchmarks::where_cond::benches
|
benchmarks::where_cond::benches,
|
||||||
|
benchmarks::conv_transpose2d::benches,
|
||||||
|
benchmarks::qmatmul::benches,
|
||||||
|
benchmarks::unary::benches
|
||||||
);
|
);
|
||||||
|
@ -12,7 +12,7 @@ fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name:
|
|||||||
let m = 1024;
|
let m = 1024;
|
||||||
let k = 1024;
|
let k = 1024;
|
||||||
|
|
||||||
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
let tensor = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||||
|
|
||||||
let flops = b * m * k * dtype.size_in_bytes();
|
let flops = b * m * k * dtype.size_in_bytes();
|
||||||
|
|
||||||
|
59
candle-core/benches/benchmarks/conv_transpose2d.rs
Normal file
59
candle-core/benches/benchmarks/conv_transpose2d.rs
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
x: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
padding: usize,
|
||||||
|
output_padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
) {
|
||||||
|
x.conv_transpose2d(k, padding, output_padding, stride, dilation)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||||
|
let t = Tensor::arange(0.0f32, 10000.0, device)
|
||||||
|
.unwrap()
|
||||||
|
.reshape((1, 4, 50, 50))
|
||||||
|
.unwrap()
|
||||||
|
.to_dtype(dtype)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let kernel = Tensor::arange(0.0f32, 100.0, device)
|
||||||
|
.unwrap()
|
||||||
|
.reshape((4, 1, 5, 5))
|
||||||
|
.unwrap()
|
||||||
|
.to_dtype(dtype)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let flops = t.dims().iter().product::<usize>() * dtype.size_in_bytes();
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(black_box(&t), black_box(&kernel), 1, 0, 1, 2);
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
|
for device in handler.devices {
|
||||||
|
run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
|
||||||
|
run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
|
||||||
|
run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
@ -1,6 +1,9 @@
|
|||||||
pub(crate) mod affine;
|
pub(crate) mod affine;
|
||||||
|
pub(crate) mod conv_transpose2d;
|
||||||
pub(crate) mod matmul;
|
pub(crate) mod matmul;
|
||||||
|
pub(crate) mod qmatmul;
|
||||||
pub(crate) mod random;
|
pub(crate) mod random;
|
||||||
|
pub(crate) mod unary;
|
||||||
pub(crate) mod where_cond;
|
pub(crate) mod where_cond;
|
||||||
|
|
||||||
use candle_core::{Device, Result};
|
use candle_core::{Device, Result};
|
||||||
|
72
candle-core/benches/benchmarks/qmatmul.rs
Normal file
72
candle-core/benches/benchmarks/qmatmul.rs
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{
|
||||||
|
quantized::{self, GgmlDType, QMatMul},
|
||||||
|
Device, Module, Tensor,
|
||||||
|
};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run(matmul: &QMatMul, x: &Tensor) {
|
||||||
|
matmul.forward(x).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
||||||
|
let b = 1;
|
||||||
|
let m = 1;
|
||||||
|
let n = 1024;
|
||||||
|
let k = 1024;
|
||||||
|
|
||||||
|
let lhs = (0..(m * k))
|
||||||
|
.map(|v| v as f32 / (m * k) as f32)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let rhs = (0..(k * n))
|
||||||
|
.map(|v| v as f32 / (n * k) as f32)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
|
||||||
|
let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
|
||||||
|
|
||||||
|
let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
|
||||||
|
let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
|
||||||
|
|
||||||
|
let flops = b * m * n * k;
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
|
||||||
|
group.sample_size(200);
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(black_box(&matmul), black_box(&lhs));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
|
for device in handler.devices {
|
||||||
|
for dtype in [
|
||||||
|
GgmlDType::F32,
|
||||||
|
GgmlDType::F16,
|
||||||
|
GgmlDType::Q4_0,
|
||||||
|
GgmlDType::Q4_1,
|
||||||
|
GgmlDType::Q5_0,
|
||||||
|
GgmlDType::Q5_1,
|
||||||
|
GgmlDType::Q8_0,
|
||||||
|
GgmlDType::Q2K,
|
||||||
|
GgmlDType::Q3K,
|
||||||
|
GgmlDType::Q4K,
|
||||||
|
GgmlDType::Q5K,
|
||||||
|
GgmlDType::Q6K,
|
||||||
|
] {
|
||||||
|
run_bench(c, &device, dtype);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
49
candle-core/benches/benchmarks/unary.rs
Normal file
49
candle-core/benches/benchmarks/unary.rs
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run(a: &Tensor) {
|
||||||
|
a.sqrt().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||||
|
let b = 1;
|
||||||
|
let m = 1024;
|
||||||
|
let k = 1024;
|
||||||
|
|
||||||
|
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, device)
|
||||||
|
.unwrap()
|
||||||
|
.to_dtype(dtype)
|
||||||
|
.unwrap()
|
||||||
|
.reshape((b, m, k))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let flops = b * m * k * dtype.size_in_bytes();
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(black_box(&tensor));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
|
for device in handler.devices {
|
||||||
|
for dtype in [DType::F32, DType::BF16, DType::F16] {
|
||||||
|
let name = format!("sqrt_{:?}", dtype);
|
||||||
|
run_unary_benchmark(c, &device, dtype, &name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
@ -25,9 +25,9 @@ const SIZE: usize = B * M * K;
|
|||||||
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
||||||
|
|
||||||
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
|
||||||
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
let on_true = Tensor::ones((B, M, K), dtype, device).unwrap();
|
||||||
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
let on_false = Tensor::zeros((B, M, K), dtype, device).unwrap();
|
||||||
|
|
||||||
let elements = B * M * K;
|
let elements = B * M * K;
|
||||||
// E.g. 2 f32 tensors + 1 u8 tensor
|
// E.g. 2 f32 tensors + 1 u8 tensor
|
||||||
|
@ -9,21 +9,25 @@ use candle_core::{Device, Tensor};
|
|||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
let device = Device::new_cuda(0)?;
|
||||||
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
|
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||||
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||||
println!("{out_t}");
|
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
||||||
let in_t = in_t.to_device(&Device::Cpu)?;
|
let _x1 = x.matmul(&x)?;
|
||||||
let k_t = k_t.to_device(&Device::Cpu)?;
|
drop(_x1);
|
||||||
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
let start_time = std::time::Instant::now();
|
||||||
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
|
let _x1 = x.matmul(&x)?;
|
||||||
.sqr()?
|
device.synchronize()?;
|
||||||
.sum_all()?;
|
println!("fp32: {:?}", start_time.elapsed());
|
||||||
println!("{diff}");
|
drop(_x1);
|
||||||
|
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||||
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
||||||
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
let _x1 = x.matmul(&x)?;
|
||||||
let res = t.conv2d(&w, 1, 1, 1, 1)?;
|
drop(_x1);
|
||||||
println!("{res:?}");
|
let start_time = std::time::Instant::now();
|
||||||
|
let _x1 = x.matmul(&x)?;
|
||||||
|
device.synchronize()?;
|
||||||
|
println!("tf32: {:?}", start_time.elapsed());
|
||||||
|
drop(_x1);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
28
candle-core/examples/metal_basics.rs
Normal file
28
candle-core/examples/metal_basics.rs
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
// This requires the code to be run with MTL_CAPTURE_ENABLED=1
|
||||||
|
let device = Device::new_metal(0)?;
|
||||||
|
let metal_device = match &device {
|
||||||
|
Device::Metal(m) => m,
|
||||||
|
_ => anyhow::bail!("unexpected device"),
|
||||||
|
};
|
||||||
|
metal_device.capture("/tmp/candle.gputrace")?;
|
||||||
|
// This first synchronize ensures that a new command buffer gets created after setting up the
|
||||||
|
// capture scope.
|
||||||
|
device.synchronize()?;
|
||||||
|
let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?;
|
||||||
|
let x1 = x.add(&x)?;
|
||||||
|
println!("{x1:?}");
|
||||||
|
// This second synchronize ensures that the command buffer gets commited before the end of the
|
||||||
|
// capture scope.
|
||||||
|
device.synchronize()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -380,6 +380,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
|||||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||||
|
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||||
|
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
@ -402,6 +412,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = -v
|
||||||
|
}
|
||||||
|
vs_exp_inplace(ys);
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = v / (1.0 + *y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = -v
|
||||||
|
}
|
||||||
|
vd_exp_inplace(ys);
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = v / (1.0 + *y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! binary_op {
|
macro_rules! binary_op {
|
||||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -98,6 +98,19 @@ pub trait BackendStorage: Sized {
|
|||||||
) -> Result<Self>;
|
) -> Result<Self>;
|
||||||
|
|
||||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
// Similar to cudaMemcpy2D, though values are in elements and not in bytes.
|
||||||
|
fn copy2d(
|
||||||
|
&self,
|
||||||
|
_: &mut Self,
|
||||||
|
_d1: usize,
|
||||||
|
_d2: usize,
|
||||||
|
_src_stride1: usize,
|
||||||
|
_dst_stride1: usize,
|
||||||
|
_src_offset: usize,
|
||||||
|
_dst_offset: usize,
|
||||||
|
) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||||
@ -114,11 +127,24 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
|||||||
|
|
||||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
/// This function is unsafe as it doesn't initialize the underlying data store.
|
||||||
|
/// The caller should ensure that the data is properly initialized as early as possible
|
||||||
|
/// after this call.
|
||||||
|
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn set_seed(&self, _: u64) -> Result<()>;
|
fn set_seed(&self, _: u64) -> Result<()>;
|
||||||
|
|
||||||
|
/// Synchronize should block until all the operations on the device are completed.
|
||||||
|
fn synchronize(&self) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
/// Methods for backpropagation of gradients.
|
||||||
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
||||||
use crate::{Error, Result, Tensor, TensorId};
|
use crate::{Error, Result, Tensor, TensorId};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -111,9 +112,10 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Op::Unary(_node, UnaryOp::Ceil)
|
Op::Unary(_node, UnaryOp::Ceil)
|
||||||
| Op::Unary(_node, UnaryOp::Floor)
|
| Op::Unary(_node, UnaryOp::Floor)
|
||||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
| Op::Unary(_node, UnaryOp::Round)
|
||||||
|
| Op::Unary(_node, UnaryOp::Sign) => nodes,
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
| Op::UpsampleNearest1D(node)
|
| Op::UpsampleNearest1D { arg: node, .. }
|
||||||
| Op::UpsampleNearest2D { arg: node, .. }
|
| Op::UpsampleNearest2D { arg: node, .. }
|
||||||
| Op::AvgPool2D { arg: node, .. }
|
| Op::AvgPool2D { arg: node, .. }
|
||||||
| Op::MaxPool2D { arg: node, .. }
|
| Op::MaxPool2D { arg: node, .. }
|
||||||
@ -175,7 +177,7 @@ impl Tensor {
|
|||||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||||
// derivatives but these are out of scope at the moment.
|
// derivatives but these are out of scope at the moment.
|
||||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||||
let grad = if do_not_detach { grad } else { grad.detach()? };
|
let grad = if do_not_detach { grad } else { grad.detach() };
|
||||||
if let Some(op) = node.op() {
|
if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||||
@ -250,6 +252,7 @@ impl Tensor {
|
|||||||
out_padding,
|
out_padding,
|
||||||
*stride,
|
*stride,
|
||||||
*dilation,
|
*dilation,
|
||||||
|
/* groups */ 1,
|
||||||
)?;
|
)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
@ -309,9 +312,32 @@ impl Tensor {
|
|||||||
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "conv-transpose1d",
|
op: "conv-transpose1d",
|
||||||
})?,
|
})?,
|
||||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
Op::ConvTranspose2D {
|
||||||
op: "conv-transpose2d",
|
arg,
|
||||||
})?,
|
kernel,
|
||||||
|
padding,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
output_padding: _output_padding,
|
||||||
|
} => {
|
||||||
|
let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
|
|
||||||
|
let grad_kernel = grad
|
||||||
|
.transpose(0, 1)?
|
||||||
|
.conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||||
|
.transpose(0, 1)?;
|
||||||
|
let sum_grad = grads.or_insert(kernel)?;
|
||||||
|
let (_, _, k0, k1) = kernel.dims4()?;
|
||||||
|
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
||||||
|
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
||||||
|
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
||||||
|
} else {
|
||||||
|
grad_kernel
|
||||||
|
};
|
||||||
|
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||||
|
}
|
||||||
Op::AvgPool2D {
|
Op::AvgPool2D {
|
||||||
arg,
|
arg,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
@ -347,9 +373,18 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
}
|
}
|
||||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
Op::UpsampleNearest1D { arg, target_size } => {
|
||||||
op: "upsample-nearest1d",
|
let (_n, c, size) = arg.dims3()?;
|
||||||
})?,
|
if target_size % size != 0 {
|
||||||
|
crate::bail!("backward not supported for non integer upscaling factors")
|
||||||
|
}
|
||||||
|
let scale = target_size / size;
|
||||||
|
|
||||||
|
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
|
||||||
|
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = conv_sum;
|
||||||
|
}
|
||||||
Op::UpsampleNearest2D {
|
Op::UpsampleNearest2D {
|
||||||
arg,
|
arg,
|
||||||
target_h,
|
target_h,
|
||||||
@ -454,7 +489,6 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad)?;
|
*sum_grad = sum_grad.add(&grad)?;
|
||||||
}
|
}
|
||||||
Op::Cmp(_args, _) => {}
|
|
||||||
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
||||||
let node = broadcast_back(arg, node, reduced_dims)?;
|
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||||
@ -544,20 +578,18 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
|
Op::Unary(_, UnaryOp::Floor)
|
||||||
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
|
| Op::Unary(_, UnaryOp::Round)
|
||||||
|
| Op::Reduce(_, ReduceOp::ArgMin, _)
|
||||||
|
| Op::Reduce(_, ReduceOp::ArgMax, _)
|
||||||
|
| Op::Unary(_, UnaryOp::Sign)
|
||||||
|
| Op::Cmp(_, _) => {}
|
||||||
Op::Reshape(arg) => {
|
Op::Reshape(arg) => {
|
||||||
let arg_grad = grad.reshape(arg.dims())?;
|
let arg_grad = grad.reshape(arg.dims())?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||||
Op::Unary(_, UnaryOp::Floor) => {
|
|
||||||
Err(Error::BackwardNotSupported { op: "floor" })?
|
|
||||||
}
|
|
||||||
Op::Unary(_, UnaryOp::Round) => {
|
|
||||||
Err(Error::BackwardNotSupported { op: "round" })?
|
|
||||||
}
|
|
||||||
Op::Unary(arg, UnaryOp::Gelu) => {
|
Op::Unary(arg, UnaryOp::Gelu) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
let cube = arg.powf(3.)?;
|
let cube = arg.powf(3.)?;
|
||||||
@ -589,13 +621,21 @@ impl Tensor {
|
|||||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||||
}
|
}
|
||||||
|
Op::Unary(arg, UnaryOp::Silu) => {
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node
|
||||||
|
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
|
||||||
|
let silu_grad = &sigmoid_arg * (1. - *node) + *node;
|
||||||
|
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||||
|
}
|
||||||
Op::Elu(arg, alpha) => {
|
Op::Elu(arg, alpha) => {
|
||||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
let zeros = arg.zeros_like()?;
|
let zeros = arg.zeros_like()?;
|
||||||
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||||
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||||
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
// node == alpha * (e^x - 1) for x <= 0, reuse it
|
||||||
|
let negative_exp_mask = (negative_mask * (*node + *alpha))?;
|
||||||
let combined_mask = (positive_mask + negative_exp_mask)?;
|
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||||
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||||
}
|
}
|
||||||
@ -673,30 +713,38 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct GradStore(HashMap<TensorId, Tensor>);
|
pub struct GradStore(HashMap<TensorId, Tensor>);
|
||||||
|
|
||||||
impl GradStore {
|
impl GradStore {
|
||||||
|
/// Create a new gradient store
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
GradStore(HashMap::new())
|
GradStore(HashMap::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the gradient tensor corresponding to the given tensor id
|
||||||
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
|
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
|
||||||
self.0.get(&id)
|
self.0.get(&id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the gradient tensor associated with the given tensor
|
||||||
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
|
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
|
||||||
self.0.get(&tensor.id())
|
self.0.get(&tensor.id())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Remove the gradient tensor associated with the given tensor, returning it if it exists
|
||||||
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
|
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
|
||||||
self.0.remove(&tensor.id())
|
self.0.remove(&tensor.id())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
|
||||||
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
|
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
|
||||||
self.0.insert(tensor.id(), grad)
|
self.0.insert(tensor.id(), grad)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the gradient tensor associated with the given tensor, or, if it does not exist,
|
||||||
|
/// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
|
||||||
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
|
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
|
||||||
use std::collections::hash_map::Entry;
|
use std::collections::hash_map::Entry;
|
||||||
let grad = match self.0.entry(tensor.id()) {
|
let grad = match self.0.entry(tensor.id()) {
|
||||||
@ -708,4 +756,9 @@ impl GradStore {
|
|||||||
};
|
};
|
||||||
Ok(grad)
|
Ok(grad)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the tensor ids of the stored gradient tensors
|
||||||
|
pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
|
||||||
|
self.0.keys()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -187,36 +187,16 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies a 1D transposed convolution over the input tensor.
|
fn conv_transpose1d_single_group(
|
||||||
pub fn conv_transpose1d(
|
|
||||||
&self,
|
&self,
|
||||||
kernel: &Self,
|
kernel: &Self,
|
||||||
padding: usize,
|
params: &ParamsConvTranspose1D,
|
||||||
output_padding: usize,
|
|
||||||
stride: usize,
|
|
||||||
dilation: usize,
|
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let (b_size, c_in, l_in) = self.dims3()?;
|
|
||||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
|
||||||
if c_in != c_in_k {
|
|
||||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
|
||||||
}
|
|
||||||
let params = ParamsConvTranspose1D {
|
|
||||||
b_size,
|
|
||||||
l_in,
|
|
||||||
k_size,
|
|
||||||
c_out,
|
|
||||||
c_in,
|
|
||||||
padding,
|
|
||||||
output_padding,
|
|
||||||
stride,
|
|
||||||
dilation,
|
|
||||||
};
|
|
||||||
let storage = self.storage().conv_transpose1d(
|
let storage = self.storage().conv_transpose1d(
|
||||||
self.layout(),
|
self.layout(),
|
||||||
&kernel.storage(),
|
&kernel.storage(),
|
||||||
kernel.layout(),
|
kernel.layout(),
|
||||||
¶ms,
|
params,
|
||||||
)?;
|
)?;
|
||||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||||
arg,
|
arg,
|
||||||
@ -230,6 +210,49 @@ impl Tensor {
|
|||||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies a 1D transposed convolution over the input tensor.
|
||||||
|
pub fn conv_transpose1d(
|
||||||
|
&self,
|
||||||
|
kernel: &Self,
|
||||||
|
padding: usize,
|
||||||
|
output_padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
groups: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||||
|
let (b_size, c_in, l_in) = self.dims3()?;
|
||||||
|
if c_in != c_in_k {
|
||||||
|
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||||
|
}
|
||||||
|
if c_in % groups != 0 {
|
||||||
|
crate::bail!("in_channel {c_in} is not divisible by the number of groups")
|
||||||
|
}
|
||||||
|
let params = ParamsConvTranspose1D {
|
||||||
|
b_size,
|
||||||
|
l_in,
|
||||||
|
k_size,
|
||||||
|
c_out,
|
||||||
|
c_in: c_in / groups,
|
||||||
|
padding,
|
||||||
|
output_padding,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
};
|
||||||
|
if groups == 1 {
|
||||||
|
self.conv_transpose1d_single_group(kernel, ¶ms)
|
||||||
|
} else {
|
||||||
|
let blocks = self.chunk(groups, 1)?;
|
||||||
|
let kernel = kernel.chunk(groups, 0)?;
|
||||||
|
let blocks = blocks
|
||||||
|
.iter()
|
||||||
|
.zip(&kernel)
|
||||||
|
.map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
Tensor::cat(&blocks, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||||
let storage =
|
let storage =
|
||||||
self.storage()
|
self.storage()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
pub mod erf;
|
pub mod erf;
|
||||||
pub mod kernels;
|
pub mod kernels;
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
trait Cpu<const ARR: usize> {
|
trait Cpu<const ARR: usize> {
|
||||||
type Unit;
|
type Unit;
|
||||||
type Array;
|
type Array;
|
||||||
@ -18,6 +19,7 @@ trait Cpu<const ARR: usize> {
|
|||||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
trait CpuF16<const ARR: usize> {
|
trait CpuF16<const ARR: usize> {
|
||||||
type Unit;
|
type Unit;
|
||||||
type Array;
|
type Array;
|
||||||
|
@ -4,7 +4,13 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
|||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
mod utils;
|
||||||
|
pub use utils::{
|
||||||
|
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
||||||
|
};
|
||||||
|
|
||||||
const USE_IM2COL_CONV1D: bool = true;
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
|
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||||
const USE_IM2COL_CONV2D: bool = true;
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||||
@ -20,105 +26,20 @@ pub enum CpuStorage {
|
|||||||
F64(Vec<f64>),
|
F64(Vec<f64>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum CpuStorageRef<'a> {
|
||||||
|
U8(&'a [u8]),
|
||||||
|
U32(&'a [u32]),
|
||||||
|
I64(&'a [i64]),
|
||||||
|
BF16(&'a [bf16]),
|
||||||
|
F16(&'a [f16]),
|
||||||
|
F32(&'a [f32]),
|
||||||
|
F64(&'a [f64]),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CpuDevice;
|
pub struct CpuDevice;
|
||||||
|
|
||||||
pub trait Map1 {
|
|
||||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
|
||||||
|
|
||||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
|
||||||
match vs {
|
|
||||||
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map1Any {
|
|
||||||
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
|
|
||||||
&self,
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
wrap: W,
|
|
||||||
) -> Result<CpuStorage>;
|
|
||||||
|
|
||||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
|
||||||
match vs {
|
|
||||||
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
|
|
||||||
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
|
|
||||||
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
|
|
||||||
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
|
|
||||||
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
|
|
||||||
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
|
|
||||||
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type C = CpuStorage;
|
|
||||||
pub trait Map2 {
|
|
||||||
const OP: &'static str;
|
|
||||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
|
||||||
|
|
||||||
fn map(
|
|
||||||
&self,
|
|
||||||
v1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
v2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
) -> Result<CpuStorage> {
|
|
||||||
match (v1, v2) {
|
|
||||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: v1.dtype(),
|
|
||||||
rhs: v2.dtype(),
|
|
||||||
op: Self::OP,
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2U8 {
|
|
||||||
const OP: &'static str;
|
|
||||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
|
||||||
|
|
||||||
fn map(
|
|
||||||
&self,
|
|
||||||
v1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
v2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
) -> Result<CpuStorage> {
|
|
||||||
match (v1, v2) {
|
|
||||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: v1.dtype(),
|
|
||||||
rhs: v2.dtype(),
|
|
||||||
op: Self::OP,
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Cmp(CmpOp);
|
struct Cmp(CmpOp);
|
||||||
impl Map2U8 for Cmp {
|
impl Map2U8 for Cmp {
|
||||||
const OP: &'static str = "cmp";
|
const OP: &'static str = "cmp";
|
||||||
@ -200,7 +121,8 @@ impl ReduceIndex {
|
|||||||
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
|
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
|
||||||
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
|
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
|
||||||
let dst_to_set = dst.spare_capacity_mut();
|
let dst_to_set = dst.spare_capacity_mut();
|
||||||
let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
|
let dst_to_set =
|
||||||
|
unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
|
||||||
match src_l.contiguous_offsets() {
|
match src_l.contiguous_offsets() {
|
||||||
Some((o1, o2)) => {
|
Some((o1, o2)) => {
|
||||||
let src = &src[o1..o2];
|
let src = &src[o1..o2];
|
||||||
@ -365,275 +287,6 @@ impl<'a> Map1 for ReduceSum<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
mut f: F,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match layout.strided_blocks() {
|
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
|
||||||
[start_offset..start_offset + len]
|
|
||||||
.iter()
|
|
||||||
.map(|&v| f(v))
|
|
||||||
.collect(),
|
|
||||||
crate::StridedBlocks::MultipleBlocks {
|
|
||||||
block_start_index,
|
|
||||||
block_len,
|
|
||||||
} => {
|
|
||||||
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
|
||||||
// Specialize the case where block_len is one to avoid the second loop.
|
|
||||||
if block_len == 1 {
|
|
||||||
for index in block_start_index {
|
|
||||||
let v = unsafe { vs.get_unchecked(index) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for index in block_start_index {
|
|
||||||
for offset in 0..block_len {
|
|
||||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
mut f: F,
|
|
||||||
mut f_vec: FV,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match layout.strided_blocks() {
|
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
|
||||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
|
||||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(len) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
crate::StridedBlocks::MultipleBlocks {
|
|
||||||
block_start_index,
|
|
||||||
block_len,
|
|
||||||
} => {
|
|
||||||
let el_count = layout.shape().elem_count();
|
|
||||||
// Specialize the case where block_len is one to avoid the second loop.
|
|
||||||
if block_len == 1 {
|
|
||||||
let mut result = Vec::with_capacity(el_count);
|
|
||||||
for index in block_start_index {
|
|
||||||
let v = unsafe { vs.get_unchecked(index) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
result
|
|
||||||
} else {
|
|
||||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
|
||||||
let mut dst_index = 0;
|
|
||||||
for src_index in block_start_index {
|
|
||||||
let vs = &vs[src_index..src_index + block_len];
|
|
||||||
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
|
||||||
f_vec(vs, ys);
|
|
||||||
dst_index += block_len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This function maps over two strided index sequences.
|
|
||||||
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
|
||||||
lhs_l: &Layout,
|
|
||||||
rhs_l: &Layout,
|
|
||||||
lhs: &[T],
|
|
||||||
rhs: &[T],
|
|
||||||
mut f: F,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
|
||||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
|
||||||
.iter()
|
|
||||||
.zip(rhs[o_r1..o_r2].iter())
|
|
||||||
.map(|(&l, &r)| f(l, r))
|
|
||||||
.collect(),
|
|
||||||
(Some((o_l1, o_l2)), None) => {
|
|
||||||
// TODO: Maybe we want to avoid going through the layout twice.
|
|
||||||
match rhs_l.offsets_b() {
|
|
||||||
Some(ob) => {
|
|
||||||
let mut i_in_block = 0;
|
|
||||||
let mut i_right_broadcast = 0;
|
|
||||||
lhs[o_l1..o_l2]
|
|
||||||
.iter()
|
|
||||||
.map(|&l| {
|
|
||||||
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
|
||||||
i_right_broadcast += 1;
|
|
||||||
if i_right_broadcast >= ob.right_broadcast {
|
|
||||||
i_in_block += 1;
|
|
||||||
i_right_broadcast = 0;
|
|
||||||
}
|
|
||||||
if i_in_block >= ob.len {
|
|
||||||
i_in_block = 0
|
|
||||||
}
|
|
||||||
f(l, *r)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(None, Some((o_r1, o_r2))) => {
|
|
||||||
// TODO: Maybe we want to avoid going through the layout twice.
|
|
||||||
match lhs_l.offsets_b() {
|
|
||||||
Some(ob) => {
|
|
||||||
let mut i_in_block = 0;
|
|
||||||
let mut i_right_broadcast = 0;
|
|
||||||
rhs[o_r1..o_r2]
|
|
||||||
.iter()
|
|
||||||
.map(|&r| {
|
|
||||||
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
|
||||||
i_right_broadcast += 1;
|
|
||||||
if i_right_broadcast >= ob.right_broadcast {
|
|
||||||
i_in_block += 1;
|
|
||||||
i_right_broadcast = 0;
|
|
||||||
}
|
|
||||||
if i_in_block >= ob.len {
|
|
||||||
i_in_block = 0
|
|
||||||
}
|
|
||||||
f(*l, r)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Similar to binary_map but with vectorized variants.
|
|
||||||
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
|
||||||
lhs_l: &Layout,
|
|
||||||
rhs_l: &Layout,
|
|
||||||
lhs: &[T],
|
|
||||||
rhs: &[T],
|
|
||||||
mut f: F,
|
|
||||||
mut f_vec: FV,
|
|
||||||
) -> Vec<T> {
|
|
||||||
let el_count = lhs_l.shape().elem_count();
|
|
||||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
|
||||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
|
||||||
Some(ob) if ob.right_broadcast == 1 => {
|
|
||||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
let mut dst_i = 0;
|
|
||||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
|
||||||
f_vec(
|
|
||||||
&lhs[src_i..src_i + ob.len],
|
|
||||||
rhs,
|
|
||||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
|
||||||
);
|
|
||||||
dst_i += ob.len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
Some(ob) => {
|
|
||||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys = lhs[o_l1..o_l2].to_vec();
|
|
||||||
for idx_l in 0..ob.left_broadcast {
|
|
||||||
let start = idx_l * ob.len * ob.right_broadcast;
|
|
||||||
for (i, &r) in rhs.iter().enumerate() {
|
|
||||||
let start = start + i * ob.right_broadcast;
|
|
||||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
|
||||||
*v = f(*v, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
|
||||||
Some(ob) if ob.right_broadcast == 1 => {
|
|
||||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
let mut dst_i = 0;
|
|
||||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
|
||||||
f_vec(
|
|
||||||
lhs,
|
|
||||||
&rhs[src_i..src_i + ob.len],
|
|
||||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
|
||||||
);
|
|
||||||
dst_i += ob.len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
Some(ob) => {
|
|
||||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys = rhs[o_r1..o_r2].to_vec();
|
|
||||||
for idx_l in 0..ob.left_broadcast {
|
|
||||||
let start = idx_l * ob.len * ob.right_broadcast;
|
|
||||||
for (i, &l) in lhs.iter().enumerate() {
|
|
||||||
let start = start + i * ob.right_broadcast;
|
|
||||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
|
||||||
*v = f(l, *v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
_ => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Affine(f64, f64);
|
struct Affine(f64, f64);
|
||||||
|
|
||||||
impl Map1 for Affine {
|
impl Map1 for Affine {
|
||||||
@ -1022,6 +675,26 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn copy2d_<T: Copy>(
|
||||||
|
src: &[T],
|
||||||
|
dst: &mut [T],
|
||||||
|
d1: usize,
|
||||||
|
d2: usize,
|
||||||
|
src_stride1: usize,
|
||||||
|
dst_stride1: usize,
|
||||||
|
src_offset: usize,
|
||||||
|
dst_offset: usize,
|
||||||
|
) {
|
||||||
|
for i1 in 0..d1 {
|
||||||
|
let dst_idx = i1 * dst_stride1 + dst_offset;
|
||||||
|
let src_idx = i1 * src_stride1 + src_offset;
|
||||||
|
let dst = &mut dst[dst_idx..dst_idx + d2];
|
||||||
|
let src = &src[src_idx..src_idx + d2];
|
||||||
|
dst.copy_from_slice(src)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
|
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
|
||||||
match src_l.strided_blocks() {
|
match src_l.strided_blocks() {
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
@ -1256,6 +929,34 @@ impl Map1 for Im2Col {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Col2Im1D {
|
||||||
|
stride: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Map1 for Col2Im1D {
|
||||||
|
fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
|
||||||
|
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
|
||||||
|
let stride = self.stride;
|
||||||
|
let l_out = (l_in - 1) * stride + k_size;
|
||||||
|
let mut im = vec![T::zero(); b_size * c_out * l_out];
|
||||||
|
let (dst_s0, dst_s1) = (c_out * l_out, l_out);
|
||||||
|
let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
|
||||||
|
for l_in_i in 0..l_in {
|
||||||
|
for k_i in 0..k_size {
|
||||||
|
let l_out_i = l_in_i * stride + k_i;
|
||||||
|
for b_i in 0..b_size {
|
||||||
|
for c_i in 0..c_out {
|
||||||
|
let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
|
||||||
|
let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
|
||||||
|
im[dst_idx] += col[src_idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(im)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||||
|
|
||||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||||
@ -1263,6 +964,7 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
|
|||||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||||
let p = self.0;
|
let p = self.0;
|
||||||
let inp = &inp[inp_l.start_offset()..];
|
let inp = &inp[inp_l.start_offset()..];
|
||||||
|
let k = &k[k_l.start_offset()..];
|
||||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||||
let l_out = p.l_out();
|
let l_out = p.l_out();
|
||||||
@ -1514,6 +1216,30 @@ impl MatMul {
|
|||||||
}))
|
}))
|
||||||
.bt()
|
.bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
|
||||||
|
let lhs_stride = lhs_l.stride();
|
||||||
|
let rhs_stride = rhs_l.stride();
|
||||||
|
let rank = lhs_stride.len();
|
||||||
|
let (_b, m, n, k) = self.0;
|
||||||
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
|
[_, stride] if lhs_l.dims()[0] == 1 => stride,
|
||||||
|
[stride, _] if lhs_l.dims()[1] == 1 => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => m * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||||
|
};
|
||||||
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
|
[_, stride] if rhs_l.dims()[0] == 1 => stride,
|
||||||
|
[stride, _] if rhs_l.dims()[1] == 1 => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => n * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||||
|
};
|
||||||
|
Ok((a_skip, b_skip))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Map2 for MatMul {
|
impl Map2 for MatMul {
|
||||||
@ -1547,18 +1273,7 @@ impl Map2 for MatMul {
|
|||||||
let rhs_cs = rhs_stride[rank - 1];
|
let rhs_cs = rhs_stride[rank - 1];
|
||||||
let rhs_rs = rhs_stride[rank - 2];
|
let rhs_rs = rhs_stride[rank - 2];
|
||||||
|
|
||||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
||||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => m * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
|
||||||
};
|
|
||||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => n * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
|
||||||
};
|
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let dst_shape: Shape = (m, n).into();
|
let dst_shape: Shape = (m, n).into();
|
||||||
@ -1618,20 +1333,8 @@ impl Map2 for MatMul {
|
|||||||
|
|
||||||
let lhs_stride = lhs_l.stride();
|
let lhs_stride = lhs_l.stride();
|
||||||
let rhs_stride = rhs_l.stride();
|
let rhs_stride = rhs_l.stride();
|
||||||
let rank = lhs_stride.len();
|
|
||||||
|
|
||||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
||||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => m * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
|
||||||
};
|
|
||||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => n * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
|
||||||
};
|
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
@ -1639,7 +1342,7 @@ impl Map2 for MatMul {
|
|||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
|
||||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||||
(n as i32, b'N')
|
(n as i32, b'N')
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(k as i32, b'T')
|
(k as i32, b'T')
|
||||||
@ -1647,7 +1350,7 @@ impl Map2 for MatMul {
|
|||||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||||
(k as i32, b'N')
|
(k as i32, b'N')
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(m as i32, b'T')
|
(m as i32, b'T')
|
||||||
@ -1721,20 +1424,8 @@ impl Map2 for MatMul {
|
|||||||
|
|
||||||
let lhs_stride = lhs_l.stride();
|
let lhs_stride = lhs_l.stride();
|
||||||
let rhs_stride = rhs_l.stride();
|
let rhs_stride = rhs_l.stride();
|
||||||
let rank = lhs_stride.len();
|
|
||||||
|
|
||||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
||||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => m * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
|
||||||
};
|
|
||||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => n * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
|
||||||
};
|
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
@ -1742,7 +1433,7 @@ impl Map2 for MatMul {
|
|||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
|
||||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||||
(n as i32, b'N')
|
(n as i32, b'N')
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(k as i32, b'T')
|
(k as i32, b'T')
|
||||||
@ -1750,7 +1441,7 @@ impl Map2 for MatMul {
|
|||||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||||
(k as i32, b'N')
|
(k as i32, b'N')
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(m as i32, b'T')
|
(m as i32, b'T')
|
||||||
@ -2422,6 +2113,48 @@ impl BackendStorage for CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn copy2d(
|
||||||
|
&self,
|
||||||
|
dst: &mut Self,
|
||||||
|
d1: usize,
|
||||||
|
d2: usize,
|
||||||
|
src_s: usize,
|
||||||
|
dst_s: usize,
|
||||||
|
src_o: usize,
|
||||||
|
dst_o: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
match (self, dst) {
|
||||||
|
(Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
||||||
|
(Self::U32(src), Self::U32(dst)) => {
|
||||||
|
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||||
|
}
|
||||||
|
(Self::I64(src), Self::I64(dst)) => {
|
||||||
|
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||||
|
}
|
||||||
|
(Self::BF16(src), Self::BF16(dst)) => {
|
||||||
|
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||||
|
}
|
||||||
|
(Self::F16(src), Self::F16(dst)) => {
|
||||||
|
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||||
|
}
|
||||||
|
(Self::F32(src), Self::F32(dst)) => {
|
||||||
|
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||||
|
}
|
||||||
|
(Self::F64(src), Self::F64(dst)) => {
|
||||||
|
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||||
|
}
|
||||||
|
(_, dst) => {
|
||||||
|
return Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: self.dtype(),
|
||||||
|
rhs: dst.dtype(),
|
||||||
|
op: "copy2d",
|
||||||
|
}
|
||||||
|
.bt());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||||
@ -2490,7 +2223,10 @@ impl BackendStorage for CpuStorage {
|
|||||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
} else {
|
} else {
|
||||||
// Make the kernel contiguous if not already the case.
|
// Make the kernel contiguous if not already the case.
|
||||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
let mut kernel_c = unsafe {
|
||||||
|
self.device()
|
||||||
|
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||||
|
};
|
||||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
@ -2498,7 +2234,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
};
|
};
|
||||||
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
|
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
|
||||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
@ -2510,7 +2246,52 @@ impl BackendStorage for CpuStorage {
|
|||||||
kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
params: &crate::conv::ParamsConvTranspose1D,
|
params: &crate::conv::ParamsConvTranspose1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
let can_use_col2im = kernel_l.is_contiguous()
|
||||||
|
&& params.dilation == 1
|
||||||
|
&& params.padding == 0
|
||||||
|
&& params.output_padding == 0;
|
||||||
|
if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||||
|
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||||
|
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||||
|
if !kernel_l.is_contiguous() {
|
||||||
|
crate::bail!(
|
||||||
|
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if c_in != c_in2 {
|
||||||
|
crate::bail!(
|
||||||
|
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||||
|
l.shape(),
|
||||||
|
kernel_l.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let col = {
|
||||||
|
// This merges the last two dimensions of the kernel together.
|
||||||
|
let kernel_l_mm = Layout::new(
|
||||||
|
(b_size, c_in, k_size * c_out).into(),
|
||||||
|
vec![0, k_size * c_out, 1],
|
||||||
|
kernel_l.start_offset(),
|
||||||
|
);
|
||||||
|
self.matmul(
|
||||||
|
kernel,
|
||||||
|
(
|
||||||
|
b_size,
|
||||||
|
/* m */ l_in,
|
||||||
|
/* n */ c_out * k_size,
|
||||||
|
/* k */ c_in,
|
||||||
|
),
|
||||||
|
&l.transpose(1, 2)?,
|
||||||
|
&kernel_l_mm,
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||||
|
Col2Im1D {
|
||||||
|
stride: params.stride,
|
||||||
|
}
|
||||||
|
.map(&col, &col_l)
|
||||||
|
} else {
|
||||||
|
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
@ -2544,7 +2325,10 @@ impl BackendStorage for CpuStorage {
|
|||||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
} else {
|
} else {
|
||||||
// Make the kernel contiguous if not already the case.
|
// Make the kernel contiguous if not already the case.
|
||||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
let mut kernel_c = unsafe {
|
||||||
|
self.device()
|
||||||
|
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||||
|
};
|
||||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
@ -2554,7 +2338,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
|
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.transpose(1, 3)?;
|
.transpose(1, 3)?;
|
||||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
@ -2574,7 +2358,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||||
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||||
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2583,7 +2367,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||||
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||||
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2600,7 +2384,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2673,10 +2457,18 @@ impl BackendDevice for CpuDevice {
|
|||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||||
|
Ok(T::to_cpu_storage(s))
|
||||||
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Ok(s.clone())
|
Ok(s.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
|
||||||
|
Ok(s)
|
||||||
|
}
|
||||||
|
|
||||||
fn new(_: usize) -> Result<Self> {
|
fn new(_: usize) -> Result<Self> {
|
||||||
Ok(Self)
|
Ok(Self)
|
||||||
}
|
}
|
||||||
@ -2778,6 +2570,53 @@ impl BackendDevice for CpuDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::uninit_vec)]
|
||||||
|
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
// The code below is highly unsafe but hopefully not directly unsound as we only consider
|
||||||
|
// types that are Copy, not Drop, and for which all bit patterns are proper values.
|
||||||
|
// It's still pretty risky, see the following for more details:
|
||||||
|
// https://github.com/rust-lang/rust-clippy/issues/4483
|
||||||
|
let storage = match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
let mut v = Vec::with_capacity(elem_count);
|
||||||
|
v.set_len(elem_count);
|
||||||
|
CpuStorage::U8(v)
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let mut v = Vec::with_capacity(elem_count);
|
||||||
|
v.set_len(elem_count);
|
||||||
|
CpuStorage::U32(v)
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let mut v = Vec::with_capacity(elem_count);
|
||||||
|
v.set_len(elem_count);
|
||||||
|
CpuStorage::I64(v)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
let mut v = Vec::with_capacity(elem_count);
|
||||||
|
v.set_len(elem_count);
|
||||||
|
CpuStorage::BF16(v)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let mut v = Vec::with_capacity(elem_count);
|
||||||
|
v.set_len(elem_count);
|
||||||
|
CpuStorage::F16(v)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut v = Vec::with_capacity(elem_count);
|
||||||
|
v.set_len(elem_count);
|
||||||
|
CpuStorage::F32(v)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut v = Vec::with_capacity(elem_count);
|
||||||
|
v.set_len(elem_count);
|
||||||
|
CpuStorage::F64(v)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(storage)
|
||||||
|
}
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let storage = match dtype {
|
let storage = match dtype {
|
||||||
@ -2805,6 +2644,10 @@ impl BackendDevice for CpuDevice {
|
|||||||
};
|
};
|
||||||
Ok(storage)
|
Ok(storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn synchronize(&self) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
360
candle-core/src/cpu_backend/utils.rs
Normal file
360
candle-core/src/cpu_backend/utils.rs
Normal file
@ -0,0 +1,360 @@
|
|||||||
|
/// Helper functions to write CPU kernels.
|
||||||
|
use crate::backend::BackendStorage;
|
||||||
|
use crate::{Error, Layout, Result, WithDType};
|
||||||
|
|
||||||
|
type C = super::CpuStorage;
|
||||||
|
pub trait Map1 {
|
||||||
|
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||||
|
|
||||||
|
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
||||||
|
match vs {
|
||||||
|
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
|
||||||
|
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
|
||||||
|
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
|
||||||
|
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
|
||||||
|
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
|
||||||
|
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
|
||||||
|
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map1Any {
|
||||||
|
fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
|
||||||
|
|
||||||
|
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
||||||
|
match vs {
|
||||||
|
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
|
||||||
|
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
|
||||||
|
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
|
||||||
|
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
|
||||||
|
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
|
||||||
|
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
|
||||||
|
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2 {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
||||||
|
|
||||||
|
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2U8 {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||||
|
|
||||||
|
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
lhs: &[T],
|
||||||
|
rhs: &[T],
|
||||||
|
mut f: F,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||||
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||||
|
.iter()
|
||||||
|
.zip(rhs[o_r1..o_r2].iter())
|
||||||
|
.map(|(&l, &r)| f(l, r))
|
||||||
|
.collect(),
|
||||||
|
(Some((o_l1, o_l2)), None) => {
|
||||||
|
// TODO: Maybe we want to avoid going through the layout twice.
|
||||||
|
match rhs_l.offsets_b() {
|
||||||
|
Some(ob) => {
|
||||||
|
let mut i_in_block = 0;
|
||||||
|
let mut i_right_broadcast = 0;
|
||||||
|
lhs[o_l1..o_l2]
|
||||||
|
.iter()
|
||||||
|
.map(|&l| {
|
||||||
|
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
||||||
|
i_right_broadcast += 1;
|
||||||
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
|
i_in_block += 1;
|
||||||
|
i_right_broadcast = 0;
|
||||||
|
}
|
||||||
|
if i_in_block >= ob.len {
|
||||||
|
i_in_block = 0
|
||||||
|
}
|
||||||
|
f(l, *r)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(None, Some((o_r1, o_r2))) => {
|
||||||
|
// TODO: Maybe we want to avoid going through the layout twice.
|
||||||
|
match lhs_l.offsets_b() {
|
||||||
|
Some(ob) => {
|
||||||
|
let mut i_in_block = 0;
|
||||||
|
let mut i_right_broadcast = 0;
|
||||||
|
rhs[o_r1..o_r2]
|
||||||
|
.iter()
|
||||||
|
.map(|&r| {
|
||||||
|
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
||||||
|
i_right_broadcast += 1;
|
||||||
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
|
i_in_block += 1;
|
||||||
|
i_right_broadcast = 0;
|
||||||
|
}
|
||||||
|
if i_in_block >= ob.len {
|
||||||
|
i_in_block = 0
|
||||||
|
}
|
||||||
|
f(*l, r)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similar to binary_map but with vectorized variants.
|
||||||
|
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
lhs: &[T],
|
||||||
|
rhs: &[T],
|
||||||
|
mut f: F,
|
||||||
|
mut f_vec: FV,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let el_count = lhs_l.shape().elem_count();
|
||||||
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||||
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||||
|
};
|
||||||
|
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||||
|
Some(ob) if ob.right_broadcast == 1 => {
|
||||||
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||||
|
};
|
||||||
|
let mut dst_i = 0;
|
||||||
|
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||||
|
f_vec(
|
||||||
|
&lhs[src_i..src_i + ob.len],
|
||||||
|
rhs,
|
||||||
|
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||||
|
);
|
||||||
|
dst_i += ob.len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
Some(ob) => {
|
||||||
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys = lhs[o_l1..o_l2].to_vec();
|
||||||
|
for idx_l in 0..ob.left_broadcast {
|
||||||
|
let start = idx_l * ob.len * ob.right_broadcast;
|
||||||
|
for (i, &r) in rhs.iter().enumerate() {
|
||||||
|
let start = start + i * ob.right_broadcast;
|
||||||
|
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||||
|
*v = f(*v, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||||
|
Some(ob) if ob.right_broadcast == 1 => {
|
||||||
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||||
|
};
|
||||||
|
let mut dst_i = 0;
|
||||||
|
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||||
|
f_vec(
|
||||||
|
lhs,
|
||||||
|
&rhs[src_i..src_i + ob.len],
|
||||||
|
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||||
|
);
|
||||||
|
dst_i += ob.len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
Some(ob) => {
|
||||||
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys = rhs[o_r1..o_r2].to_vec();
|
||||||
|
for idx_l in 0..ob.left_broadcast {
|
||||||
|
let start = idx_l * ob.len * ob.right_broadcast;
|
||||||
|
for (i, &l) in lhs.iter().enumerate() {
|
||||||
|
let start = start + i * ob.right_broadcast;
|
||||||
|
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||||
|
*v = f(l, *v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
_ => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
mut f: F,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match layout.strided_blocks() {
|
||||||
|
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||||
|
[start_offset..start_offset + len]
|
||||||
|
.iter()
|
||||||
|
.map(|&v| f(v))
|
||||||
|
.collect(),
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
||||||
|
// Specialize the case where block_len is one to avoid the second loop.
|
||||||
|
if block_len == 1 {
|
||||||
|
for index in block_start_index {
|
||||||
|
let v = unsafe { vs.get_unchecked(index) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for index in block_start_index {
|
||||||
|
for offset in 0..block_len {
|
||||||
|
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
mut f: F,
|
||||||
|
mut f_vec: FV,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match layout.strided_blocks() {
|
||||||
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
|
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
|
||||||
|
};
|
||||||
|
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(len) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
let el_count = layout.shape().elem_count();
|
||||||
|
// Specialize the case where block_len is one to avoid the second loop.
|
||||||
|
if block_len == 1 {
|
||||||
|
let mut result = Vec::with_capacity(el_count);
|
||||||
|
for index in block_start_index {
|
||||||
|
let v = unsafe { vs.get_unchecked(index) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
|
||||||
|
};
|
||||||
|
let mut dst_index = 0;
|
||||||
|
for src_index in block_start_index {
|
||||||
|
let vs = &vs[src_index..src_index + block_len];
|
||||||
|
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
||||||
|
f_vec(vs, ys);
|
||||||
|
dst_index += block_len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,6 @@
|
|||||||
use crate::WithDType;
|
use crate::WithDType;
|
||||||
use cudarc;
|
use cudarc;
|
||||||
use cudarc::cudnn::safe::{Conv2dForward, Cudnn};
|
use cudarc::cudnn::safe::{ConvForward, Cudnn};
|
||||||
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
|
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -87,7 +87,7 @@ pub(crate) fn launch_conv2d<
|
|||||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||||
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
||||||
)?;
|
)?;
|
||||||
let conv2d = Conv2dForward {
|
let conv2d = ConvForward {
|
||||||
conv: &conv,
|
conv: &conv,
|
||||||
x: &x,
|
x: &x,
|
||||||
w: &w,
|
w: &w,
|
452
candle-core/src/cuda_backend/device.rs
Normal file
452
candle-core/src/cuda_backend/device.rs
Normal file
@ -0,0 +1,452 @@
|
|||||||
|
use crate::backend::BackendDevice;
|
||||||
|
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||||
|
pub use candle_kernels as kernels;
|
||||||
|
pub use cudarc;
|
||||||
|
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
||||||
|
|
||||||
|
/// Unique identifier for cuda devices.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||||
|
pub struct DeviceId(usize);
|
||||||
|
|
||||||
|
impl DeviceId {
|
||||||
|
fn new() -> Self {
|
||||||
|
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||||
|
use std::sync::atomic;
|
||||||
|
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||||
|
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CudaRng(cudarc::curand::CudaRng);
|
||||||
|
unsafe impl Send for CudaRng {}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct CudaDevice {
|
||||||
|
id: DeviceId,
|
||||||
|
device: Arc<cudarc::driver::CudaDevice>,
|
||||||
|
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
||||||
|
curand: Arc<Mutex<CudaRng>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for CudaDevice {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "CudaDevice({:?})", self.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for CudaDevice {
|
||||||
|
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CudaDevice {
|
||||||
|
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||||
|
self.device.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> DeviceId {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
||||||
|
let params = (&data, v as u8, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||||
|
let params = (&data, v as u32, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||||
|
let params = (&data, v as i64, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||||
|
let params = (&data, bf16::from_f64(v), elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||||
|
let params = (&data, f16::from_f64(v), elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||||
|
let params = (&data, v as f32, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||||
|
let params = (&data, v, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||||
|
if !self.has_func(module_name, module_name) {
|
||||||
|
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||||
|
// done once per kernel name.
|
||||||
|
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||||
|
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||||
|
.map_err(|cuda| CudaError::Load {
|
||||||
|
cuda,
|
||||||
|
module_name: module_name.to_string(),
|
||||||
|
})
|
||||||
|
.w()?;
|
||||||
|
}
|
||||||
|
self.get_func(module_name, module_name)
|
||||||
|
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||||
|
// able to only build the error value if needed.
|
||||||
|
.ok_or(CudaError::MissingKernel {
|
||||||
|
module_name: module_name.to_string(),
|
||||||
|
})
|
||||||
|
.w()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendDevice for CudaDevice {
|
||||||
|
type Storage = CudaStorage;
|
||||||
|
|
||||||
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
|
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
||||||
|
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||||
|
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||||
|
Ok(Self {
|
||||||
|
id: DeviceId::new(),
|
||||||
|
device,
|
||||||
|
blas: Arc::new(blas),
|
||||||
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
|
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||||
|
// state will be identical and the same random numbers will be generated.
|
||||||
|
let mut curand = self.curand.lock().unwrap();
|
||||||
|
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
crate::DeviceLocation::Cuda {
|
||||||
|
gpu_id: self.device.ordinal(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn same_device(&self, rhs: &Self) -> bool {
|
||||||
|
self.id == rhs.id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let curand = self.curand.lock().unwrap();
|
||||||
|
let slice = match dtype {
|
||||||
|
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||||
|
// cudarc changes.
|
||||||
|
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||||
|
Err(CudaError::UnsupportedDtype {
|
||||||
|
dtype,
|
||||||
|
op: "rand_uniform",
|
||||||
|
})
|
||||||
|
.w()?
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||||
|
curand.0.fill_with_uniform(&mut data).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||||
|
curand.0.fill_with_uniform(&mut data).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let slice = if lo == 0. && up == 1.0 {
|
||||||
|
slice
|
||||||
|
} else {
|
||||||
|
use super::utils::Map1;
|
||||||
|
let layout = Layout::contiguous(shape);
|
||||||
|
super::Affine(up - lo, lo).map(&slice, self, &layout)?
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
||||||
|
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||||
|
// cudarc changes.
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let curand = self.curand.lock().unwrap();
|
||||||
|
// curand can only generate an odd number of values.
|
||||||
|
// https://github.com/huggingface/candle/issues/734
|
||||||
|
let elem_count_round = if elem_count % 2 == 1 {
|
||||||
|
elem_count + 1
|
||||||
|
} else {
|
||||||
|
elem_count
|
||||||
|
};
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||||
|
Err(CudaError::UnsupportedDtype {
|
||||||
|
dtype,
|
||||||
|
op: "rand_normal",
|
||||||
|
})
|
||||||
|
.w()?
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||||
|
curand
|
||||||
|
.0
|
||||||
|
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||||
|
.w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||||
|
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
self.const_impl(1., shape, dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
let data = self.alloc::<u8>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let data = self.alloc::<u32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let data = self.alloc::<i64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
let data = self.alloc::<bf16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let data = self.alloc::<f16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let data = self.alloc::<f32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let data = self.alloc::<f64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||||
|
let slice = match T::cpu_storage_ref(s) {
|
||||||
|
CpuStorageRef::U8(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::U32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::I64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::BF16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::F16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::F32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::F64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
|
let slice = match storage {
|
||||||
|
CpuStorage::U8(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
CpuStorage::U32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::I64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
CpuStorage::BF16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||||
|
let slice = match storage {
|
||||||
|
CpuStorage::U8(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
CpuStorage::U32(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::I64(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
CpuStorage::BF16(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F16(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F32(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F64(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn synchronize(&self) -> Result<()> {
|
||||||
|
self.device.synchronize().map_err(crate::Error::wrap)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
62
candle-core/src/cuda_backend/error.rs
Normal file
62
candle-core/src/cuda_backend/error.rs
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
use crate::{DType, Layout};
|
||||||
|
|
||||||
|
/// cudarc related errors
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum CudaError {
|
||||||
|
#[error(transparent)]
|
||||||
|
Cuda(#[from] cudarc::driver::DriverError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Compiler(#[from] cudarc::nvrtc::CompileError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Curand(#[from] cudarc::curand::result::CurandError),
|
||||||
|
|
||||||
|
#[error("missing kernel '{module_name}'")]
|
||||||
|
MissingKernel { module_name: String },
|
||||||
|
|
||||||
|
#[error("unsupported dtype {dtype:?} for {op}")]
|
||||||
|
UnsupportedDtype { dtype: DType, op: &'static str },
|
||||||
|
|
||||||
|
#[error("internal error '{0}'")]
|
||||||
|
InternalError(&'static str),
|
||||||
|
|
||||||
|
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||||
|
MatMulNonContiguous {
|
||||||
|
lhs_stride: Layout,
|
||||||
|
rhs_stride: Layout,
|
||||||
|
mnk: (usize, usize, usize),
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||||
|
UnexpectedDType {
|
||||||
|
msg: &'static str,
|
||||||
|
expected: DType,
|
||||||
|
got: DType,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("{cuda} when loading {module_name}")]
|
||||||
|
Load {
|
||||||
|
cuda: cudarc::driver::DriverError,
|
||||||
|
module_name: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<CudaError> for crate::Error {
|
||||||
|
fn from(val: CudaError) -> Self {
|
||||||
|
crate::Error::Cuda(Box::new(val)).bt()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait WrapErr<O> {
|
||||||
|
fn w(self) -> std::result::Result<O, crate::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
||||||
|
fn w(self) -> std::result::Result<O, crate::Error> {
|
||||||
|
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
172
candle-core/src/cuda_backend/utils.rs
Normal file
172
candle-core/src/cuda_backend/utils.rs
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
/// Helper functions to plug cuda kernels in candle.
|
||||||
|
use crate::{Layout, Result, Shape, WithDType};
|
||||||
|
pub use cudarc;
|
||||||
|
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||||
|
|
||||||
|
use super::{CudaDevice, CudaError, WrapErr};
|
||||||
|
|
||||||
|
pub type S = super::CudaStorageSlice;
|
||||||
|
|
||||||
|
pub trait Map1 {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||||
|
let out = match s {
|
||||||
|
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||||
|
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||||
|
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||||
|
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||||
|
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||||
|
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||||
|
S::F64(s) => S::F64(self.f(s, d, l)?),
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2 {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||||
|
let out = match (s1, s2) {
|
||||||
|
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map3 {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
src3: &CudaSlice<T>,
|
||||||
|
layout3: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
s1: &S,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &S,
|
||||||
|
l2: &Layout,
|
||||||
|
s3: &S,
|
||||||
|
l3: &Layout,
|
||||||
|
d: &CudaDevice,
|
||||||
|
) -> Result<S> {
|
||||||
|
let out = match (s1, s2, s3) {
|
||||||
|
(S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2InPlace {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
dst: &mut CudaSlice<T>,
|
||||||
|
dst_shape: &Shape,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
src_l: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<()>;
|
||||||
|
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
dst: &mut S,
|
||||||
|
dst_s: &Shape,
|
||||||
|
src: &S,
|
||||||
|
src_l: &Layout,
|
||||||
|
d: &CudaDevice,
|
||||||
|
) -> Result<()> {
|
||||||
|
match (dst, src) {
|
||||||
|
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map1Any {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
wrap: W,
|
||||||
|
) -> Result<S>;
|
||||||
|
|
||||||
|
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||||
|
let out = match s {
|
||||||
|
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||||
|
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||||
|
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||||
|
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||||
|
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||||
|
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||||
|
S::F64(s) => self.f(s, d, l, S::F64)?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2Any {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<S>;
|
||||||
|
|
||||||
|
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||||
|
let out = match (s1, s2) {
|
||||||
|
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
377
candle-core/src/custom_op.rs
Normal file
377
candle-core/src/custom_op.rs
Normal file
@ -0,0 +1,377 @@
|
|||||||
|
use crate::op::{BackpropOp, Op};
|
||||||
|
use crate::tensor::from_storage;
|
||||||
|
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Unary ops that can be defined in user-land.
|
||||||
|
pub trait CustomOp1 {
|
||||||
|
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_storage: &MetalStorage,
|
||||||
|
_layout: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||||
|
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||||
|
/// The function should return the gradient of the argument.
|
||||||
|
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||||
|
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CustomOp2 {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(CudaStorage, Shape)> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bwd(
|
||||||
|
&self,
|
||||||
|
_arg1: &Tensor,
|
||||||
|
_arg2: &Tensor,
|
||||||
|
_res: &Tensor,
|
||||||
|
_grad_res: &Tensor,
|
||||||
|
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||||
|
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CustomOp3 {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
s3: &CpuStorage,
|
||||||
|
l3: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(CudaStorage, Shape)> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bwd(
|
||||||
|
&self,
|
||||||
|
_arg1: &Tensor,
|
||||||
|
_arg2: &Tensor,
|
||||||
|
_arg3: &Tensor,
|
||||||
|
_res: &Tensor,
|
||||||
|
_grad_res: &Tensor,
|
||||||
|
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
||||||
|
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tensor {
|
||||||
|
/// Applies a unary custom op without backward support
|
||||||
|
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a binary custom op without backward support
|
||||||
|
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
||||||
|
let (storage, shape) =
|
||||||
|
self.storage()
|
||||||
|
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a ternary custom op without backward support
|
||||||
|
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op3(
|
||||||
|
self.layout(),
|
||||||
|
&t2.storage(),
|
||||||
|
t2.layout(),
|
||||||
|
&t3.storage(),
|
||||||
|
t3.layout(),
|
||||||
|
c,
|
||||||
|
)?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a unary custom op.
|
||||||
|
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
||||||
|
let (storage, shape) = self
|
||||||
|
.storage()
|
||||||
|
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
||||||
|
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||||
|
Ok(from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
||||||
|
self.apply_op1_arc(Arc::new(Box::new(c)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a binary custom op.
|
||||||
|
pub fn apply_op2_arc(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op2(
|
||||||
|
self.layout(),
|
||||||
|
&rhs.storage(),
|
||||||
|
rhs.layout(),
|
||||||
|
c.as_ref().as_ref(),
|
||||||
|
)?;
|
||||||
|
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
||||||
|
Ok(from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
||||||
|
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a ternary custom op.
|
||||||
|
pub fn apply_op3_arc(
|
||||||
|
&self,
|
||||||
|
t2: &Self,
|
||||||
|
t3: &Self,
|
||||||
|
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op3(
|
||||||
|
self.layout(),
|
||||||
|
&t2.storage(),
|
||||||
|
t2.layout(),
|
||||||
|
&t3.storage(),
|
||||||
|
t3.layout(),
|
||||||
|
c.as_ref().as_ref(),
|
||||||
|
)?;
|
||||||
|
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
||||||
|
Op::CustomOp3(t1, t2, t3, c.clone())
|
||||||
|
});
|
||||||
|
Ok(from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
||||||
|
&self,
|
||||||
|
t2: &Self,
|
||||||
|
t3: &Self,
|
||||||
|
c: C,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// In place ops.
|
||||||
|
|
||||||
|
/// Unary ops that can be defined in user-land.
|
||||||
|
/// These ops work in place and as such back-prop is unsupported.
|
||||||
|
pub trait InplaceOp1 {
|
||||||
|
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait InplaceOp2 {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
|
||||||
|
-> Result<()>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &mut MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<()> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait InplaceOp3 {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &mut CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
s3: &CpuStorage,
|
||||||
|
l3: &Layout,
|
||||||
|
) -> Result<()>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
_: &mut CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<()> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &mut MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<()> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tensor {
|
||||||
|
/// Applies a unary custom op in place.
|
||||||
|
pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
|
||||||
|
self.storage_mut().inplace_op1(self.layout(), c)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a unary custom op in place (for the first tensor).
|
||||||
|
pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
|
||||||
|
self.storage_mut()
|
||||||
|
.inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a ternary custom op in place (for the first tensor).
|
||||||
|
pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
|
||||||
|
self.storage_mut().inplace_op3(
|
||||||
|
self.layout(),
|
||||||
|
&t2.storage(),
|
||||||
|
t2.layout(),
|
||||||
|
&t3.storage(),
|
||||||
|
t3.layout(),
|
||||||
|
c,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
@ -171,6 +171,22 @@ impl Device {
|
|||||||
matches!(self, Self::Metal(_))
|
matches!(self, Self::Metal(_))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn supports_bf16(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Cuda(_) | Self::Metal(_) => true,
|
||||||
|
Self::Cpu => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return `BF16` for devices that support it, otherwise default to `F32`.
|
||||||
|
pub fn bf16_default_to_f32(&self) -> DType {
|
||||||
|
if self.supports_bf16() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||||
if crate::utils::cuda_is_available() {
|
if crate::utils::cuda_is_available() {
|
||||||
Self::new_cuda(ordinal)
|
Self::new_cuda(ordinal)
|
||||||
@ -289,17 +305,48 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||||
|
match self {
|
||||||
|
Device::Cpu => {
|
||||||
|
let storage = CpuDevice.alloc_uninit(shape, dtype)?;
|
||||||
|
Ok(Storage::Cpu(storage))
|
||||||
|
}
|
||||||
|
Device::Cuda(device) => {
|
||||||
|
let storage = device.alloc_uninit(shape, dtype)?;
|
||||||
|
Ok(Storage::Cuda(storage))
|
||||||
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.alloc_uninit(shape, dtype)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
|
||||||
|
match self {
|
||||||
|
Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
|
||||||
|
Device::Cuda(device) => {
|
||||||
|
let storage = device.storage_from_slice(data)?;
|
||||||
|
Ok(Storage::Cuda(storage))
|
||||||
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.storage_from_slice(data)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
let storage = array.to_cpu_storage();
|
let storage = array.to_cpu_storage();
|
||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
Device::Metal(device) => {
|
Device::Metal(device) => {
|
||||||
let storage = array.to_cpu_storage();
|
let storage = array.to_cpu_storage();
|
||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
||||||
Ok(Storage::Metal(storage))
|
Ok(Storage::Metal(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -310,14 +357,22 @@ impl Device {
|
|||||||
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
|
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
let storage = S::to_cpu_storage_owned(data);
|
let storage = S::to_cpu_storage_owned(data);
|
||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
Device::Metal(device) => {
|
Device::Metal(device) => {
|
||||||
let storage = S::to_cpu_storage_owned(data);
|
let storage = S::to_cpu_storage_owned(data);
|
||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
||||||
Ok(Storage::Metal(storage))
|
Ok(Storage::Metal(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn synchronize(&self) -> Result<()> {
|
||||||
|
match self {
|
||||||
|
Self::Cpu => Ok(()),
|
||||||
|
Self::Cuda(d) => d.synchronize(),
|
||||||
|
Self::Metal(d) => d.synchronize(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -65,12 +65,13 @@ impl std::fmt::Debug for Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Options for Tensor pretty printing
|
/// Options for Tensor pretty printing
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct PrinterOptions {
|
pub struct PrinterOptions {
|
||||||
precision: usize,
|
pub precision: usize,
|
||||||
threshold: usize,
|
pub threshold: usize,
|
||||||
edge_items: usize,
|
pub edge_items: usize,
|
||||||
line_width: usize,
|
pub line_width: usize,
|
||||||
sci_mode: Option<bool>,
|
pub sci_mode: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
|
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
|
||||||
@ -89,6 +90,10 @@ impl PrinterOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
|
||||||
|
&PRINT_OPTS
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_print_options(options: PrinterOptions) {
|
pub fn set_print_options(options: PrinterOptions) {
|
||||||
*PRINT_OPTS.lock().unwrap() = options
|
*PRINT_OPTS.lock().unwrap() = options
|
||||||
}
|
}
|
||||||
@ -117,6 +122,26 @@ pub fn set_print_options_full() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_line_width(line_width: usize) {
|
||||||
|
PRINT_OPTS.lock().unwrap().line_width = line_width
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_precision(precision: usize) {
|
||||||
|
PRINT_OPTS.lock().unwrap().precision = precision
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_edge_items(edge_items: usize) {
|
||||||
|
PRINT_OPTS.lock().unwrap().edge_items = edge_items
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_threshold(threshold: usize) {
|
||||||
|
PRINT_OPTS.lock().unwrap().threshold = threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_sci_mode(sci_mode: Option<bool>) {
|
||||||
|
PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
|
||||||
|
}
|
||||||
|
|
||||||
struct FmtSize {
|
struct FmtSize {
|
||||||
current_size: usize,
|
current_size: usize,
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
//! Types for elements that can be stored and manipulated using tensors.
|
//! Types for elements that can be stored and manipulated using tensors.
|
||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::{CpuStorage, Error, Result};
|
use crate::{CpuStorage, CpuStorageRef, Error, Result};
|
||||||
|
|
||||||
/// The different types of elements allowed in tensors.
|
/// The different types of elements allowed in tensors.
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
@ -23,7 +23,15 @@ pub enum DType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub struct DTypeParseError;
|
pub struct DTypeParseError(String);
|
||||||
|
|
||||||
|
impl std::fmt::Display for DTypeParseError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "cannot parse '{}' as a dtype", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for DTypeParseError {}
|
||||||
|
|
||||||
impl std::str::FromStr for DType {
|
impl std::str::FromStr for DType {
|
||||||
type Err = DTypeParseError;
|
type Err = DTypeParseError;
|
||||||
@ -36,7 +44,7 @@ impl std::str::FromStr for DType {
|
|||||||
"f16" => Ok(Self::F16),
|
"f16" => Ok(Self::F16),
|
||||||
"f32" => Ok(Self::F32),
|
"f32" => Ok(Self::F32),
|
||||||
"f64" => Ok(Self::F64),
|
"f64" => Ok(Self::F64),
|
||||||
_ => Err(DTypeParseError),
|
_ => Err(DTypeParseError(s.to_string())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -92,12 +100,14 @@ pub trait WithDType:
|
|||||||
+ 'static
|
+ 'static
|
||||||
+ Send
|
+ Send
|
||||||
+ Sync
|
+ Sync
|
||||||
|
+ std::any::Any
|
||||||
+ crate::cpu::kernels::VecOps
|
+ crate::cpu::kernels::VecOps
|
||||||
{
|
{
|
||||||
const DTYPE: DType;
|
const DTYPE: DType;
|
||||||
|
|
||||||
fn from_f64(v: f64) -> Self;
|
fn from_f64(v: f64) -> Self;
|
||||||
fn to_f64(self) -> f64;
|
fn to_f64(self) -> f64;
|
||||||
|
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||||
|
|
||||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||||
@ -121,6 +131,10 @@ macro_rules! with_dtype {
|
|||||||
$to_f64(self)
|
$to_f64(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
||||||
|
CpuStorageRef::$dtype(data)
|
||||||
|
}
|
||||||
|
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||||
CpuStorage::$dtype(data)
|
CpuStorage::$dtype(data)
|
||||||
}
|
}
|
||||||
|
@ -154,6 +154,19 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn copy2d(
|
||||||
|
&self,
|
||||||
|
_: &mut Self,
|
||||||
|
_: usize,
|
||||||
|
_: usize,
|
||||||
|
_: usize,
|
||||||
|
_: usize,
|
||||||
|
_: usize,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
@ -197,10 +210,22 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
@ -208,4 +233,38 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn synchronize(&self) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||||
|
/// allowed with f16 GEMMs.
|
||||||
|
pub fn gemm_reduced_precision_f16() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||||
|
/// allowed with f16 GEMMs.
|
||||||
|
pub fn set_gemm_reduced_precision_f16(_: bool) {}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||||
|
/// allowed with bf16 GEMMs.
|
||||||
|
pub fn gemm_reduced_precision_bf16() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||||
|
/// allowed with bf16 GEMMs.
|
||||||
|
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||||
|
/// allowed with f32 GEMMs.
|
||||||
|
pub fn gemm_reduced_precision_f32() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||||
|
/// allowed with f32 GEMMs.
|
||||||
|
pub fn set_gemm_reduced_precision_f32(_b: bool) {}
|
||||||
|
@ -166,6 +166,19 @@ impl crate::backend::BackendStorage for MetalStorage {
|
|||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn copy2d(
|
||||||
|
&self,
|
||||||
|
_: &mut Self,
|
||||||
|
_: usize,
|
||||||
|
_: usize,
|
||||||
|
_: usize,
|
||||||
|
_: usize,
|
||||||
|
_: usize,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
@ -209,10 +222,22 @@ impl crate::backend::BackendDevice for MetalDevice {
|
|||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
@ -220,4 +245,8 @@ impl crate::backend::BackendDevice for MetalDevice {
|
|||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn synchronize(&self) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -219,10 +219,14 @@ impl Error {
|
|||||||
Self::Wrapped(Box::new(err)).bt()
|
Self::Wrapped(Box::new(err)).bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
pub fn msg(err: impl std::error::Error) -> Self {
|
||||||
Self::Msg(err.to_string()).bt()
|
Self::Msg(err.to_string()).bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn debug(err: impl std::fmt::Debug) -> Self {
|
||||||
|
Self::Msg(format!("{err:?}")).bt()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn bt(self) -> Self {
|
pub fn bt(self) -> Self {
|
||||||
let backtrace = std::backtrace::Backtrace::capture();
|
let backtrace = std::backtrace::Backtrace::capture();
|
||||||
match backtrace.status() {
|
match backtrace.status() {
|
||||||
|
@ -141,28 +141,117 @@ impl<T> IndexOp<T> for Tensor
|
|||||||
where
|
where
|
||||||
T: Into<TensorIndexer>,
|
T: Into<TensorIndexer>,
|
||||||
{
|
{
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, DType, Device, IndexOp};
|
||||||
|
/// let a = Tensor::new(&[
|
||||||
|
/// [0., 1.],
|
||||||
|
/// [2., 3.],
|
||||||
|
/// [4., 5.]
|
||||||
|
/// ], &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let b = a.i(0)?;
|
||||||
|
/// assert_eq!(b.shape().dims(), &[2]);
|
||||||
|
/// assert_eq!(b.to_vec1::<f64>()?, &[0., 1.]);
|
||||||
|
///
|
||||||
|
/// let c = a.i(..2)?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2, 2]);
|
||||||
|
/// assert_eq!(c.to_vec2::<f64>()?, &[
|
||||||
|
/// [0., 1.],
|
||||||
|
/// [2., 3.]
|
||||||
|
/// ]);
|
||||||
|
///
|
||||||
|
/// let d = a.i(1..)?;
|
||||||
|
/// assert_eq!(d.shape().dims(), &[2, 2]);
|
||||||
|
/// assert_eq!(d.to_vec2::<f64>()?, &[
|
||||||
|
/// [2., 3.],
|
||||||
|
/// [4., 5.]
|
||||||
|
/// ]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
fn i(&self, index: T) -> Result<Tensor, Error> {
|
fn i(&self, index: T) -> Result<Tensor, Error> {
|
||||||
self.index(&[index.into()])
|
self.index(&[index.into()])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<A> IndexOp<(A,)> for Tensor
|
||||||
|
where
|
||||||
|
A: Into<TensorIndexer>,
|
||||||
|
{
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, DType, Device, IndexOp};
|
||||||
|
/// let a = Tensor::new(&[
|
||||||
|
/// [0f32, 1.],
|
||||||
|
/// [2. , 3.],
|
||||||
|
/// [4. , 5.]
|
||||||
|
/// ], &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let b = a.i((0,))?;
|
||||||
|
/// assert_eq!(b.shape().dims(), &[2]);
|
||||||
|
/// assert_eq!(b.to_vec1::<f32>()?, &[0., 1.]);
|
||||||
|
///
|
||||||
|
/// let c = a.i((..2,))?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2, 2]);
|
||||||
|
/// assert_eq!(c.to_vec2::<f32>()?, &[
|
||||||
|
/// [0., 1.],
|
||||||
|
/// [2., 3.]
|
||||||
|
/// ]);
|
||||||
|
///
|
||||||
|
/// let d = a.i((1..,))?;
|
||||||
|
/// assert_eq!(d.shape().dims(), &[2, 2]);
|
||||||
|
/// assert_eq!(d.to_vec2::<f32>()?, &[
|
||||||
|
/// [2., 3.],
|
||||||
|
/// [4., 5.]
|
||||||
|
/// ]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
fn i(&self, (a,): (A,)) -> Result<Tensor, Error> {
|
||||||
|
self.index(&[a.into()])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
impl<A, B> IndexOp<(A, B)> for Tensor
|
||||||
|
where
|
||||||
|
A: Into<TensorIndexer>,
|
||||||
|
B: Into<TensorIndexer>,
|
||||||
|
{
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, DType, Device, IndexOp};
|
||||||
|
/// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let b = a.i((1, 0))?;
|
||||||
|
/// assert_eq!(b.to_vec0::<f32>()?, 3.);
|
||||||
|
///
|
||||||
|
/// let c = a.i((..2, 1))?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2]);
|
||||||
|
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||||
|
///
|
||||||
|
/// let d = a.i((2.., ..))?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2]);
|
||||||
|
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
|
||||||
|
self.index(&[a.into(), b.into()])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! index_op_tuple {
|
macro_rules! index_op_tuple {
|
||||||
($($t:ident),+) => {
|
($doc:tt, $($t:ident),+) => {
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
|
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
|
||||||
where
|
where
|
||||||
$($t: Into<TensorIndexer>,)*
|
$($t: Into<TensorIndexer>,)*
|
||||||
{
|
{
|
||||||
|
#[doc=$doc]
|
||||||
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
|
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
|
||||||
self.index(&[$($t.into(),)*])
|
self.index(&[$($t.into(),)*])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
index_op_tuple!(A);
|
|
||||||
index_op_tuple!(A, B);
|
index_op_tuple!("see [TensorIndex#method.i]", A, B, C);
|
||||||
index_op_tuple!(A, B, C);
|
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D);
|
||||||
index_op_tuple!(A, B, C, D);
|
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E);
|
||||||
index_op_tuple!(A, B, C, D, E);
|
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F);
|
||||||
index_op_tuple!(A, B, C, D, E, F);
|
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G);
|
||||||
index_op_tuple!(A, B, C, D, E, F, G);
|
|
||||||
|
@ -70,7 +70,7 @@ impl Layout {
|
|||||||
self.shape.is_fortran_contiguous(&self.stride)
|
self.shape.is_fortran_contiguous(&self.stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||||
let dims = self.shape().dims();
|
let dims = self.shape().dims();
|
||||||
if dim >= dims.len() {
|
if dim >= dims.len() {
|
||||||
Err(Error::DimOutOfRange {
|
Err(Error::DimOutOfRange {
|
||||||
@ -99,7 +99,7 @@ impl Layout {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||||
let rank = self.shape.rank();
|
let rank = self.shape.rank();
|
||||||
if rank <= dim1 || rank <= dim2 {
|
if rank <= dim1 || rank <= dim2 {
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
@ -120,7 +120,7 @@ impl Layout {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||||
let is_permutation =
|
let is_permutation =
|
||||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||||
if !is_permutation {
|
if !is_permutation {
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
//!
|
//!
|
||||||
//! ## Features
|
//! ## Features
|
||||||
//!
|
//!
|
||||||
//! - Simple syntax (looks and like PyTorch)
|
//! - Simple syntax (looks and feels like PyTorch)
|
||||||
//! - CPU and Cuda backends (and M1 support)
|
//! - CPU and Cuda backends (and M1 support)
|
||||||
//! - Enable serverless (CPU) small and fast deployments
|
//! - Enable serverless (CPU) small and fast deployments
|
||||||
//! - Model training
|
//! - Model training
|
||||||
@ -37,18 +37,17 @@
|
|||||||
mod accelerate;
|
mod accelerate;
|
||||||
pub mod backend;
|
pub mod backend;
|
||||||
pub mod backprop;
|
pub mod backprop;
|
||||||
mod conv;
|
pub mod conv;
|
||||||
mod convert;
|
mod convert;
|
||||||
pub mod cpu;
|
pub mod cpu;
|
||||||
pub mod cpu_backend;
|
pub mod cpu_backend;
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub mod cuda_backend;
|
pub mod cuda_backend;
|
||||||
#[cfg(feature = "cudnn")]
|
mod custom_op;
|
||||||
pub mod cudnn;
|
|
||||||
mod device;
|
mod device;
|
||||||
pub mod display;
|
pub mod display;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
mod dummy_cuda_backend;
|
pub mod dummy_cuda_backend;
|
||||||
mod dummy_metal_backend;
|
mod dummy_metal_backend;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
mod indexer;
|
mod indexer;
|
||||||
@ -58,37 +57,46 @@ pub mod metal_backend;
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
mod mkl;
|
mod mkl;
|
||||||
pub mod npy;
|
pub mod npy;
|
||||||
mod op;
|
pub mod op;
|
||||||
pub mod pickle;
|
pub mod pickle;
|
||||||
pub mod quantized;
|
pub mod quantized;
|
||||||
pub mod safetensors;
|
pub mod safetensors;
|
||||||
pub mod scalar;
|
pub mod scalar;
|
||||||
pub mod shape;
|
pub mod shape;
|
||||||
|
mod sort;
|
||||||
mod storage;
|
mod storage;
|
||||||
|
pub mod streaming;
|
||||||
mod strided_index;
|
mod strided_index;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
|
mod tensor_cat;
|
||||||
pub mod test_utils;
|
pub mod test_utils;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
mod variable;
|
mod variable;
|
||||||
|
|
||||||
pub use cpu_backend::CpuStorage;
|
#[cfg(feature = "cudnn")]
|
||||||
|
pub use cuda_backend::cudnn;
|
||||||
|
|
||||||
|
pub use cpu_backend::{CpuStorage, CpuStorageRef};
|
||||||
|
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||||
pub use device::{Device, DeviceLocation, NdArray};
|
pub use device::{Device, DeviceLocation, NdArray};
|
||||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use indexer::IndexOp;
|
pub use indexer::{IndexOp, TensorIndexer};
|
||||||
pub use layout::Layout;
|
pub use layout::Layout;
|
||||||
pub use op::{CustomOp1, CustomOp2, CustomOp3};
|
|
||||||
pub use shape::{Shape, D};
|
pub use shape::{Shape, D};
|
||||||
pub use storage::Storage;
|
pub use storage::Storage;
|
||||||
|
pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule};
|
||||||
pub use strided_index::{StridedBlocks, StridedIndex};
|
pub use strided_index::{StridedBlocks, StridedIndex};
|
||||||
pub use tensor::{Tensor, TensorId};
|
pub use tensor::{Tensor, TensorId};
|
||||||
pub use variable::Var;
|
pub use variable::Var;
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub use cuda_backend::{CudaDevice, CudaStorage};
|
pub use cuda_backend as cuda;
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
pub use dummy_cuda_backend as cuda;
|
||||||
|
|
||||||
|
pub use cuda::{CudaDevice, CudaStorage};
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||||
@ -129,6 +137,15 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<M: Module> Module for Option<&M> {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
None => Ok(xs.clone()),
|
||||||
|
Some(m) => m.forward(xs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||||
// separate the training and evaluation behaviors.
|
// separate the training and evaluation behaviors.
|
||||||
pub trait ModuleT {
|
pub trait ModuleT {
|
||||||
|
324
candle-core/src/metal_backend/device.rs
Normal file
324
candle-core/src/metal_backend/device.rs
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
use crate::{DType, Result};
|
||||||
|
use candle_metal_kernels::Kernels;
|
||||||
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::ffi::c_void;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::{Arc, Mutex, RwLock};
|
||||||
|
|
||||||
|
use super::MetalError;
|
||||||
|
|
||||||
|
/// Unique identifier for cuda devices.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||||
|
pub struct DeviceId(usize);
|
||||||
|
|
||||||
|
impl DeviceId {
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||||
|
use std::sync::atomic;
|
||||||
|
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||||
|
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||||
|
pub(crate) struct Commands {
|
||||||
|
/// Single command queue for the entire device.
|
||||||
|
command_queue: CommandQueue,
|
||||||
|
/// One command buffer at a time.
|
||||||
|
/// The scheduler works by allowing multiple
|
||||||
|
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||||
|
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||||
|
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||||
|
/// to start to work).
|
||||||
|
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||||
|
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||||
|
/// command buffer2 starts (or there are metal bugs there)
|
||||||
|
command_buffer: CommandBuffer,
|
||||||
|
/// Keeps track of the current amount of compute command encoders on the current
|
||||||
|
/// command buffer
|
||||||
|
/// Arc, RwLock because of the interior mutability.
|
||||||
|
command_buffer_index: usize,
|
||||||
|
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||||
|
compute_per_buffer: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Commands {
|
||||||
|
pub(crate) fn new(command_queue: CommandQueue) -> Result<Self> {
|
||||||
|
let command_buffer = command_queue.new_command_buffer().to_owned();
|
||||||
|
command_buffer.enqueue();
|
||||||
|
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||||
|
Ok(val) => val.parse()?,
|
||||||
|
_ => 50,
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
command_queue,
|
||||||
|
command_buffer,
|
||||||
|
command_buffer_index: 0,
|
||||||
|
compute_per_buffer,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> {
|
||||||
|
let mut command_buffer = self.command_buffer.to_owned();
|
||||||
|
let mut flushed = false;
|
||||||
|
if self.command_buffer_index > self.compute_per_buffer {
|
||||||
|
self.command_buffer.commit();
|
||||||
|
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
|
self.command_buffer = command_buffer.clone();
|
||||||
|
self.command_buffer_index = 0;
|
||||||
|
flushed = true;
|
||||||
|
}
|
||||||
|
self.command_buffer_index += 1;
|
||||||
|
Ok((flushed, command_buffer))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wait_until_completed(&mut self) -> Result<()> {
|
||||||
|
match self.command_buffer.status() {
|
||||||
|
metal::MTLCommandBufferStatus::Committed
|
||||||
|
| metal::MTLCommandBufferStatus::Scheduled
|
||||||
|
| metal::MTLCommandBufferStatus::Completed => {
|
||||||
|
panic!("Already committed");
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
self.command_buffer.commit();
|
||||||
|
self.command_buffer.wait_until_completed();
|
||||||
|
self.command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MetalDevice {
|
||||||
|
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
|
||||||
|
/// the device itself.
|
||||||
|
pub(crate) id: DeviceId,
|
||||||
|
|
||||||
|
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||||
|
pub(crate) device: metal::Device,
|
||||||
|
|
||||||
|
pub(crate) commands: Arc<RwLock<Commands>>,
|
||||||
|
|
||||||
|
/// Simple allocator struct.
|
||||||
|
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||||
|
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||||
|
/// (could be linked to FFI communication overhead).
|
||||||
|
///
|
||||||
|
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
|
||||||
|
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
|
||||||
|
/// to be reused. However, in order for this to work, we need to guarantee the order of
|
||||||
|
/// operation, so that this buffer is not being used by another kernel at the same time.
|
||||||
|
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
||||||
|
///
|
||||||
|
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||||
|
/// (strong_count = 1).
|
||||||
|
pub(crate) buffers: Arc<RwLock<BufferMap>>,
|
||||||
|
|
||||||
|
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||||
|
/// Heavily used by [`candle_metal_kernels`]
|
||||||
|
pub(crate) kernels: Arc<Kernels>,
|
||||||
|
/// Seed for random number generation.
|
||||||
|
pub(crate) seed: Arc<Mutex<Buffer>>,
|
||||||
|
/// Whether to use the MLX matmul kernels instead of the MFA ones.
|
||||||
|
pub(crate) use_mlx_mm: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for MetalDevice {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "MetalDevice({:?})", self.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for MetalDevice {
|
||||||
|
type Target = metal::DeviceRef;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MetalDevice {
|
||||||
|
pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) {
|
||||||
|
self.use_mlx_mm = use_mlx_mm
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> DeviceId {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn metal_device(&self) -> &metal::Device {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
|
fn drop_unused_buffers(&self) -> Result<()> {
|
||||||
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
|
for subbuffers in buffers.values_mut() {
|
||||||
|
let newbuffers = subbuffers
|
||||||
|
.iter()
|
||||||
|
.filter(|s| Arc::strong_count(*s) > 1)
|
||||||
|
.map(Arc::clone)
|
||||||
|
.collect();
|
||||||
|
*subbuffers = newbuffers;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||||
|
let mut commands = self.commands.write().map_err(MetalError::from)?;
|
||||||
|
let (flushed, command_buffer) = commands.command_buffer()?;
|
||||||
|
if flushed {
|
||||||
|
self.drop_unused_buffers()?
|
||||||
|
}
|
||||||
|
Ok(command_buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wait_until_completed(&self) -> Result<()> {
|
||||||
|
let mut commands = self.commands.write().map_err(MetalError::from)?;
|
||||||
|
commands.wait_until_completed()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn kernels(&self) -> &Kernels {
|
||||||
|
&self.kernels
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &metal::Device {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new buffer (not necessarily zeroed).
|
||||||
|
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||||
|
/// This means the buffer data cannot be read on the CPU directly.
|
||||||
|
///
|
||||||
|
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
||||||
|
pub fn new_buffer(
|
||||||
|
&self,
|
||||||
|
element_count: usize,
|
||||||
|
dtype: DType,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Arc<Buffer>> {
|
||||||
|
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||||
|
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new buffer (not necessarily zeroed).
|
||||||
|
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||||
|
/// This means the buffer can be read on the CPU but will require manual
|
||||||
|
/// synchronization when the CPU memory is modified
|
||||||
|
/// Used as a bridge to gather data back from the GPU
|
||||||
|
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
||||||
|
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new buffer from data.
|
||||||
|
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||||
|
///
|
||||||
|
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
||||||
|
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
||||||
|
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||||
|
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||||
|
let new_buffer = self.device.new_buffer_with_data(
|
||||||
|
data.as_ptr() as *const c_void,
|
||||||
|
size,
|
||||||
|
MTLResourceOptions::StorageModeManaged,
|
||||||
|
);
|
||||||
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
|
|
||||||
|
let subbuffers = buffers
|
||||||
|
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||||
|
.or_insert(vec![]);
|
||||||
|
|
||||||
|
let new_buffer = Arc::new(new_buffer);
|
||||||
|
subbuffers.push(new_buffer.clone());
|
||||||
|
Ok(new_buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
||||||
|
let buffer = self.allocate_buffer(
|
||||||
|
size_in_bytes as NSUInteger,
|
||||||
|
MTLResourceOptions::StorageModePrivate,
|
||||||
|
"allocate_zeros",
|
||||||
|
)?;
|
||||||
|
let command_buffer = self.command_buffer()?;
|
||||||
|
command_buffer.set_label("zeros");
|
||||||
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
|
blit.fill_buffer(
|
||||||
|
&buffer,
|
||||||
|
metal::NSRange {
|
||||||
|
location: 0,
|
||||||
|
length: buffer.length(),
|
||||||
|
},
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
blit.end_encoding();
|
||||||
|
Ok(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The critical allocator algorithm
|
||||||
|
fn allocate_buffer(
|
||||||
|
&self,
|
||||||
|
size: NSUInteger,
|
||||||
|
option: MTLResourceOptions,
|
||||||
|
_name: &str,
|
||||||
|
) -> Result<Arc<Buffer>> {
|
||||||
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
|
if let Some(b) = find_available_buffer(size, option, &buffers) {
|
||||||
|
// Cloning also ensures we increment the strong count
|
||||||
|
return Ok(b.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let size = buf_size(size);
|
||||||
|
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||||
|
|
||||||
|
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||||
|
let new_buffer = Arc::new(new_buffer);
|
||||||
|
subbuffers.push(new_buffer.clone());
|
||||||
|
|
||||||
|
Ok(new_buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a metal GPU capture trace on [`path`].
|
||||||
|
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||||
|
let capture = metal::CaptureManager::shared();
|
||||||
|
let descriptor = metal::CaptureDescriptor::new();
|
||||||
|
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||||
|
descriptor.set_capture_device(self);
|
||||||
|
// The [set_output_url] call requires an absolute path so we convert it if needed.
|
||||||
|
if path.as_ref().is_absolute() {
|
||||||
|
descriptor.set_output_url(path);
|
||||||
|
} else {
|
||||||
|
let path = std::env::current_dir()?.join(path);
|
||||||
|
descriptor.set_output_url(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
capture
|
||||||
|
.start_capture(&descriptor)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||||
|
size.saturating_sub(1).next_power_of_two() as NSUInteger
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_available_buffer(
|
||||||
|
size: NSUInteger,
|
||||||
|
option: MTLResourceOptions,
|
||||||
|
buffers: &BufferMap,
|
||||||
|
) -> Option<Arc<Buffer>> {
|
||||||
|
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||||
|
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||||
|
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||||
|
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||||
|
for sub in subbuffers {
|
||||||
|
if Arc::strong_count(sub) == 1 {
|
||||||
|
best_buffer = Some(sub);
|
||||||
|
best_buffer_size = *buffer_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
best_buffer.cloned()
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -333,6 +333,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
|||||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||||
|
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||||
|
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||||
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
@ -355,6 +365,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = -v
|
||||||
|
}
|
||||||
|
vs_exp_inplace(ys);
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = v / (1.0 + *y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = -v
|
||||||
|
}
|
||||||
|
vd_exp_inplace(ys);
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = v / (1.0 + *y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! binary_op {
|
macro_rules! binary_op {
|
||||||
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -330,7 +330,7 @@ impl Tensor {
|
|||||||
path: P,
|
path: P,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
|
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
|
||||||
let options =
|
let options: zip::write::FileOptions<()> =
|
||||||
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
|
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
|
||||||
|
|
||||||
for (name, tensor) in ts.iter() {
|
for (name, tensor) in ts.iter() {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
use crate::Tensor;
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use num_traits::float::Float;
|
use num_traits::float::Float;
|
||||||
|
|
||||||
@ -61,10 +61,12 @@ pub enum UnaryOp {
|
|||||||
GeluErf,
|
GeluErf,
|
||||||
Erf,
|
Erf,
|
||||||
Relu,
|
Relu,
|
||||||
|
Silu,
|
||||||
Tanh,
|
Tanh,
|
||||||
Floor,
|
Floor,
|
||||||
Ceil,
|
Ceil,
|
||||||
Round,
|
Round,
|
||||||
|
Sign,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -131,7 +133,10 @@ pub enum Op {
|
|||||||
stride: (usize, usize),
|
stride: (usize, usize),
|
||||||
},
|
},
|
||||||
|
|
||||||
UpsampleNearest1D(Tensor),
|
UpsampleNearest1D {
|
||||||
|
arg: Tensor,
|
||||||
|
target_size: usize,
|
||||||
|
},
|
||||||
UpsampleNearest2D {
|
UpsampleNearest2D {
|
||||||
arg: Tensor,
|
arg: Tensor,
|
||||||
target_h: usize,
|
target_h: usize,
|
||||||
@ -157,168 +162,23 @@ pub enum Op {
|
|||||||
Permute(Tensor, Vec<usize>),
|
Permute(Tensor, Vec<usize>),
|
||||||
Elu(Tensor, f64),
|
Elu(Tensor, f64),
|
||||||
Powf(Tensor, f64),
|
Powf(Tensor, f64),
|
||||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
|
CustomOp1(
|
||||||
|
Tensor,
|
||||||
|
std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
|
||||||
|
),
|
||||||
CustomOp2(
|
CustomOp2(
|
||||||
Tensor,
|
Tensor,
|
||||||
Tensor,
|
Tensor,
|
||||||
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
|
||||||
),
|
),
|
||||||
CustomOp3(
|
CustomOp3(
|
||||||
Tensor,
|
Tensor,
|
||||||
Tensor,
|
Tensor,
|
||||||
Tensor,
|
Tensor,
|
||||||
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unary ops that can be defined in user-land.
|
|
||||||
pub trait CustomOp1 {
|
|
||||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_storage: &MetalStorage,
|
|
||||||
_layout: &Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
|
||||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
|
||||||
/// The function should return the gradient of the argument.
|
|
||||||
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
|
||||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait CustomOp2 {
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
s1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
s2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
) -> Result<(CpuStorage, Shape)>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(CudaStorage, Shape)> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bwd(
|
|
||||||
&self,
|
|
||||||
_arg1: &Tensor,
|
|
||||||
_arg2: &Tensor,
|
|
||||||
_res: &Tensor,
|
|
||||||
_grad_res: &Tensor,
|
|
||||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
|
||||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait CustomOp3 {
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
s1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
s2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
s3: &CpuStorage,
|
|
||||||
l3: &Layout,
|
|
||||||
) -> Result<(CpuStorage, Shape)>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(CudaStorage, Shape)> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bwd(
|
|
||||||
&self,
|
|
||||||
_arg1: &Tensor,
|
|
||||||
_arg2: &Tensor,
|
|
||||||
_arg3: &Tensor,
|
|
||||||
_res: &Tensor,
|
|
||||||
_grad_res: &Tensor,
|
|
||||||
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
|
||||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait UnaryOpT {
|
pub trait UnaryOpT {
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
const KERNEL: &'static str;
|
const KERNEL: &'static str;
|
||||||
@ -390,10 +250,12 @@ pub(crate) struct Gelu;
|
|||||||
pub(crate) struct GeluErf;
|
pub(crate) struct GeluErf;
|
||||||
pub(crate) struct Erf;
|
pub(crate) struct Erf;
|
||||||
pub(crate) struct Relu;
|
pub(crate) struct Relu;
|
||||||
|
pub(crate) struct Silu;
|
||||||
pub(crate) struct Tanh;
|
pub(crate) struct Tanh;
|
||||||
pub(crate) struct Floor;
|
pub(crate) struct Floor;
|
||||||
pub(crate) struct Ceil;
|
pub(crate) struct Ceil;
|
||||||
pub(crate) struct Round;
|
pub(crate) struct Round;
|
||||||
|
pub(crate) struct Sign;
|
||||||
|
|
||||||
macro_rules! bin_op {
|
macro_rules! bin_op {
|
||||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||||
@ -597,6 +459,13 @@ unary_op!(Recip, "recip", v, v.recip());
|
|||||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||||
|
|
||||||
|
// Hardcode the value for sqrt(2/pi)
|
||||||
|
// https://github.com/huggingface/candle/issues/1982
|
||||||
|
#[allow(clippy::excessive_precision)]
|
||||||
|
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
|
||||||
|
#[allow(clippy::excessive_precision)]
|
||||||
|
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
|
||||||
|
|
||||||
/// Tanh based approximation of the `gelu` operation
|
/// Tanh based approximation of the `gelu` operation
|
||||||
/// GeluErf is the more precise one.
|
/// GeluErf is the more precise one.
|
||||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
@ -609,7 +478,7 @@ impl UnaryOpT for Gelu {
|
|||||||
* v
|
* v
|
||||||
* (bf16::ONE
|
* (bf16::ONE
|
||||||
+ bf16::tanh(
|
+ bf16::tanh(
|
||||||
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
|
bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
||||||
* v
|
* v
|
||||||
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
||||||
))
|
))
|
||||||
@ -620,22 +489,18 @@ impl UnaryOpT for Gelu {
|
|||||||
* v
|
* v
|
||||||
* (f16::ONE
|
* (f16::ONE
|
||||||
+ f16::tanh(
|
+ f16::tanh(
|
||||||
(f16::from_f32_const(2.0) / f16::PI).sqrt()
|
f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
||||||
* v
|
* v
|
||||||
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn f32(v: f32) -> f32 {
|
fn f32(v: f32) -> f32 {
|
||||||
0.5 * v
|
0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
|
||||||
* (1.0
|
|
||||||
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
|
||||||
}
|
}
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn f64(v: f64) -> f64 {
|
fn f64(v: f64) -> f64 {
|
||||||
0.5 * v
|
0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
|
||||||
* (1.0
|
|
||||||
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
|
||||||
}
|
}
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn u8(_: u8) -> u8 {
|
fn u8(_: u8) -> u8 {
|
||||||
@ -724,6 +589,77 @@ impl UnaryOpT for Erf {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Silu operation
|
||||||
|
impl UnaryOpT for Silu {
|
||||||
|
const NAME: &'static str = "silu";
|
||||||
|
const V: Self = Silu;
|
||||||
|
#[inline(always)]
|
||||||
|
fn bf16(v: bf16) -> bf16 {
|
||||||
|
v / (bf16::ONE + (-v).exp())
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f16(v: f16) -> f16 {
|
||||||
|
v / (f16::ONE + (-v).exp())
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32(v: f32) -> f32 {
|
||||||
|
v / (1.0 + (-v).exp())
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
v / (1.0 + (-v).exp())
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u8(_: u8) -> u8 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u32(_: u32) -> u32 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn i64(_: i64) -> i64 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
const KERNEL: &'static str = "usilu";
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
const F32_VEC: bool = true;
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||||
|
crate::mkl::vs_silu(xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
const F64_VEC: bool = true;
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||||
|
crate::mkl::vd_silu(xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
const F32_VEC: bool = true;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||||
|
crate::accelerate::vs_silu(xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
const F64_VEC: bool = true;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||||
|
crate::accelerate::vd_silu(xs, ys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UnaryOpT for Abs {
|
impl UnaryOpT for Abs {
|
||||||
const NAME: &'static str = "abs";
|
const NAME: &'static str = "abs";
|
||||||
const KERNEL: &'static str = "uabs";
|
const KERNEL: &'static str = "uabs";
|
||||||
@ -991,3 +927,37 @@ impl std::ops::Deref for BackpropOp {
|
|||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl UnaryOpT for Sign {
|
||||||
|
const NAME: &'static str = "sign";
|
||||||
|
const KERNEL: &'static str = "usign";
|
||||||
|
const V: Self = Sign;
|
||||||
|
#[inline(always)]
|
||||||
|
fn bf16(v: bf16) -> bf16 {
|
||||||
|
bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f16(v: f16) -> f16 {
|
||||||
|
f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32(v: f32) -> f32 {
|
||||||
|
f32::from(v > 0.) - f32::from(v < 0.)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
f64::from(v > 0.) - f64::from(v < 0.)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u8(v: u8) -> u8 {
|
||||||
|
u8::min(1, v)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u32(v: u32) -> u32 {
|
||||||
|
u32::min(1, v)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn i64(v: i64) -> i64 {
|
||||||
|
(v > 0) as i64 - (v < 0) as i64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -42,7 +42,7 @@ pub enum OpCode {
|
|||||||
Stop = b'.',
|
Stop = b'.',
|
||||||
NewObj = 0x81,
|
NewObj = 0x81,
|
||||||
EmptyList = b']',
|
EmptyList = b']',
|
||||||
BinFloat = b'g',
|
BinFloat = b'G',
|
||||||
Append = b'a',
|
Append = b'a',
|
||||||
Appends = b'e',
|
Appends = b'e',
|
||||||
}
|
}
|
||||||
@ -217,6 +217,13 @@ impl Object {
|
|||||||
let args = args.remove(1);
|
let args = args.remove(1);
|
||||||
(callable, args)
|
(callable, args)
|
||||||
}
|
}
|
||||||
|
Object::Class {
|
||||||
|
module_name,
|
||||||
|
class_name,
|
||||||
|
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
|
||||||
|
let mut args = args.tuple()?;
|
||||||
|
args.remove(0).reduce()?
|
||||||
|
}
|
||||||
_ => (callable, args),
|
_ => (callable, args),
|
||||||
};
|
};
|
||||||
match callable {
|
match callable {
|
||||||
@ -227,13 +234,11 @@ impl Object {
|
|||||||
_ => return Ok(None),
|
_ => return Ok(None),
|
||||||
};
|
};
|
||||||
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||||
let mut path = dir_name.to_path_buf();
|
|
||||||
path.push(file_path);
|
|
||||||
Ok(Some(TensorInfo {
|
Ok(Some(TensorInfo {
|
||||||
name,
|
name,
|
||||||
dtype,
|
dtype,
|
||||||
layout,
|
layout,
|
||||||
path: path.to_string_lossy().into_owned(),
|
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
|
||||||
storage_size,
|
storage_size,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@ -345,8 +350,10 @@ impl Stack {
|
|||||||
module_name,
|
module_name,
|
||||||
class_name,
|
class_name,
|
||||||
} => {
|
} => {
|
||||||
if module_name == "collections" && class_name == "OrderedDict" {
|
if module_name == "collections"
|
||||||
// TODO: have a separate ordered dict.
|
&& (class_name == "OrderedDict" || class_name == "defaultdict")
|
||||||
|
{
|
||||||
|
// TODO: have a separate ordered dict and a separate default dict.
|
||||||
Some(Object::Dict(vec![]))
|
Some(Object::Dict(vec![]))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -455,7 +462,10 @@ impl Stack {
|
|||||||
self.push(Object::Int(arg))
|
self.push(Object::Int(arg))
|
||||||
}
|
}
|
||||||
OpCode::BinFloat => {
|
OpCode::BinFloat => {
|
||||||
let arg = r.read_f64::<LittleEndian>()?;
|
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian.
|
||||||
|
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
|
||||||
|
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
|
||||||
|
let arg = r.read_f64::<byteorder::BigEndian>()?;
|
||||||
self.push(Object::Float(arg))
|
self.push(Object::Float(arg))
|
||||||
}
|
}
|
||||||
OpCode::BinUnicode => {
|
OpCode::BinUnicode => {
|
||||||
@ -627,9 +637,16 @@ pub struct TensorInfo {
|
|||||||
pub storage_size: usize,
|
pub storage_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Read the tensor info from a .pth file.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `file` - The path to the .pth file.
|
||||||
|
/// * `verbose` - Whether to print debug information.
|
||||||
|
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
|
||||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||||
file: P,
|
file: P,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
|
key: Option<&str>,
|
||||||
) -> Result<Vec<TensorInfo>> {
|
) -> Result<Vec<TensorInfo>> {
|
||||||
let file = std::fs::File::open(file)?;
|
let file = std::fs::File::open(file)?;
|
||||||
let zip_reader = std::io::BufReader::new(file);
|
let zip_reader = std::io::BufReader::new(file);
|
||||||
@ -651,8 +668,9 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
stack.read_loop(&mut reader)?;
|
stack.read_loop(&mut reader)?;
|
||||||
let obj = stack.finalize()?;
|
let obj = stack.finalize()?;
|
||||||
if VERBOSE || verbose {
|
if VERBOSE || verbose {
|
||||||
println!("{obj:?}");
|
println!("{obj:#?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
let obj = match obj {
|
let obj = match obj {
|
||||||
Object::Build { callable, args } => match *callable {
|
Object::Build { callable, args } => match *callable {
|
||||||
Object::Reduce { callable, args: _ } => match *callable {
|
Object::Reduce { callable, args: _ } => match *callable {
|
||||||
@ -666,6 +684,24 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
},
|
},
|
||||||
obj => obj,
|
obj => obj,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// If key is provided, then we need to extract the state_dict from the object.
|
||||||
|
let obj = if let Some(key) = key {
|
||||||
|
if let Object::Dict(key_values) = obj {
|
||||||
|
key_values
|
||||||
|
.into_iter()
|
||||||
|
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
|
||||||
|
.map(|(_, v)| v)
|
||||||
|
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
|
||||||
|
} else {
|
||||||
|
obj
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
obj
|
||||||
|
};
|
||||||
|
|
||||||
|
// If the object is a dict, then we can extract the tensor info from it.
|
||||||
|
// NOTE: We are assuming that the `obj` is state_dict by this stage.
|
||||||
if let Object::Dict(key_values) = obj {
|
if let Object::Dict(key_values) = obj {
|
||||||
for (name, value) in key_values.into_iter() {
|
for (name, value) in key_values.into_iter() {
|
||||||
match value.into_tensor_info(name, &dir_name) {
|
match value.into_tensor_info(name, &dir_name) {
|
||||||
@ -688,8 +724,8 @@ pub struct PthTensors {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PthTensors {
|
impl PthTensors {
|
||||||
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
|
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
|
||||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
|
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
|
||||||
let tensor_infos = tensor_infos
|
let tensor_infos = tensor_infos
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|ti| (ti.name.to_string(), ti))
|
.map(|ti| (ti.name.to_string(), ti))
|
||||||
@ -712,10 +748,12 @@ impl PthTensors {
|
|||||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||||
|
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
|
||||||
|
let rank = tensor_info.layout.shape().rank();
|
||||||
|
|
||||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||||
// case.
|
// case and when the tensor is fortran contiguous.
|
||||||
if !tensor_info.layout.is_contiguous() {
|
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"cannot retrieve non-contiguous tensors {:?}",
|
"cannot retrieve non-contiguous tensors {:?}",
|
||||||
tensor_info.layout
|
tensor_info.layout
|
||||||
@ -733,13 +771,33 @@ impl PthTensors {
|
|||||||
tensor_info.dtype,
|
tensor_info.dtype,
|
||||||
&mut reader,
|
&mut reader,
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(tensor))
|
|
||||||
|
if rank > 1 && is_fortran_contiguous {
|
||||||
|
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
|
||||||
|
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
|
||||||
|
let tensor = tensor.reshape(shape_reversed)?;
|
||||||
|
|
||||||
|
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
|
||||||
|
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
|
||||||
|
let tensor = tensor.permute(dim_indeces_reversed)?;
|
||||||
|
Ok(Some(tensor))
|
||||||
|
} else {
|
||||||
|
Ok(Some(tensor))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read all the tensors from a PyTorch pth file.
|
/// Read all the tensors from a PyTorch pth file with a given key.
|
||||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
///
|
||||||
let pth = PthTensors::new(path)?;
|
/// # Arguments
|
||||||
|
/// * `path` - Path to the pth file.
|
||||||
|
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
||||||
|
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||||
|
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||||
|
path: P,
|
||||||
|
key: Option<&str>,
|
||||||
|
) -> Result<Vec<(String, Tensor)>> {
|
||||||
|
let pth = PthTensors::new(path, key)?;
|
||||||
let tensor_names = pth.tensor_infos.keys();
|
let tensor_names = pth.tensor_infos.keys();
|
||||||
let mut tensors = Vec::with_capacity(tensor_names.len());
|
let mut tensors = Vec::with_capacity(tensor_names.len());
|
||||||
for name in tensor_names {
|
for name in tensor_names {
|
||||||
@ -749,3 +807,11 @@ pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tenso
|
|||||||
}
|
}
|
||||||
Ok(tensors)
|
Ok(tensors)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Read all the tensors from a PyTorch pth file.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `path` - Path to the pth file.
|
||||||
|
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||||
|
read_all_with_key(path, None)
|
||||||
|
}
|
||||||
|
680
candle-core/src/quantized/cuda.rs
Normal file
680
candle-core/src/quantized/cuda.rs
Normal file
@ -0,0 +1,680 @@
|
|||||||
|
use super::{GgmlDType, QStorage};
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||||
|
use crate::{CudaDevice, CudaStorage, Result};
|
||||||
|
use half::f16;
|
||||||
|
|
||||||
|
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct QCudaStorage {
|
||||||
|
data: CudaSlice<u8>,
|
||||||
|
dtype: GgmlDType,
|
||||||
|
device: CudaDevice,
|
||||||
|
}
|
||||||
|
|
||||||
|
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
||||||
|
|
||||||
|
pub fn set_force_dmmv(f: bool) {
|
||||||
|
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const WARP_SIZE: usize = 32;
|
||||||
|
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
|
||||||
|
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
|
||||||
|
pub const NWARPS_Q4_0_AMPERE: usize = 4;
|
||||||
|
pub const GGML_CUDA_MMV_X: usize = 32;
|
||||||
|
pub const GGML_CUDA_MMV_Y: usize = 1;
|
||||||
|
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
|
||||||
|
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
||||||
|
pub const MATRIX_ROW_PADDING: usize = 512;
|
||||||
|
|
||||||
|
fn ceil_div(p: usize, q: usize) -> usize {
|
||||||
|
(p + q - 1) / q
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pad(p: usize, q: usize) -> usize {
|
||||||
|
ceil_div(p, q) * q
|
||||||
|
}
|
||||||
|
|
||||||
|
fn quantize_q8_1(
|
||||||
|
src: &CudaView<f32>,
|
||||||
|
dst: &mut CudaSlice<u8>,
|
||||||
|
elem_count: usize,
|
||||||
|
ky: usize,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<()> {
|
||||||
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let kx = elem_count;
|
||||||
|
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
||||||
|
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
||||||
|
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
||||||
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
grid_dim: (num_blocks as u32, ky as u32, 1),
|
||||||
|
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
let params = (src, dst, kx as i32, kx_padded as i32);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dequantize_f32(
|
||||||
|
data: &CudaSlice<u8>,
|
||||||
|
dtype: GgmlDType,
|
||||||
|
elem_count: usize,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaStorage> {
|
||||||
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let nb = (elem_count + 255) / 256;
|
||||||
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||||
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
||||||
|
GgmlDType::Q5_0 => (
|
||||||
|
"dequantize_block_q5_0_f32",
|
||||||
|
false,
|
||||||
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
|
),
|
||||||
|
GgmlDType::Q5_1 => (
|
||||||
|
"dequantize_block_q5_1_f32",
|
||||||
|
false,
|
||||||
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
|
),
|
||||||
|
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
|
||||||
|
GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
|
||||||
|
GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
|
||||||
|
GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
|
||||||
|
GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
|
||||||
|
GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
|
||||||
|
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
||||||
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
|
};
|
||||||
|
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
|
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
||||||
|
// See e.g.
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
grid_dim: (num_blocks as u32, 1, 1),
|
||||||
|
block_dim: (block_dim as u32, 1, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
if is_k {
|
||||||
|
let params = (data, &dst);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
} else {
|
||||||
|
let nb32 = match dtype {
|
||||||
|
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||||
|
_ => elem_count / 32,
|
||||||
|
};
|
||||||
|
let params = (data, &dst, nb32 as i32);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
}
|
||||||
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dequantize_f16(
|
||||||
|
data: &CudaSlice<u8>,
|
||||||
|
dtype: GgmlDType,
|
||||||
|
elem_count: usize,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaStorage> {
|
||||||
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let nb = (elem_count + 255) / 256;
|
||||||
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||||
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
||||||
|
GgmlDType::Q5_0 => (
|
||||||
|
"dequantize_block_q5_0_f16",
|
||||||
|
false,
|
||||||
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
|
),
|
||||||
|
GgmlDType::Q5_1 => (
|
||||||
|
"dequantize_block_q5_1_f16",
|
||||||
|
false,
|
||||||
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
|
),
|
||||||
|
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
|
||||||
|
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
|
||||||
|
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
|
||||||
|
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
|
||||||
|
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
|
||||||
|
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
|
||||||
|
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
||||||
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
|
};
|
||||||
|
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
|
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
||||||
|
// See e.g.
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
grid_dim: (num_blocks as u32, 1, 1),
|
||||||
|
block_dim: (block_dim as u32, 1, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
if is_k {
|
||||||
|
let params = (data, &dst);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
} else {
|
||||||
|
let nb32 = match dtype {
|
||||||
|
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||||
|
_ => elem_count / 32,
|
||||||
|
};
|
||||||
|
let params = (data, &dst, nb32 as i32);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
}
|
||||||
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dequantize_mul_mat_vec(
|
||||||
|
data: &CudaSlice<u8>,
|
||||||
|
y: &CudaView<f32>,
|
||||||
|
dtype: GgmlDType,
|
||||||
|
ncols: usize,
|
||||||
|
nrows: usize,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaStorage> {
|
||||||
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||||
|
if data_elems < ncols * nrows {
|
||||||
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
|
}
|
||||||
|
if y.len() != ncols {
|
||||||
|
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||||
|
}
|
||||||
|
let kernel_name = match dtype {
|
||||||
|
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
||||||
|
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
||||||
|
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
|
||||||
|
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
|
||||||
|
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
|
||||||
|
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
|
||||||
|
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
|
||||||
|
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
|
||||||
|
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
|
||||||
|
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
||||||
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
|
};
|
||||||
|
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
|
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||||
|
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||||
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
grid_dim: (block_num_y as u32, 1, 1),
|
||||||
|
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let params = (data, y, &dst, ncols as i32, nrows as i32);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mul_mat_vec_via_q8_1(
|
||||||
|
data: &CudaSlice<u8>,
|
||||||
|
y: &CudaView<f32>,
|
||||||
|
dtype: GgmlDType,
|
||||||
|
ncols: usize,
|
||||||
|
nrows: usize,
|
||||||
|
b_size: usize,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaStorage> {
|
||||||
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||||
|
if data_elems < ncols * nrows {
|
||||||
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
|
}
|
||||||
|
if y.len() != ncols * b_size {
|
||||||
|
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||||
|
}
|
||||||
|
if b_size == 0 || b_size > 8 {
|
||||||
|
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
|
||||||
|
}
|
||||||
|
// Start by quantizing y
|
||||||
|
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||||
|
let y_size_in_bytes =
|
||||||
|
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
|
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
||||||
|
|
||||||
|
let kernel_name = match dtype {
|
||||||
|
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
|
||||||
|
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
|
||||||
|
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
|
||||||
|
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
|
||||||
|
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
|
||||||
|
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
|
||||||
|
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
|
||||||
|
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
|
||||||
|
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
|
||||||
|
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
|
||||||
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
|
};
|
||||||
|
let kernel_name = format!("{kernel_name}{b_size}");
|
||||||
|
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
|
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||||
|
let (nblocks, nwarps) = match b_size {
|
||||||
|
1 => (nrows as u32, 4),
|
||||||
|
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
||||||
|
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
||||||
|
_ => crate::bail!("unexpected bsize {b_size}"),
|
||||||
|
};
|
||||||
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
grid_dim: (nblocks, 1, 1),
|
||||||
|
block_dim: (WARP_SIZE as u32, nwarps, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let params = (
|
||||||
|
data,
|
||||||
|
&y_q8_1,
|
||||||
|
&dst,
|
||||||
|
/* ncols_x */ ncols as i32,
|
||||||
|
/* nrows_x */ nrows as i32,
|
||||||
|
/* nrows_y */ ncols_padded as i32,
|
||||||
|
/* nrows_dst */ nrows as i32,
|
||||||
|
);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn mul_mat_via_q8_1(
|
||||||
|
data: &CudaSlice<u8>,
|
||||||
|
y: &CudaView<f32>,
|
||||||
|
dtype: GgmlDType,
|
||||||
|
x_rows: usize,
|
||||||
|
x_cols: usize,
|
||||||
|
y_rows: usize,
|
||||||
|
y_cols: usize,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaStorage> {
|
||||||
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||||
|
if data_elems < x_rows * x_cols {
|
||||||
|
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
||||||
|
}
|
||||||
|
if y.len() != y_rows * y_cols {
|
||||||
|
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
|
||||||
|
}
|
||||||
|
if x_cols != y_rows {
|
||||||
|
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
|
||||||
|
}
|
||||||
|
let k = x_cols;
|
||||||
|
// Start by quantizing y
|
||||||
|
let k_padded = pad(k, MATRIX_ROW_PADDING);
|
||||||
|
let y_size_in_bytes =
|
||||||
|
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
|
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
|
||||||
|
|
||||||
|
let (kernel_name, mmq_x, mmq_y) = match dtype {
|
||||||
|
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
|
||||||
|
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
|
||||||
|
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
|
||||||
|
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
|
||||||
|
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
|
||||||
|
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
|
||||||
|
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
|
||||||
|
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
|
||||||
|
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
|
||||||
|
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
||||||
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
|
};
|
||||||
|
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
|
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
||||||
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
grid_dim: (
|
||||||
|
ceil_div(x_rows, mmq_y) as u32,
|
||||||
|
ceil_div(y_cols, mmq_x) as u32,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let params = (
|
||||||
|
/* vx */ data,
|
||||||
|
/* vy */ &y_q8_1,
|
||||||
|
/* dst */ &dst,
|
||||||
|
/* ncols_x */ x_cols as i32,
|
||||||
|
/* nrows_x */ x_rows as i32,
|
||||||
|
/* ncols_y */ y_cols as i32,
|
||||||
|
/* nrows_y */ k_padded as i32,
|
||||||
|
/* nrows_dst */ x_rows as i32,
|
||||||
|
);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QCudaStorage {
|
||||||
|
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||||
|
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
||||||
|
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
|
||||||
|
Ok(QCudaStorage {
|
||||||
|
data,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
|
self.dtype
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &CudaDevice {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||||
|
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
|
||||||
|
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
||||||
|
let vec = slice.to_vec();
|
||||||
|
T::to_float(&vec, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
let fast_kernel = matches!(
|
||||||
|
self.dtype,
|
||||||
|
GgmlDType::Q4_0
|
||||||
|
| GgmlDType::Q4_1
|
||||||
|
| GgmlDType::Q5_0
|
||||||
|
| GgmlDType::Q5_1
|
||||||
|
| GgmlDType::Q8_0
|
||||||
|
| GgmlDType::Q2K
|
||||||
|
| GgmlDType::Q3K
|
||||||
|
| GgmlDType::Q4K
|
||||||
|
| GgmlDType::Q5K
|
||||||
|
| GgmlDType::Q6K
|
||||||
|
| GgmlDType::Q8K
|
||||||
|
);
|
||||||
|
if fast_kernel {
|
||||||
|
return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
|
||||||
|
}
|
||||||
|
// Run the dequantization on cpu.
|
||||||
|
|
||||||
|
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
|
||||||
|
let mut out = vec![0.0; elem_count];
|
||||||
|
let block_len = elem_count / self.dtype.block_size();
|
||||||
|
match self.dtype {
|
||||||
|
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
|
||||||
|
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.device
|
||||||
|
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||||
|
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||||
|
// Run the quantization on cpu.
|
||||||
|
let src = match &src.slice {
|
||||||
|
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
||||||
|
self.device.dtoh_sync_copy(data).w()?
|
||||||
|
}
|
||||||
|
_ => crate::bail!("only f32 can be quantized"),
|
||||||
|
};
|
||||||
|
let src_len = src.len();
|
||||||
|
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||||
|
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
|
||||||
|
qcpu_storage.quantize(&src)?;
|
||||||
|
let data = qcpu_storage.data()?;
|
||||||
|
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
|
||||||
|
self.data = data;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn storage_size_in_bytes(&self) -> usize {
|
||||||
|
self.data.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fwd(
|
||||||
|
&self,
|
||||||
|
self_shape: &crate::Shape,
|
||||||
|
storage: &CudaStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(CudaStorage, crate::Shape)> {
|
||||||
|
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
8
|
||||||
|
};
|
||||||
|
let use_vec_kernel = match layout.shape().dims() {
|
||||||
|
[b, m, _k] => b * m <= max_bm,
|
||||||
|
[b, _k] => *b <= max_bm,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
if use_vec_kernel {
|
||||||
|
self.dequantize_matmul_vec(self_shape, storage, layout)
|
||||||
|
} else {
|
||||||
|
self.dequantize_matmul(self_shape, storage, layout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QCudaStorage {
|
||||||
|
fn dequantize_matmul_vec(
|
||||||
|
&self,
|
||||||
|
self_shape: &crate::Shape,
|
||||||
|
rhs: &CudaStorage,
|
||||||
|
rhs_l: &crate::Layout,
|
||||||
|
) -> Result<(CudaStorage, crate::Shape)> {
|
||||||
|
let (nrows, ncols) = self_shape.dims2()?;
|
||||||
|
let rhs = rhs.as_cuda_slice::<f32>()?;
|
||||||
|
let rhs = match rhs_l.contiguous_offsets() {
|
||||||
|
Some((o1, o2)) => rhs.slice(o1..o2),
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
||||||
|
};
|
||||||
|
let (b_size, k) = match rhs_l.shape().dims() {
|
||||||
|
[b, m, k] => (b * m, *k),
|
||||||
|
[b, k] => (*b, *k),
|
||||||
|
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
||||||
|
};
|
||||||
|
if ncols != k {
|
||||||
|
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
||||||
|
}
|
||||||
|
|
||||||
|
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
|
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||||
|
} else {
|
||||||
|
mul_mat_vec_via_q8_1(
|
||||||
|
&self.data,
|
||||||
|
&rhs,
|
||||||
|
self.dtype,
|
||||||
|
ncols,
|
||||||
|
nrows,
|
||||||
|
b_size,
|
||||||
|
self.device(),
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
let mut out_shape = rhs_l.shape().dims().to_vec();
|
||||||
|
out_shape.pop();
|
||||||
|
out_shape.push(nrows);
|
||||||
|
Ok((out, out_shape.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dequantize_matmul(
|
||||||
|
&self,
|
||||||
|
self_shape: &crate::Shape,
|
||||||
|
storage: &CudaStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(CudaStorage, crate::Shape)> {
|
||||||
|
use crate::backend::BackendStorage;
|
||||||
|
let (n, k) = self_shape.dims2()?;
|
||||||
|
let (b, m, k2) = match layout.shape().dims() {
|
||||||
|
&[b, m, k2] => (b, m, k2),
|
||||||
|
&[m, k2] => (1, m, k2),
|
||||||
|
s => crate::bail!("unexpected shape for input {s:?}"),
|
||||||
|
};
|
||||||
|
if k2 != k {
|
||||||
|
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
|
||||||
|
}
|
||||||
|
|
||||||
|
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
|
let data_f32 = self.dequantize(n * k)?;
|
||||||
|
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
||||||
|
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
|
||||||
|
} else {
|
||||||
|
let storage = storage.as_cuda_slice::<f32>()?;
|
||||||
|
let storage = match layout.contiguous_offsets() {
|
||||||
|
Some((o1, o2)) => storage.slice(o1..o2),
|
||||||
|
None => Err(crate::Error::RequiresContiguous {
|
||||||
|
op: "quantized-matmul",
|
||||||
|
}
|
||||||
|
.bt())?,
|
||||||
|
};
|
||||||
|
mul_mat_via_q8_1(
|
||||||
|
&self.data,
|
||||||
|
&storage,
|
||||||
|
self.dtype,
|
||||||
|
/* x_rows */ n,
|
||||||
|
/* x_cols */ k,
|
||||||
|
/* y_rows */ k,
|
||||||
|
/* y_cols */ b * m,
|
||||||
|
self.device(),
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
let mut out_shape = layout.shape().dims().to_vec();
|
||||||
|
out_shape.pop();
|
||||||
|
out_shape.push(n);
|
||||||
|
Ok((out, out_shape.into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||||
|
device: &CudaDevice,
|
||||||
|
data: &[T],
|
||||||
|
) -> Result<super::QStorage> {
|
||||||
|
let data = unsafe {
|
||||||
|
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
|
||||||
|
};
|
||||||
|
let data = device.htod_sync_copy(data).w()?;
|
||||||
|
Ok(QStorage::Cuda(QCudaStorage {
|
||||||
|
data,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype: T::DTYPE,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cuda_quantize_q8_1() -> Result<()> {
|
||||||
|
let dev = CudaDevice::new(0)?;
|
||||||
|
let el = 256;
|
||||||
|
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
||||||
|
let y_size_in_bytes =
|
||||||
|
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
|
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||||
|
let y = dev.htod_sync_copy(&vs).w()?;
|
||||||
|
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cuda_mmv_q8_1() -> Result<()> {
|
||||||
|
let dev = CudaDevice::new(0)?;
|
||||||
|
let ncols = 256;
|
||||||
|
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
||||||
|
let y = dev.htod_sync_copy(&vs).w()?;
|
||||||
|
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
||||||
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
|
let cuda_storage = mul_mat_vec_via_q8_1(
|
||||||
|
&xs.data,
|
||||||
|
&y.slice(..),
|
||||||
|
/* dtype */ GgmlDType::Q4_0,
|
||||||
|
/* ncols */ ncols,
|
||||||
|
/* nrows */ 1,
|
||||||
|
/* b_size */ 1,
|
||||||
|
&dev,
|
||||||
|
)?;
|
||||||
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
|
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||||
|
assert_eq!(vs.len(), 1);
|
||||||
|
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
||||||
|
// Q8 means 1/256 precision.
|
||||||
|
assert_eq!(vs[0], 5561664.5);
|
||||||
|
|
||||||
|
let cuda_storage = dequantize_mul_mat_vec(
|
||||||
|
&xs.data,
|
||||||
|
&y.slice(..),
|
||||||
|
/* dtype */ GgmlDType::Q4_0,
|
||||||
|
/* ncols */ ncols,
|
||||||
|
/* nrows */ 1,
|
||||||
|
&dev,
|
||||||
|
)?;
|
||||||
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
|
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||||
|
assert_eq!(vs.len(), 1);
|
||||||
|
assert_eq!(vs[0], 5561851.0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cuda_mm_q8_1() -> Result<()> {
|
||||||
|
let dev = CudaDevice::new(0)?;
|
||||||
|
let ncols = 256;
|
||||||
|
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
||||||
|
let y = dev.htod_sync_copy(&vs).w()?;
|
||||||
|
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
||||||
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
|
let cuda_storage = mul_mat_via_q8_1(
|
||||||
|
&xs.data,
|
||||||
|
&y.slice(..),
|
||||||
|
/* dtype */ GgmlDType::Q4_0,
|
||||||
|
/* x_rows */ 4,
|
||||||
|
/* x_cols */ ncols,
|
||||||
|
/* y_rows */ ncols,
|
||||||
|
/* y_cols */ 4,
|
||||||
|
&dev,
|
||||||
|
)?;
|
||||||
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
|
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||||
|
|
||||||
|
/*
|
||||||
|
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
||||||
|
x @ x.t() / 16
|
||||||
|
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
|
||||||
|
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
|
||||||
|
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
|
||||||
|
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
|
||||||
|
*/
|
||||||
|
assert_eq!(vs.len(), 16);
|
||||||
|
assert_eq!(vs[0], 347604.0);
|
||||||
|
assert_eq!(vs[1], 888153.06);
|
||||||
|
assert_eq!(vs[4], 869780.7);
|
||||||
|
assert_eq!(vs[5], 2483145.0);
|
||||||
|
assert_eq!(vs[11], 9407368.0);
|
||||||
|
assert_eq!(vs[14], 9470856.0);
|
||||||
|
assert_eq!(vs[15], 13138824.0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
54
candle-core/src/quantized/dummy_cuda.rs
Normal file
54
candle-core/src/quantized/dummy_cuda.rs
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use super::GgmlDType;
|
||||||
|
use crate::{CudaDevice, CudaStorage, Error, Result};
|
||||||
|
|
||||||
|
pub struct QCudaStorage {
|
||||||
|
dtype: GgmlDType,
|
||||||
|
device: CudaDevice,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QCudaStorage {
|
||||||
|
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
|
self.dtype
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &CudaDevice {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn storage_size_in_bytes(&self) -> usize {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fwd(
|
||||||
|
&self,
|
||||||
|
_self_shape: &crate::Shape,
|
||||||
|
_storage: &CudaStorage,
|
||||||
|
_layout: &crate::Layout,
|
||||||
|
) -> Result<(CudaStorage, crate::Shape)> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||||
|
_device: &CudaDevice,
|
||||||
|
_data: &[T],
|
||||||
|
) -> Result<super::QStorage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
50
candle-core/src/quantized/dummy_metal.rs
Normal file
50
candle-core/src/quantized/dummy_metal.rs
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use super::GgmlDType;
|
||||||
|
use crate::{Error, MetalDevice, MetalStorage, Result};
|
||||||
|
|
||||||
|
pub struct QMetalStorage {
|
||||||
|
dtype: GgmlDType,
|
||||||
|
device: MetalDevice,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QMetalStorage {
|
||||||
|
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
|
self.dtype
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &MetalDevice {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn storage_size_in_bytes(&self) -> usize {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fwd(
|
||||||
|
&self,
|
||||||
|
_self_shape: &crate::Shape,
|
||||||
|
_storage: &MetalStorage,
|
||||||
|
_layout: &crate::Layout,
|
||||||
|
) -> Result<(MetalStorage, crate::Shape)> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||||
|
_device: &MetalDevice,
|
||||||
|
_data: &[T],
|
||||||
|
) -> Result<super::QStorage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
@ -1,7 +1,5 @@
|
|||||||
//! Support for the GGML file format.
|
//! Support for the GGML file format.
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
use super::metal::load_quantized_metal;
|
|
||||||
use super::{k_quants, GgmlDType, QStorage};
|
use super::{k_quants, GgmlDType, QStorage};
|
||||||
use crate::{Device, Result};
|
use crate::{Device, Result};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
@ -130,13 +128,8 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
|||||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||||
let data: QStorage = match device {
|
let data: QStorage = match device {
|
||||||
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
||||||
#[cfg(feature = "metal")]
|
Device::Metal(metal) => super::metal::load_quantized(metal, data)?,
|
||||||
Device::Metal(metal) => load_quantized_metal(metal, data)?,
|
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?,
|
||||||
#[cfg(not(feature = "metal"))]
|
|
||||||
Device::Metal(_metal) => {
|
|
||||||
crate::bail!("Metal backend requires `metal` feature")
|
|
||||||
}
|
|
||||||
device => unimplemented!("Implement quantized tensor for device {device:?}"),
|
|
||||||
};
|
};
|
||||||
super::QTensor::new(data, dims)
|
super::QTensor::new(data, dims)
|
||||||
}
|
}
|
||||||
@ -233,6 +226,7 @@ pub struct Content {
|
|||||||
pub hparams: HParams,
|
pub hparams: HParams,
|
||||||
pub vocab: Vocab,
|
pub vocab: Vocab,
|
||||||
pub tensors: HashMap<String, super::QTensor>,
|
pub tensors: HashMap<String, super::QTensor>,
|
||||||
|
pub device: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Content {
|
impl Content {
|
||||||
@ -252,11 +246,13 @@ impl Content {
|
|||||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||||
tensors.insert(name, tensor);
|
tensors.insert(name, tensor);
|
||||||
}
|
}
|
||||||
|
let device = device.clone();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
magic,
|
magic,
|
||||||
hparams,
|
hparams,
|
||||||
vocab,
|
vocab,
|
||||||
tensors,
|
tensors,
|
||||||
|
device,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,7 +135,6 @@ pub enum ValueType {
|
|||||||
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
||||||
String,
|
String,
|
||||||
// The value is an array of other values, with the length and type prepended.
|
// The value is an array of other values, with the length and type prepended.
|
||||||
///
|
|
||||||
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
||||||
Array,
|
Array,
|
||||||
}
|
}
|
||||||
@ -218,10 +217,16 @@ impl Value {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// This will also automatically upcast any integral types which will not truncate.
|
||||||
pub fn to_u64(&self) -> Result<u64> {
|
pub fn to_u64(&self) -> Result<u64> {
|
||||||
match self {
|
match self {
|
||||||
Self::U64(v) => Ok(*v),
|
Self::U64(v) => Ok(*v),
|
||||||
v => crate::bail!("not a u64 {v:?}"),
|
// Autoupcast cases here
|
||||||
|
Self::U8(v) => Ok(*v as u64),
|
||||||
|
Self::U16(v) => Ok(*v as u64),
|
||||||
|
Self::U32(v) => Ok(*v as u64),
|
||||||
|
Self::Bool(v) => Ok(*v as u64),
|
||||||
|
v => crate::bail!("not a u64 or upcastable to u64 {v:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use super::{GgmlDType, QStorage};
|
use super::{GgmlDType, QStorage};
|
||||||
use crate::{DType, MetalDevice, MetalStorage, Result};
|
use crate::backend::BackendStorage;
|
||||||
|
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
|
||||||
use metal::Buffer;
|
use metal::Buffer;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -10,23 +11,31 @@ pub struct QMetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl QMetalStorage {
|
impl QMetalStorage {
|
||||||
|
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||||
|
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||||
|
let buffer = device.allocate_zeros(size)?;
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
self.dtype
|
self.dtype
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &MetalDevice {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
pub fn buffer(&self) -> &Buffer {
|
pub fn buffer(&self) -> &Buffer {
|
||||||
&self.buffer
|
&self.buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
|
|
||||||
Self {
|
|
||||||
device,
|
|
||||||
buffer,
|
|
||||||
dtype,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
|
|
||||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
command_buffer.set_label("to_cpu");
|
command_buffer.set_label("to_cpu");
|
||||||
@ -36,87 +45,73 @@ impl QMetalStorage {
|
|||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
self.device.wait_until_completed()?;
|
self.device.wait_until_completed()?;
|
||||||
let mut out = vec![0.0; elem_count];
|
let mut out = vec![0.0; elem_count];
|
||||||
|
let block_len = elem_count / self.dtype.block_size();
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
GgmlDType::F32 => {
|
GgmlDType::F32 => {
|
||||||
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
|
let vec: Vec<f32> = read_to_vec(&buffer, block_len);
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
f32::to_float(&vec, &mut out)?;
|
f32::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::F16 => {
|
GgmlDType::F16 => {
|
||||||
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
|
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
half::f16::to_float(&vec, &mut out)?;
|
half::f16::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q4_0 => {
|
GgmlDType::Q4_0 => {
|
||||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
|
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q4_1 => {
|
GgmlDType::Q4_1 => {
|
||||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
|
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q5_0 => {
|
GgmlDType::Q5_0 => {
|
||||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
|
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q5_1 => {
|
GgmlDType::Q5_1 => {
|
||||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
|
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q8_0 => {
|
GgmlDType::Q8_0 => {
|
||||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
|
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q8_1 => {
|
GgmlDType::Q8_1 => {
|
||||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
|
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q2K => {
|
GgmlDType::Q2K => {
|
||||||
let vec: Vec<crate::quantized::BlockQ2K> =
|
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
||||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q3K => {
|
GgmlDType::Q3K => {
|
||||||
let vec: Vec<crate::quantized::BlockQ3K> =
|
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
||||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q4K => {
|
GgmlDType::Q4K => {
|
||||||
let vec: Vec<crate::quantized::BlockQ4K> =
|
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
||||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q5K => {
|
GgmlDType::Q5K => {
|
||||||
let vec: Vec<crate::quantized::BlockQ5K> =
|
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
||||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q6K => {
|
GgmlDType::Q6K => {
|
||||||
let vec: Vec<crate::quantized::BlockQ6K> =
|
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
||||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q8K => {
|
GgmlDType::Q8K => {
|
||||||
let vec: Vec<crate::quantized::BlockQ8K> =
|
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
||||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let buffer = self.device.new_buffer_with_data(&out)?;
|
let buffer = self.device.new_buffer_with_data(&out)?;
|
||||||
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
|
Ok(MetalStorage::new(
|
||||||
|
buffer,
|
||||||
|
self.device.clone(),
|
||||||
|
elem_count,
|
||||||
|
DType::F32,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
||||||
@ -130,9 +125,70 @@ impl QMetalStorage {
|
|||||||
self.buffer = buffer;
|
self.buffer = buffer;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn storage_size_in_bytes(&self) -> usize {
|
||||||
|
self.buffer.length() as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fwd(
|
||||||
|
&self,
|
||||||
|
self_shape: &Shape,
|
||||||
|
storage: &MetalStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
use crate::MetalError;
|
||||||
|
|
||||||
|
if !layout.is_contiguous() {
|
||||||
|
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||||
|
}
|
||||||
|
let src_shape = layout.shape();
|
||||||
|
// self is transposed so n is first then k.
|
||||||
|
if src_shape.rank() < 2 {
|
||||||
|
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||||
|
}
|
||||||
|
let (n, k) = self_shape.dims2()?;
|
||||||
|
let mut dst_shape = src_shape.dims().to_vec();
|
||||||
|
|
||||||
|
// We always use a single batch dimension and stack all the tensors in the batch on the
|
||||||
|
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
||||||
|
// properly.
|
||||||
|
let m = match dst_shape.len() {
|
||||||
|
3 => dst_shape[0] * dst_shape[1],
|
||||||
|
2 => dst_shape[0],
|
||||||
|
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||||
|
};
|
||||||
|
let last_k = dst_shape.pop().unwrap();
|
||||||
|
if last_k != k {
|
||||||
|
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
|
||||||
|
}
|
||||||
|
dst_shape.push(n);
|
||||||
|
let dst_shape = Shape::from(dst_shape);
|
||||||
|
let device = storage.device().clone();
|
||||||
|
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||||
|
let command_buffer = device.command_buffer()?;
|
||||||
|
// In some cases it would be better to use the mm variant, though it has its drawbacks
|
||||||
|
// around memory alignemnt.
|
||||||
|
for batch_id in 0..m {
|
||||||
|
candle_metal_kernels::call_quantized_matmul_mv_t(
|
||||||
|
device.device(),
|
||||||
|
&command_buffer,
|
||||||
|
device.kernels(),
|
||||||
|
self.dtype.into(),
|
||||||
|
(1, 1, n, k),
|
||||||
|
storage.buffer(),
|
||||||
|
(layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
|
||||||
|
&self.buffer,
|
||||||
|
batch_id * n * DType::F32.size_in_bytes(),
|
||||||
|
&dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
}
|
||||||
|
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
|
||||||
|
Ok((dst_storage, dst_shape))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||||
device: &MetalDevice,
|
device: &MetalDevice,
|
||||||
data: &[T],
|
data: &[T],
|
||||||
) -> Result<QStorage> {
|
) -> Result<QStorage> {
|
||||||
@ -151,3 +207,24 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
|||||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||||
slice.to_vec()
|
slice.to_vec()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||||
|
fn from(value: GgmlDType) -> Self {
|
||||||
|
match value {
|
||||||
|
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||||
|
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||||
|
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||||
|
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||||
|
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||||
|
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||||
|
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||||
|
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||||
|
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||||
|
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||||
|
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||||
|
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||||
|
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||||
|
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,16 +1,27 @@
|
|||||||
#[cfg(feature = "metal")]
|
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
|
||||||
use crate::{backend::BackendStorage, DType};
|
|
||||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
|
||||||
use k_quants::*;
|
use k_quants::*;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
#[cfg(target_feature = "avx")]
|
#[cfg(target_feature = "avx")]
|
||||||
pub mod avx;
|
pub mod avx;
|
||||||
|
mod dummy_cuda;
|
||||||
|
mod dummy_metal;
|
||||||
pub mod ggml_file;
|
pub mod ggml_file;
|
||||||
pub mod gguf_file;
|
pub mod gguf_file;
|
||||||
pub mod k_quants;
|
pub mod k_quants;
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
pub mod metal;
|
pub mod metal;
|
||||||
|
#[cfg(not(feature = "metal"))]
|
||||||
|
mod metal {
|
||||||
|
pub use super::dummy_metal::*;
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub mod cuda;
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
mod cuda {
|
||||||
|
pub use super::dummy_cuda::*;
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
pub mod neon;
|
pub mod neon;
|
||||||
#[cfg(target_feature = "simd128")]
|
#[cfg(target_feature = "simd128")]
|
||||||
@ -32,22 +43,13 @@ impl Device {
|
|||||||
let storage = dtype.cpu_zeros(elem_count);
|
let storage = dtype.cpu_zeros(elem_count);
|
||||||
Ok(QStorage::Cpu(storage))
|
Ok(QStorage::Cpu(storage))
|
||||||
}
|
}
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
Device::Metal(metal) => {
|
Device::Metal(metal) => {
|
||||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
|
||||||
let buffer = metal.allocate_zeros(size)?;
|
Ok(QStorage::Metal(storage))
|
||||||
Ok(QStorage::Metal(metal::QMetalStorage::new(
|
|
||||||
buffer,
|
|
||||||
metal.clone(),
|
|
||||||
dtype,
|
|
||||||
)))
|
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "metal"))]
|
Device::Cuda(cuda) => {
|
||||||
Device::Metal(_metal) => {
|
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
|
||||||
crate::bail!("Metal feature not activated");
|
Ok(QStorage::Cuda(storage))
|
||||||
}
|
|
||||||
Device::Cuda(_cuda) => {
|
|
||||||
crate::bail!("Cuda ggml quantization not supported");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -55,32 +57,40 @@ impl Device {
|
|||||||
|
|
||||||
pub enum QStorage {
|
pub enum QStorage {
|
||||||
Cpu(Box<dyn QuantizedType>),
|
Cpu(Box<dyn QuantizedType>),
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
Metal(metal::QMetalStorage),
|
Metal(metal::QMetalStorage),
|
||||||
|
Cuda(cuda::QCudaStorage),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QStorage {
|
impl QStorage {
|
||||||
fn block_size(&self) -> usize {
|
fn block_size(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => storage.block_size(),
|
QStorage::Cpu(storage) => storage.block_size(),
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||||
|
QStorage::Cuda(storage) => storage.dtype().block_size(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dtype(&self) -> GgmlDType {
|
fn dtype(&self) -> GgmlDType {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => storage.dtype(),
|
QStorage::Cpu(storage) => storage.dtype(),
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(storage) => storage.dtype(),
|
QStorage::Metal(storage) => storage.dtype(),
|
||||||
|
QStorage::Cuda(storage) => storage.dtype(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(&self) -> Device {
|
||||||
|
match self {
|
||||||
|
QStorage::Cpu(_storage) => Device::Cpu,
|
||||||
|
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||||
|
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size_in_bytes(&self) -> usize {
|
fn size_in_bytes(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||||
#[cfg(feature = "metal")]
|
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
|
||||||
QStorage::Metal(storage) => storage.buffer().length() as usize,
|
QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,8 +99,8 @@ impl QStorage {
|
|||||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
||||||
storage.from_float(src.as_slice::<f32>()?)?;
|
storage.from_float(src.as_slice::<f32>()?)?;
|
||||||
}
|
}
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||||
|
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
|
||||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -99,8 +109,8 @@ impl QStorage {
|
|||||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||||
|
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,8 +122,7 @@ impl QStorage {
|
|||||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||||
Ok(Cow::from(data))
|
Ok(Cow::from(data))
|
||||||
}
|
}
|
||||||
#[cfg(feature = "metal")]
|
QStorage::Metal(_) | QStorage::Cuda(_) => {
|
||||||
QStorage::Metal(_storage) => {
|
|
||||||
crate::bail!("not implemented");
|
crate::bail!("not implemented");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -336,6 +345,10 @@ impl QTensor {
|
|||||||
self.storage.dtype()
|
self.storage.dtype()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> Device {
|
||||||
|
self.storage.device()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
self.shape.rank()
|
self.shape.rank()
|
||||||
}
|
}
|
||||||
@ -347,9 +360,24 @@ impl QTensor {
|
|||||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
||||||
let none = crate::op::BackpropOp::none();
|
let none = crate::op::BackpropOp::none();
|
||||||
let is_variable = false;
|
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
|
||||||
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
|
}
|
||||||
.to_device(device)
|
|
||||||
|
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
|
||||||
|
// In the CUDA case, we have a specialized kernel as this can be useful for volta
|
||||||
|
// architectures. https://github.com/huggingface/candle/issues/2136
|
||||||
|
match &self.storage {
|
||||||
|
QStorage::Cuda(s) => {
|
||||||
|
let s = s.dequantize_f16(self.shape.elem_count())?;
|
||||||
|
let none = crate::op::BackpropOp::none();
|
||||||
|
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
|
||||||
|
.to_device(device)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
|
||||||
|
Ok(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
pub fn storage_size_in_bytes(&self) -> usize {
|
||||||
@ -365,6 +393,7 @@ impl QTensor {
|
|||||||
pub enum QMatMul {
|
pub enum QMatMul {
|
||||||
QTensor(std::sync::Arc<QTensor>),
|
QTensor(std::sync::Arc<QTensor>),
|
||||||
Tensor(Tensor),
|
Tensor(Tensor),
|
||||||
|
TensorF16(Tensor),
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_local! {
|
thread_local! {
|
||||||
@ -378,6 +407,17 @@ thread_local! {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static DEQUANTIZE_ALL_F16: bool = {
|
||||||
|
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
|
||||||
|
Ok(s) => {
|
||||||
|
!s.is_empty() && s != "0"
|
||||||
|
},
|
||||||
|
Err(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||||
let dequantize = match qtensor.dtype() {
|
let dequantize = match qtensor.dtype() {
|
||||||
@ -385,8 +425,11 @@ impl QMatMul {
|
|||||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||||
};
|
};
|
||||||
let t = if dequantize {
|
let t = if dequantize {
|
||||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
let tensor = qtensor.dequantize(&qtensor.device())?;
|
||||||
Self::Tensor(tensor)
|
Self::Tensor(tensor)
|
||||||
|
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
|
||||||
|
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
|
||||||
|
Self::TensorF16(tensor)
|
||||||
} else {
|
} else {
|
||||||
Self::QTensor(qtensor)
|
Self::QTensor(qtensor)
|
||||||
};
|
};
|
||||||
@ -396,6 +439,25 @@ impl QMatMul {
|
|||||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dequantize_f16(&self) -> Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::QTensor(t) => t.dequantize_f16(&t.device()),
|
||||||
|
Self::Tensor(t) => t.to_dtype(DType::F16),
|
||||||
|
Self::TensorF16(t) => Ok(t.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let w = self.dequantize_f16()?;
|
||||||
|
let in_dtype = xs.dtype();
|
||||||
|
let w = match *xs.dims() {
|
||||||
|
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||||
|
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||||
|
_ => w.t()?,
|
||||||
|
};
|
||||||
|
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::CustomOp1 for QTensor {
|
impl crate::CustomOp1 for QTensor {
|
||||||
@ -427,8 +489,7 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
#[allow(clippy::infallible_destructuring_match)]
|
#[allow(clippy::infallible_destructuring_match)]
|
||||||
let self_storage = match &self.storage {
|
let self_storage = match &self.storage {
|
||||||
QStorage::Cpu(storage) => storage,
|
QStorage::Cpu(storage) => storage,
|
||||||
#[cfg(feature = "metal")]
|
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
|
||||||
_ => crate::bail!("Invalid storage"),
|
|
||||||
};
|
};
|
||||||
let slice = storage.as_slice::<f32>()?;
|
let slice = storage.as_slice::<f32>()?;
|
||||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||||
@ -437,79 +498,28 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
fn metal_fwd(
|
fn metal_fwd(
|
||||||
&self,
|
&self,
|
||||||
storage: &crate::MetalStorage,
|
storage: &crate::MetalStorage,
|
||||||
layout: &crate::Layout,
|
layout: &crate::Layout,
|
||||||
) -> Result<(crate::MetalStorage, Shape)> {
|
) -> Result<(crate::MetalStorage, Shape)> {
|
||||||
use crate::MetalError;
|
let self_storage = match &self.storage {
|
||||||
|
QStorage::Metal(metal) => metal,
|
||||||
if !layout.is_contiguous() {
|
|
||||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
|
||||||
}
|
|
||||||
let src_shape = layout.shape();
|
|
||||||
// self is transposed so n is first then k.
|
|
||||||
if src_shape.rank() < 2 {
|
|
||||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
|
||||||
}
|
|
||||||
let (n, k) = self.shape.dims2()?;
|
|
||||||
let mut dst_shape = src_shape.dims().to_vec();
|
|
||||||
|
|
||||||
let (b, m) = match dst_shape.len() {
|
|
||||||
3 => (dst_shape[0], dst_shape[1]),
|
|
||||||
2 => (1, dst_shape[0]),
|
|
||||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
|
||||||
};
|
|
||||||
let last_k = dst_shape.pop().unwrap();
|
|
||||||
if last_k != k {
|
|
||||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
|
||||||
}
|
|
||||||
dst_shape.push(n);
|
|
||||||
let dst_shape = Shape::from(dst_shape);
|
|
||||||
let device = storage.device().clone();
|
|
||||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
|
||||||
let (buffer, dtype) = match &self.storage {
|
|
||||||
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
|
|
||||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
||||||
};
|
};
|
||||||
let command_buffer = device.command_buffer()?;
|
self_storage.fwd(&self.shape, storage, layout)
|
||||||
candle_metal_kernels::call_quantized_matmul_t(
|
|
||||||
device.device(),
|
|
||||||
&command_buffer,
|
|
||||||
device.kernels(),
|
|
||||||
dtype.into(),
|
|
||||||
(b, m, n, k),
|
|
||||||
storage.buffer(),
|
|
||||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
|
||||||
buffer,
|
|
||||||
&dst,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
|
||||||
Ok((dst_storage, dst_shape))
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
fn cuda_fwd(
|
||||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
&self,
|
||||||
fn from(value: GgmlDType) -> Self {
|
storage: &crate::CudaStorage,
|
||||||
match value {
|
layout: &crate::Layout,
|
||||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
) -> Result<(crate::CudaStorage, Shape)> {
|
||||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
let self_storage = match &self.storage {
|
||||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
QStorage::Cuda(cuda) => cuda,
|
||||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
|
||||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
};
|
||||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
self_storage.fwd(&self.shape, storage, layout)
|
||||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
|
||||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
|
||||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
|
||||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
|
||||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
|
||||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
|
||||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
|
||||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -525,6 +535,15 @@ impl crate::Module for QMatMul {
|
|||||||
};
|
};
|
||||||
xs.matmul(&w)
|
xs.matmul(&w)
|
||||||
}
|
}
|
||||||
|
Self::TensorF16(w) => {
|
||||||
|
let in_dtype = xs.dtype();
|
||||||
|
let w = match *xs.dims() {
|
||||||
|
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||||
|
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||||
|
_ => w.t()?,
|
||||||
|
};
|
||||||
|
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -349,6 +349,30 @@ impl MmapedSafetensors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct SliceSafetensors<'a> {
|
||||||
|
safetensors: SafeTensors<'a>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> SliceSafetensors<'a> {
|
||||||
|
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||||
|
pub fn new(buffer: &'a [u8]) -> Result<Self> {
|
||||||
|
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
|
||||||
|
Ok(Self { safetensors })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||||
|
self.safetensors.tensor(name)?.load(dev)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||||
|
self.safetensors.tensors()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||||
|
Ok(self.safetensors.tensor(name)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct BufferedSafetensors {
|
pub struct BufferedSafetensors {
|
||||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||||
}
|
}
|
||||||
|
@ -171,7 +171,7 @@ impl Shape {
|
|||||||
}
|
}
|
||||||
let mut acc = 1;
|
let mut acc = 1;
|
||||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
||||||
if stride != acc {
|
if dim > 1 && stride != acc {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
acc *= dim;
|
acc *= dim;
|
||||||
@ -186,7 +186,7 @@ impl Shape {
|
|||||||
}
|
}
|
||||||
let mut acc = 1;
|
let mut acc = 1;
|
||||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
||||||
if stride != acc {
|
if dim > 1 && stride != acc {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
acc *= dim;
|
acc *= dim;
|
||||||
@ -304,6 +304,7 @@ impl Dim for usize {
|
|||||||
pub enum D {
|
pub enum D {
|
||||||
Minus1,
|
Minus1,
|
||||||
Minus2,
|
Minus2,
|
||||||
|
Minus(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl D {
|
impl D {
|
||||||
@ -311,6 +312,7 @@ impl D {
|
|||||||
let dim = match self {
|
let dim = match self {
|
||||||
Self::Minus1 => -1,
|
Self::Minus1 => -1,
|
||||||
Self::Minus2 => -2,
|
Self::Minus2 => -2,
|
||||||
|
Self::Minus(u) => -(*u as i32),
|
||||||
};
|
};
|
||||||
Error::DimOutOfRange {
|
Error::DimOutOfRange {
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
@ -327,6 +329,7 @@ impl Dim for D {
|
|||||||
match self {
|
match self {
|
||||||
Self::Minus1 if rank >= 1 => Ok(rank - 1),
|
Self::Minus1 if rank >= 1 => Ok(rank - 1),
|
||||||
Self::Minus2 if rank >= 2 => Ok(rank - 2),
|
Self::Minus2 if rank >= 2 => Ok(rank - 2),
|
||||||
|
Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
|
||||||
_ => Err(self.out_of_range(shape, op)),
|
_ => Err(self.out_of_range(shape, op)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -336,6 +339,7 @@ impl Dim for D {
|
|||||||
match self {
|
match self {
|
||||||
Self::Minus1 => Ok(rank),
|
Self::Minus1 => Ok(rank),
|
||||||
Self::Minus2 if rank >= 1 => Ok(rank - 1),
|
Self::Minus2 if rank >= 1 => Ok(rank - 1),
|
||||||
|
Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
|
||||||
_ => Err(self.out_of_range(shape, op)),
|
_ => Err(self.out_of_range(shape, op)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
239
candle-core/src/sort.rs
Normal file
239
candle-core/src/sort.rs
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
use crate::{Result, Tensor};
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
struct ArgSort {
|
||||||
|
asc: bool,
|
||||||
|
last_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ArgSort {
|
||||||
|
fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
|
||||||
|
#[allow(clippy::uninit_vec)]
|
||||||
|
// Safety: indexes are set later in the parallelized section.
|
||||||
|
let mut sort_indexes = unsafe {
|
||||||
|
let el_count = layout.shape().elem_count();
|
||||||
|
let mut v = Vec::with_capacity(el_count);
|
||||||
|
v.set_len(el_count);
|
||||||
|
v
|
||||||
|
};
|
||||||
|
if self.asc {
|
||||||
|
sort_indexes
|
||||||
|
.par_chunks_exact_mut(self.last_dim)
|
||||||
|
.zip(vs.par_chunks_exact(self.last_dim))
|
||||||
|
.for_each(|(indexes, vs)| {
|
||||||
|
indexes
|
||||||
|
.iter_mut()
|
||||||
|
.enumerate()
|
||||||
|
.for_each(|(i, v)| *v = i as u32);
|
||||||
|
indexes.sort_by(|&i, &j| {
|
||||||
|
vs[i as usize]
|
||||||
|
.partial_cmp(&vs[j as usize])
|
||||||
|
.unwrap_or(std::cmp::Ordering::Greater)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
sort_indexes
|
||||||
|
.par_chunks_exact_mut(self.last_dim)
|
||||||
|
.zip(vs.par_chunks_exact(self.last_dim))
|
||||||
|
.for_each(|(indexes, vs)| {
|
||||||
|
indexes
|
||||||
|
.iter_mut()
|
||||||
|
.enumerate()
|
||||||
|
.for_each(|(i, v)| *v = i as u32);
|
||||||
|
indexes.sort_by(|&j, &i| {
|
||||||
|
vs[i as usize]
|
||||||
|
.partial_cmp(&vs[j as usize])
|
||||||
|
.unwrap_or(std::cmp::Ordering::Greater)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
sort_indexes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::CustomOp1 for ArgSort {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"argsort"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &crate::CpuStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(crate::CpuStorage, crate::Shape)> {
|
||||||
|
let sort_indexes = match storage {
|
||||||
|
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
|
||||||
|
};
|
||||||
|
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
|
||||||
|
Ok((sort_indexes, layout.shape().into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &crate::CudaStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
||||||
|
use crate::cuda_backend::cudarc::driver::{
|
||||||
|
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||||
|
};
|
||||||
|
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
|
||||||
|
use crate::{CudaDevice, WithDType};
|
||||||
|
|
||||||
|
impl Map1Any for ArgSort {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
_wrap: W,
|
||||||
|
) -> Result<S> {
|
||||||
|
let slice = match layout.contiguous_offsets() {
|
||||||
|
None => crate::bail!("input has to be contiguous"),
|
||||||
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
|
};
|
||||||
|
let elem_count = layout.shape().elem_count();
|
||||||
|
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||||
|
let func = if self.asc {
|
||||||
|
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
||||||
|
} else {
|
||||||
|
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
||||||
|
};
|
||||||
|
let ncols = self.last_dim;
|
||||||
|
let nrows = elem_count / ncols;
|
||||||
|
let ncols_pad = next_power_of_2(ncols);
|
||||||
|
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
||||||
|
let cfg = LaunchConfig {
|
||||||
|
grid_dim: (1, nrows as u32, 1),
|
||||||
|
block_dim: (ncols_pad as u32, 1, 1),
|
||||||
|
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||||
|
};
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(S::U32(dst))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use crate::backend::BackendStorage;
|
||||||
|
let dev = storage.device();
|
||||||
|
let slice = self.map(&storage.slice, dev, layout)?;
|
||||||
|
let dst = crate::cuda_backend::CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: dev.clone(),
|
||||||
|
};
|
||||||
|
Ok((dst, layout.shape().clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &crate::MetalStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(crate::MetalStorage, crate::Shape)> {
|
||||||
|
use crate::backend::BackendStorage;
|
||||||
|
use crate::DType;
|
||||||
|
|
||||||
|
let name = {
|
||||||
|
if self.asc {
|
||||||
|
match storage.dtype() {
|
||||||
|
DType::BF16 => "asort_asc_bf16",
|
||||||
|
DType::F16 => "asort_asc_f16",
|
||||||
|
DType::F32 => "asort_asc_f32",
|
||||||
|
DType::F64 => "asort_asc_f64",
|
||||||
|
DType::U8 => "asort_asc_u8",
|
||||||
|
DType::U32 => "asort_asc_u32",
|
||||||
|
DType::I64 => "asort_asc_i64",
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match storage.dtype() {
|
||||||
|
DType::BF16 => "asort_desc_bf16",
|
||||||
|
DType::F16 => "asort_desc_f16",
|
||||||
|
DType::F32 => "asort_desc_f32",
|
||||||
|
DType::F64 => "asort_desc_f64",
|
||||||
|
DType::U8 => "asort_desc_u8",
|
||||||
|
DType::U32 => "asort_desc_u32",
|
||||||
|
DType::I64 => "asort_desc_i64",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let device = storage.device();
|
||||||
|
let kernels = device.kernels();
|
||||||
|
let command_buffer = device.command_buffer()?;
|
||||||
|
let el = layout.shape().elem_count();
|
||||||
|
let ncols = self.last_dim;
|
||||||
|
let nrows = el / ncols;
|
||||||
|
let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
|
||||||
|
let dst = device.new_buffer(el, DType::U32, "asort")?;
|
||||||
|
let mut ncols_pad = 1;
|
||||||
|
while ncols_pad < ncols {
|
||||||
|
ncols_pad *= 2;
|
||||||
|
}
|
||||||
|
candle_metal_kernels::call_arg_sort(
|
||||||
|
device.metal_device(),
|
||||||
|
&command_buffer,
|
||||||
|
kernels,
|
||||||
|
name,
|
||||||
|
nrows,
|
||||||
|
ncols,
|
||||||
|
ncols_pad,
|
||||||
|
src,
|
||||||
|
&dst,
|
||||||
|
)
|
||||||
|
.map_err(crate::Error::wrap)?;
|
||||||
|
let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
|
||||||
|
Ok((dst, layout.shape().clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
fn next_power_of_2(x: usize) -> usize {
|
||||||
|
let mut n = 1;
|
||||||
|
while n < x {
|
||||||
|
n *= 2
|
||||||
|
}
|
||||||
|
n
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tensor {
|
||||||
|
/// Returns the indices that sort the tensor along the last dimension.
|
||||||
|
///
|
||||||
|
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||||
|
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||||
|
/// comes to ties.
|
||||||
|
pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
|
||||||
|
if !self.is_contiguous() {
|
||||||
|
return Err(crate::Error::RequiresContiguous {
|
||||||
|
op: "arg_sort_last_dim",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let last_dim = match self.dims().last() {
|
||||||
|
None => crate::bail!("empty last-dim in arg-sort"),
|
||||||
|
Some(last_dim) => *last_dim,
|
||||||
|
};
|
||||||
|
// No need for a backward pass for arg sort.
|
||||||
|
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
|
||||||
|
/// sorted indexes.
|
||||||
|
///
|
||||||
|
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||||
|
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||||
|
/// comes to ties.
|
||||||
|
pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
|
||||||
|
if !self.is_contiguous() {
|
||||||
|
return Err(crate::Error::RequiresContiguous {
|
||||||
|
op: "sort_last_dim",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let asort = self.arg_sort_last_dim(asc)?;
|
||||||
|
let sorted = self.gather(&asort, crate::D::Minus1)?;
|
||||||
|
Ok((sorted, asort))
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
use crate::op::{self, CmpOp, ReduceOp};
|
||||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||||
|
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||||
|
|
||||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||||
// out of memory. Instead try_clone should be used.
|
// out of memory. Instead try_clone should be used.
|
||||||
@ -43,9 +44,19 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||||
let lhs = self.device().location();
|
let lhs_device = self.device();
|
||||||
let rhs = rhs.device().location();
|
let rhs_device = rhs.device();
|
||||||
if lhs != rhs {
|
let lhs = lhs_device.location();
|
||||||
|
let rhs = rhs_device.location();
|
||||||
|
let same_device = if self.device().is_metal() {
|
||||||
|
// On metal, we require the device to be exactly the same rather than
|
||||||
|
// having the same location. In cuda this is not necessary as all CudaDevice on the
|
||||||
|
// same GPU will use the same cuda stream.
|
||||||
|
lhs_device.same_device(&rhs_device)
|
||||||
|
} else {
|
||||||
|
lhs == rhs
|
||||||
|
};
|
||||||
|
if !same_device {
|
||||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -252,6 +263,51 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
|
||||||
|
match self {
|
||||||
|
Self::Cpu(storage) => c.cpu_fwd(storage, l),
|
||||||
|
Self::Cuda(storage) => c.cuda_fwd(storage, l),
|
||||||
|
Self::Metal(storage) => c.metal_fwd(storage, l),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn inplace_op2(
|
||||||
|
&mut self,
|
||||||
|
l1: &Layout,
|
||||||
|
t2: &Self,
|
||||||
|
l2: &Layout,
|
||||||
|
c: &dyn InplaceOp2,
|
||||||
|
) -> Result<()> {
|
||||||
|
self.same_device(t2, c.name())?;
|
||||||
|
match (self, t2) {
|
||||||
|
(Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
|
||||||
|
(Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
|
||||||
|
(Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn inplace_op3(
|
||||||
|
&mut self,
|
||||||
|
l1: &Layout,
|
||||||
|
t2: &Self,
|
||||||
|
l2: &Layout,
|
||||||
|
t3: &Self,
|
||||||
|
l3: &Layout,
|
||||||
|
c: &dyn InplaceOp3,
|
||||||
|
) -> Result<()> {
|
||||||
|
self.same_device(t2, c.name())?;
|
||||||
|
self.same_device(t3, c.name())?;
|
||||||
|
match (self, t2, t3) {
|
||||||
|
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
|
||||||
|
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
|
||||||
|
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
||||||
|
c.metal_fwd(s1, l1, s2, l2, s3, l3)
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => {
|
Storage::Cpu(storage) => {
|
||||||
@ -352,6 +408,10 @@ impl Storage {
|
|||||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
|
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||||
|
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Metal(s))
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -697,4 +757,32 @@ impl Storage {
|
|||||||
.bt()),
|
.bt()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) fn copy2d(
|
||||||
|
&self,
|
||||||
|
dst: &mut Self,
|
||||||
|
d1: usize,
|
||||||
|
d2: usize,
|
||||||
|
src_s: usize,
|
||||||
|
dst_s: usize,
|
||||||
|
src_o: usize,
|
||||||
|
dst_o: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
match (self, dst) {
|
||||||
|
(Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
||||||
|
(Self::Cuda(src), Self::Cuda(dst)) => {
|
||||||
|
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
||||||
|
}
|
||||||
|
(Self::Metal(src), Self::Metal(dst)) => {
|
||||||
|
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
||||||
|
}
|
||||||
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: lhs.device().location(),
|
||||||
|
rhs: rhs.device().location(),
|
||||||
|
op: "copy2d",
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
206
candle-core/src/streaming.rs
Normal file
206
candle-core/src/streaming.rs
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
use crate::{Result, Shape, Tensor};
|
||||||
|
|
||||||
|
pub trait Dim: crate::shape::Dim + Copy {}
|
||||||
|
impl<T: crate::shape::Dim + Copy> Dim for T {}
|
||||||
|
|
||||||
|
/// A stream tensor is used in streaming module. It can either contain an actual tensor or be
|
||||||
|
/// empty.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct StreamTensor(Option<Tensor>);
|
||||||
|
|
||||||
|
impl std::fmt::Debug for StreamTensor {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match &self.0 {
|
||||||
|
Some(t) => write!(f, "{:?}", t.shape()),
|
||||||
|
None => write!(f, "Empty"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::From<Option<Tensor>> for StreamTensor {
|
||||||
|
fn from(value: Option<Tensor>) -> Self {
|
||||||
|
Self(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::From<Tensor> for StreamTensor {
|
||||||
|
fn from(value: Tensor) -> Self {
|
||||||
|
Self(Some(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::From<()> for StreamTensor {
|
||||||
|
fn from(_value: ()) -> Self {
|
||||||
|
Self(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamTensor {
|
||||||
|
pub fn empty() -> Self {
|
||||||
|
Self(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_tensor(tensor: Tensor) -> Self {
|
||||||
|
Self(Some(tensor))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shape(&self) -> Option<&Shape> {
|
||||||
|
self.0.as_ref().map(|t| t.shape())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cat2<D: Dim>(&self, rhs: &Self, dim: D) -> Result<Self> {
|
||||||
|
let xs = match (&self.0, &rhs.0) {
|
||||||
|
(Some(lhs), Some(rhs)) => {
|
||||||
|
let xs = Tensor::cat(&[lhs, rhs], dim)?;
|
||||||
|
Some(xs)
|
||||||
|
}
|
||||||
|
(Some(xs), None) | (None, Some(xs)) => Some(xs.clone()),
|
||||||
|
(None, None) => None,
|
||||||
|
};
|
||||||
|
Ok(Self(xs))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn seq_len<D: Dim>(&self, dim: D) -> Result<usize> {
|
||||||
|
match &self.0 {
|
||||||
|
None => Ok(0),
|
||||||
|
Some(v) => v.dim(dim),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.0 = None
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn narrow<D: Dim>(&self, dim: D, offset: usize, len: usize) -> Result<StreamTensor> {
|
||||||
|
let t = match &self.0 {
|
||||||
|
None => None,
|
||||||
|
Some(t) => {
|
||||||
|
let seq_len = t.dim(dim)?;
|
||||||
|
if seq_len <= offset {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?;
|
||||||
|
Some(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(Self(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Splits the Streaming Tensor on the time axis `dim` with the first `lhs_len` elements
|
||||||
|
/// returned in the first output and the remaining in the second output.
|
||||||
|
pub fn split<D: Dim>(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> {
|
||||||
|
match &self.0 {
|
||||||
|
None => Ok((Self::empty(), Self::empty())),
|
||||||
|
Some(t) => {
|
||||||
|
let seq_len = t.dim(dim)?;
|
||||||
|
let lhs_len = usize::min(seq_len, lhs_len);
|
||||||
|
if lhs_len == 0 {
|
||||||
|
Ok((Self::empty(), t.clone().into()))
|
||||||
|
} else {
|
||||||
|
let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?);
|
||||||
|
let rhs_len = seq_len - lhs_len;
|
||||||
|
let rhs = if rhs_len == 0 {
|
||||||
|
Self::empty()
|
||||||
|
} else {
|
||||||
|
Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?)
|
||||||
|
};
|
||||||
|
Ok((lhs, rhs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_option(&self) -> Option<&Tensor> {
|
||||||
|
self.0.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
||||||
|
match &self.0 {
|
||||||
|
None => Ok(Self::empty()),
|
||||||
|
Some(t) => Ok(Self::from_tensor(t.apply(m)?)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Streaming modules take as input a stream tensor and return a stream tensor. They may perform
|
||||||
|
/// some internal buffering so that enough data has been received for the module to be able to
|
||||||
|
/// perform some operations.
|
||||||
|
pub trait StreamingModule {
|
||||||
|
// TODO: Should we also have a flush method?
|
||||||
|
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor>;
|
||||||
|
fn reset_state(&mut self);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub enum BinOp {
|
||||||
|
Add,
|
||||||
|
Mul,
|
||||||
|
Sub,
|
||||||
|
Div,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StreamingBinOp {
|
||||||
|
prev_lhs: StreamTensor,
|
||||||
|
prev_rhs: StreamTensor,
|
||||||
|
pub op: BinOp,
|
||||||
|
pub dim: crate::D,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamingBinOp {
|
||||||
|
pub fn new(op: BinOp, dim: crate::D) -> Self {
|
||||||
|
Self {
|
||||||
|
prev_lhs: StreamTensor::empty(),
|
||||||
|
prev_rhs: StreamTensor::empty(),
|
||||||
|
op,
|
||||||
|
dim,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset_state(&mut self) {
|
||||||
|
self.prev_lhs.reset();
|
||||||
|
self.prev_rhs.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
|
||||||
|
match self.op {
|
||||||
|
BinOp::Add => Tensor::add(lhs, rhs),
|
||||||
|
BinOp::Mul => Tensor::mul(lhs, rhs),
|
||||||
|
BinOp::Sub => Tensor::sub(lhs, rhs),
|
||||||
|
BinOp::Div => Tensor::div(lhs, rhs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn step(&mut self, lhs: &StreamTensor, rhs: &StreamTensor) -> Result<StreamTensor> {
|
||||||
|
let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?;
|
||||||
|
let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?;
|
||||||
|
let lhs_len = lhs.seq_len(self.dim)?;
|
||||||
|
let rhs_len = rhs.seq_len(self.dim)?;
|
||||||
|
let common_len = usize::min(lhs_len, rhs_len);
|
||||||
|
let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?;
|
||||||
|
let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?;
|
||||||
|
let ys = match (lhs.0, rhs.0) {
|
||||||
|
(Some(lhs), Some(rhs)) => {
|
||||||
|
let ys = self.forward(&lhs, &rhs)?;
|
||||||
|
StreamTensor::from_tensor(ys)
|
||||||
|
}
|
||||||
|
(None, None) => StreamTensor::empty(),
|
||||||
|
(lhs, rhs) => crate::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"),
|
||||||
|
};
|
||||||
|
self.prev_lhs = prev_lhs;
|
||||||
|
self.prev_rhs = prev_rhs;
|
||||||
|
Ok(ys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Simple wrapper that doesn't do any buffering.
|
||||||
|
pub struct Map<T: crate::Module>(T);
|
||||||
|
|
||||||
|
impl<T: crate::Module> StreamingModule for Map<T> {
|
||||||
|
fn reset_state(&mut self) {}
|
||||||
|
|
||||||
|
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||||
|
xs.apply(&self.0)
|
||||||
|
}
|
||||||
|
}
|
@ -1,9 +1,7 @@
|
|||||||
//! Tensors are N-dimensional matrixes of elements using a single data type.
|
//! Tensors are N-dimensional matrixes of elements using a single data type.
|
||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{
|
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||||
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
|
||||||
};
|
|
||||||
use crate::scalar::TensorOrScalar;
|
use crate::scalar::TensorOrScalar;
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
@ -81,6 +79,9 @@ macro_rules! unary_op {
|
|||||||
($fn_name:ident, $op_name:ident) => {
|
($fn_name:ident, $op_name:ident) => {
|
||||||
pub fn $fn_name(&self) -> Result<Self> {
|
pub fn $fn_name(&self) -> Result<Self> {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
|
if shape.elem_count() == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||||
@ -94,6 +95,9 @@ macro_rules! binary_op {
|
|||||||
($fn_name:ident, $op_name:ident) => {
|
($fn_name:ident, $op_name:ident) => {
|
||||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||||
|
if shape.elem_count() == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||||
&*rhs.storage(),
|
&*rhs.storage(),
|
||||||
self.layout(),
|
self.layout(),
|
||||||
@ -116,6 +120,9 @@ macro_rules! binary_op_scalar {
|
|||||||
.broadcast_as(self.shape())?,
|
.broadcast_as(self.shape())?,
|
||||||
};
|
};
|
||||||
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
||||||
|
if self.elem_count() == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||||
&*rhs.storage(),
|
&*rhs.storage(),
|
||||||
self.layout(),
|
self.layout(),
|
||||||
@ -363,6 +370,15 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns a new tensor with all the elements having the same specified value. Note that
|
/// Returns a new tensor with all the elements having the same specified value. Note that
|
||||||
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
|
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(a.to_vec2::<f64>()?, &[
|
||||||
|
/// [3.5, 3.5, 3.5, 3.5],
|
||||||
|
/// [3.5, 3.5, 3.5, 3.5],
|
||||||
|
/// ]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
pub fn full<D: crate::WithDType, S: Into<Shape>>(
|
pub fn full<D: crate::WithDType, S: Into<Shape>>(
|
||||||
value: D,
|
value: D,
|
||||||
shape: S,
|
shape: S,
|
||||||
@ -372,6 +388,13 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new 1D tensor from an iterator.
|
/// Creates a new 1D tensor from an iterator.
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn from_iter<D: crate::WithDType>(
|
pub fn from_iter<D: crate::WithDType>(
|
||||||
iter: impl IntoIterator<Item = D>,
|
iter: impl IntoIterator<Item = D>,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -383,12 +406,26 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||||
/// difference `1` from `start`.
|
/// difference `1` from `start`.
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let a = Tensor::arange(2., 5., &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
|
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
|
||||||
Self::arange_step(start, end, D::one(), device)
|
Self::arange_step(start, end, D::one(), device)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||||
/// difference `step` from `start`.
|
/// difference `step` from `start`.
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn arange_step<D: crate::WithDType>(
|
pub fn arange_step<D: crate::WithDType>(
|
||||||
start: D,
|
start: D,
|
||||||
end: D,
|
end: D,
|
||||||
@ -434,6 +471,16 @@ impl Tensor {
|
|||||||
/// Creates a new tensor initialized with values from the input vector. The number of elements
|
/// Creates a new tensor initialized with values from the input vector. The number of elements
|
||||||
/// in this vector must be the same as the number of elements defined by the shape.
|
/// in this vector must be the same as the number of elements defined by the shape.
|
||||||
/// If the device is cpu, no data copy is made.
|
/// If the device is cpu, no data copy is made.
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(a.to_vec2::<f64>()?, &[
|
||||||
|
/// [1., 2., 3.],
|
||||||
|
/// [4., 5., 6.]
|
||||||
|
/// ]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||||
data: Vec<D>,
|
data: Vec<D>,
|
||||||
shape: S,
|
shape: S,
|
||||||
@ -444,12 +491,31 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Creates a new tensor initialized with values from the input slice. The number of elements
|
/// Creates a new tensor initialized with values from the input slice. The number of elements
|
||||||
/// in this vector must be the same as the number of elements defined by the shape.
|
/// in this vector must be the same as the number of elements defined by the shape.
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
|
||||||
|
/// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(a.to_vec2::<f64>()?, &[
|
||||||
|
/// [2., 3., 4.],
|
||||||
|
/// [5., 6., 7.]
|
||||||
|
/// ]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||||
array: &[D],
|
array: &[D],
|
||||||
shape: S,
|
shape: S,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Self::new_impl(array, shape.into(), device, false)
|
let shape = shape.into();
|
||||||
|
let n: usize = shape.elem_count();
|
||||||
|
let buffer_size: usize = array.len();
|
||||||
|
if buffer_size != n {
|
||||||
|
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||||
|
}
|
||||||
|
let storage = device.storage_from_slice(array)?;
|
||||||
|
let none = BackpropOp::none();
|
||||||
|
Ok(from_storage(storage, shape, none, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||||
@ -508,9 +574,11 @@ impl Tensor {
|
|||||||
unary_op!(gelu_erf, GeluErf);
|
unary_op!(gelu_erf, GeluErf);
|
||||||
unary_op!(erf, Erf);
|
unary_op!(erf, Erf);
|
||||||
unary_op!(relu, Relu);
|
unary_op!(relu, Relu);
|
||||||
|
unary_op!(silu, Silu);
|
||||||
unary_op!(ceil, Ceil);
|
unary_op!(ceil, Ceil);
|
||||||
unary_op!(floor, Floor);
|
unary_op!(floor, Floor);
|
||||||
unary_op!(round, Round);
|
unary_op!(round, Round);
|
||||||
|
unary_op!(sign, Sign);
|
||||||
|
|
||||||
/// Round element of the input tensor to the nearest integer.
|
/// Round element of the input tensor to the nearest integer.
|
||||||
///
|
///
|
||||||
@ -573,9 +641,9 @@ impl Tensor {
|
|||||||
///
|
///
|
||||||
/// * `args` - A slice of 1D tensors.
|
/// * `args` - A slice of 1D tensors.
|
||||||
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
|
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
|
||||||
/// first dimension corresponds to the cardinality of the second input and the second
|
/// first dimension corresponds to the cardinality of the second input and the second
|
||||||
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
||||||
/// dimensions are in the same order as the cardinality of the inputs.
|
/// dimensions are in the same order as the cardinality of the inputs.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
@ -646,6 +714,9 @@ impl Tensor {
|
|||||||
/// # Ok::<(), candle_core::Error>(())
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
/// ```
|
/// ```
|
||||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||||
|
if self.elem_count() == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let storage = self.storage().affine(self.layout(), mul, add)?;
|
let storage = self.storage().affine(self.layout(), mul, add)?;
|
||||||
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
|
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
@ -653,6 +724,9 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||||
|
if self.elem_count() == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let storage = self.storage().elu(self.layout(), alpha)?;
|
let storage = self.storage().elu(self.layout(), alpha)?;
|
||||||
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
|
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
@ -660,12 +734,15 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Raise the tensor to some float exponent `e`.
|
/// Raise the tensor to some float exponent `e`.
|
||||||
pub fn powf(&self, e: f64) -> Result<Self> {
|
pub fn powf(&self, e: f64) -> Result<Self> {
|
||||||
|
if self.elem_count() == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let storage = self.storage().powf(self.layout(), e)?;
|
let storage = self.storage().powf(self.layout(), e)?;
|
||||||
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
|
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||||
if dim >= self.dims().len() {
|
if dim >= self.dims().len() {
|
||||||
Err(Error::DimOutOfRange {
|
Err(Error::DimOutOfRange {
|
||||||
shape: self.shape().clone(),
|
shape: self.shape().clone(),
|
||||||
@ -706,6 +783,30 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||||
/// ranges from `start` to `start + len`.
|
/// ranges from `start` to `start + len`.
|
||||||
|
/// ```
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let a = Tensor::new(&[
|
||||||
|
/// [0f32, 1., 2.],
|
||||||
|
/// [3. , 4., 5.],
|
||||||
|
/// [6. , 7., 8.]
|
||||||
|
/// ], &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let b = a.narrow(0, 1, 2)?;
|
||||||
|
/// assert_eq!(b.shape().dims(), &[2, 3]);
|
||||||
|
/// assert_eq!(b.to_vec2::<f32>()?, &[
|
||||||
|
/// [3., 4., 5.],
|
||||||
|
/// [6., 7., 8.]
|
||||||
|
/// ]);
|
||||||
|
///
|
||||||
|
/// let c = a.narrow(1, 1, 1)?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[3, 1]);
|
||||||
|
/// assert_eq!(c.to_vec2::<f32>()?, &[
|
||||||
|
/// [1.],
|
||||||
|
/// [4.],
|
||||||
|
/// [7.]
|
||||||
|
/// ]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
||||||
let dims = self.dims();
|
let dims = self.dims();
|
||||||
let dim = dim.to_index(self.shape(), "narrow")?;
|
let dim = dim.to_index(self.shape(), "narrow")?;
|
||||||
@ -804,6 +905,35 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Roll the tensor input along the given dimension.
|
||||||
|
/// Elements that are shifted beyond the last position are re-introduced at the first position.
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use candle_core::{Tensor, Device};
|
||||||
|
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||||
|
/// let tensor = tensor.roll(1, 0)?;
|
||||||
|
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
|
||||||
|
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||||
|
/// let tensor = tensor.roll(-1, 0)?;
|
||||||
|
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
|
||||||
|
where
|
||||||
|
D: Dim + Clone,
|
||||||
|
{
|
||||||
|
let dim = dim.to_index(self.shape(), "roll")?;
|
||||||
|
let dim_size = self.dim(dim)?;
|
||||||
|
let shift = shift.rem_euclid(dim_size as i32) as usize;
|
||||||
|
if shift == 0 {
|
||||||
|
Ok(self.clone())
|
||||||
|
} else {
|
||||||
|
let a = self.narrow(dim, 0, dim_size - shift)?;
|
||||||
|
let b = self.narrow(dim, dim_size - shift, shift)?;
|
||||||
|
Tensor::cat(&[&b, &a], dim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||||
/// input dimensions.
|
/// input dimensions.
|
||||||
///
|
///
|
||||||
@ -985,7 +1115,7 @@ impl Tensor {
|
|||||||
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
||||||
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
||||||
let (n, c, _l) = self.dims3()?;
|
let (n, c, _l) = self.dims3()?;
|
||||||
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
|
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
|
||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.upsample_nearest1d(self.layout(), target_size)?;
|
.upsample_nearest1d(self.layout(), target_size)?;
|
||||||
@ -1125,6 +1255,9 @@ impl Tensor {
|
|||||||
let n = b_dims[dim - 1];
|
let n = b_dims[dim - 1];
|
||||||
|
|
||||||
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||||
|
if c_shape.elem_count() == 0 || k == 0 {
|
||||||
|
return Tensor::zeros(c_shape, self.dtype(), self.device());
|
||||||
|
}
|
||||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||||
let batching_b: usize = b_dims[..dim - 2].iter().product();
|
let batching_b: usize = b_dims[..dim - 2].iter().product();
|
||||||
if k != k2 || batching != batching_b {
|
if k != k2 || batching != batching_b {
|
||||||
@ -1321,7 +1454,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
.bt())?
|
.bt())?
|
||||||
}
|
}
|
||||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
|
||||||
self.storage()
|
self.storage()
|
||||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
let offset = start * src.dims()[1..].iter().product::<usize>();
|
let offset = start * src.dims()[1..].iter().product::<usize>();
|
||||||
@ -1853,9 +1986,9 @@ impl Tensor {
|
|||||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||||
///
|
///
|
||||||
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||||
pub fn detach(&self) -> Result<Tensor> {
|
pub fn detach(&self) -> Tensor {
|
||||||
if self.op.is_none() && !self.is_variable {
|
if self.op.is_none() && !self.is_variable {
|
||||||
Ok(self.clone())
|
self.clone()
|
||||||
} else {
|
} else {
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
@ -1866,7 +1999,7 @@ impl Tensor {
|
|||||||
dtype: self.dtype,
|
dtype: self.dtype,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
};
|
};
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
Tensor(Arc::new(tensor_))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1892,7 +2025,11 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||||
_ => {
|
_ => {
|
||||||
bail!("not implemented yet")
|
bail!(
|
||||||
|
"not implemented yet, self.device: {:?}, device: {:?}",
|
||||||
|
self.device(),
|
||||||
|
device
|
||||||
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||||
@ -1971,7 +2108,7 @@ impl Tensor {
|
|||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
} else {
|
} else {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||||
self.storage()
|
self.storage()
|
||||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
let op = BackpropOp::new1(self, Op::Copy);
|
let op = BackpropOp::new1(self, Op::Copy);
|
||||||
@ -1979,11 +2116,21 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a tensor that is in row major order. This always makes a copy.
|
||||||
|
pub fn force_contiguous(&self) -> Result<Tensor> {
|
||||||
|
let shape = self.shape();
|
||||||
|
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||||
|
self.storage()
|
||||||
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
|
let op = BackpropOp::new1(self, Op::Copy);
|
||||||
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
||||||
/// copied.
|
/// copied.
|
||||||
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
||||||
let shape = self.shape().clone();
|
let shape = self.shape().clone();
|
||||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
|
||||||
self.storage()
|
self.storage()
|
||||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), true))
|
Ok(from_storage(storage, shape, BackpropOp::none(), true))
|
||||||
@ -2036,7 +2183,7 @@ impl Tensor {
|
|||||||
};
|
};
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
} else {
|
} else {
|
||||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
|
||||||
self.storage()
|
self.storage()
|
||||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
@ -2063,8 +2210,19 @@ impl Tensor {
|
|||||||
let dim = dim.to_index(self.shape(), "squeeze")?;
|
let dim = dim.to_index(self.shape(), "squeeze")?;
|
||||||
if dims[dim] == 1 {
|
if dims[dim] == 1 {
|
||||||
let mut dims = dims.to_vec();
|
let mut dims = dims.to_vec();
|
||||||
|
let mut strides = self.stride().to_vec();
|
||||||
dims.remove(dim);
|
dims.remove(dim);
|
||||||
self.reshape(dims)
|
strides.remove(dim);
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage: self.storage.clone(),
|
||||||
|
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
|
||||||
|
op: BackpropOp::new1(self, Op::Reshape),
|
||||||
|
is_variable: false,
|
||||||
|
dtype: self.dtype,
|
||||||
|
device: self.device.clone(),
|
||||||
|
};
|
||||||
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
} else {
|
} else {
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
}
|
}
|
||||||
@ -2085,10 +2243,24 @@ impl Tensor {
|
|||||||
/// ```
|
/// ```
|
||||||
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
let mut dims = self.dims().to_vec();
|
let mut dims = self.dims().to_vec();
|
||||||
|
let mut strides = self.stride().to_vec();
|
||||||
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
|
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
|
||||||
// Cannot panic because to_index_plus_one already checks dimensions
|
// Cannot panic because to_index_plus_one already checks dimensions
|
||||||
dims.insert(dim, 1);
|
dims.insert(dim, 1);
|
||||||
self.reshape(dims)
|
// Any stride would work here, but we pick one so as to maximize the probability to remain
|
||||||
|
// C contiguous.
|
||||||
|
let stride = if dim < strides.len() { strides[dim] } else { 1 };
|
||||||
|
strides.insert(dim, stride);
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage: self.storage.clone(),
|
||||||
|
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
|
||||||
|
op: BackpropOp::new1(self, Op::Reshape),
|
||||||
|
is_variable: false,
|
||||||
|
dtype: self.dtype,
|
||||||
|
device: self.device.clone(),
|
||||||
|
};
|
||||||
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stacks two or more tensors along a particular dimension.
|
/// Stacks two or more tensors along a particular dimension.
|
||||||
@ -2119,152 +2291,6 @@ impl Tensor {
|
|||||||
Self::cat(&args, dim)
|
Self::cat(&args, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Concatenates two or more tensors along a particular dimension.
|
|
||||||
///
|
|
||||||
/// All tensors must of the same rank, and the output will have
|
|
||||||
/// the same rank
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use candle_core::{Tensor, DType, Device};
|
|
||||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
|
||||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
|
||||||
///
|
|
||||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
|
||||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
|
||||||
///
|
|
||||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
|
||||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
|
||||||
/// # Ok::<(), candle_core::Error>(())
|
|
||||||
/// ```
|
|
||||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
|
||||||
if args.is_empty() {
|
|
||||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
|
||||||
}
|
|
||||||
let arg0 = args[0].as_ref();
|
|
||||||
if args.len() == 1 {
|
|
||||||
return Ok(arg0.clone());
|
|
||||||
}
|
|
||||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
|
||||||
for arg in args {
|
|
||||||
arg.as_ref().check_dim(dim, "cat")?;
|
|
||||||
}
|
|
||||||
for (arg_idx, arg) in args.iter().enumerate() {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
if arg0.rank() != arg.rank() {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: arg0.rank(),
|
|
||||||
got: arg.rank(),
|
|
||||||
shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
for (dim_idx, (v1, v2)) in arg0
|
|
||||||
.shape()
|
|
||||||
.dims()
|
|
||||||
.iter()
|
|
||||||
.zip(arg.shape().dims().iter())
|
|
||||||
.enumerate()
|
|
||||||
{
|
|
||||||
if dim_idx != dim && v1 != v2 {
|
|
||||||
Err(Error::ShapeMismatchCat {
|
|
||||||
dim: dim_idx,
|
|
||||||
first_shape: arg0.shape().clone(),
|
|
||||||
n: arg_idx + 1,
|
|
||||||
nth_shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if dim == 0 {
|
|
||||||
Self::cat0(args)
|
|
||||||
} else {
|
|
||||||
// TODO: Avoid these transpositions and have an implementation that works
|
|
||||||
// for dim != 0...
|
|
||||||
let args: Vec<Tensor> = args
|
|
||||||
.iter()
|
|
||||||
.map(|a| a.as_ref().transpose(0, dim))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let cat = Self::cat0(&args)?;
|
|
||||||
cat.transpose(0, dim)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
|
||||||
if args.is_empty() {
|
|
||||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
|
||||||
}
|
|
||||||
let arg0 = args[0].as_ref();
|
|
||||||
if args.len() == 1 {
|
|
||||||
return Ok(arg0.clone());
|
|
||||||
}
|
|
||||||
let rank = arg0.rank();
|
|
||||||
let device = arg0.device();
|
|
||||||
let dtype = arg0.dtype();
|
|
||||||
let first_dims = arg0.shape().dims();
|
|
||||||
let mut cat_dims = first_dims.to_vec();
|
|
||||||
cat_dims[0] = 0;
|
|
||||||
let mut offsets = vec![0usize];
|
|
||||||
for (arg_idx, arg) in args.iter().enumerate() {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
if arg.dtype() != dtype {
|
|
||||||
Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: dtype,
|
|
||||||
rhs: arg.dtype(),
|
|
||||||
op: "cat",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
if arg.device().location() != device.location() {
|
|
||||||
Err(Error::DeviceMismatchBinaryOp {
|
|
||||||
lhs: device.location(),
|
|
||||||
rhs: arg.device().location(),
|
|
||||||
op: "cat",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
if rank != arg.rank() {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: rank,
|
|
||||||
got: arg.rank(),
|
|
||||||
shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
for (dim_idx, (v1, v2)) in arg0
|
|
||||||
.shape()
|
|
||||||
.dims()
|
|
||||||
.iter()
|
|
||||||
.zip(arg.shape().dims().iter())
|
|
||||||
.enumerate()
|
|
||||||
{
|
|
||||||
if dim_idx == 0 {
|
|
||||||
cat_dims[0] += v2;
|
|
||||||
}
|
|
||||||
if dim_idx != 0 && v1 != v2 {
|
|
||||||
Err(Error::ShapeMismatchCat {
|
|
||||||
dim: dim_idx,
|
|
||||||
first_shape: arg0.shape().clone(),
|
|
||||||
n: arg_idx + 1,
|
|
||||||
nth_shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
|
||||||
offsets.push(next_offset);
|
|
||||||
}
|
|
||||||
let shape = Shape::from(cat_dims);
|
|
||||||
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
|
|
||||||
let mut storage = device.zeros(&shape, dtype)?;
|
|
||||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
arg.storage()
|
|
||||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
|
||||||
}
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
|
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
|
||||||
/// input tensor values and `right` elements after.
|
/// input tensor values and `right` elements after.
|
||||||
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||||
@ -2347,6 +2373,10 @@ impl Tensor {
|
|||||||
self.storage.read().unwrap()
|
self.storage.read().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
|
||||||
|
self.storage.write().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
// If we extend the visibility of this function to be usable outside of this crate, we should
|
// If we extend the visibility of this function to be usable outside of this crate, we should
|
||||||
// make it unsafe.
|
// make it unsafe.
|
||||||
pub(crate) fn storage_mut_and_layout(
|
pub(crate) fn storage_mut_and_layout(
|
||||||
@ -2368,96 +2398,6 @@ impl Tensor {
|
|||||||
std::ptr::eq(lhs, rhs)
|
std::ptr::eq(lhs, rhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies a unary custom op without backward support
|
|
||||||
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a binary custom op without backward support
|
|
||||||
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
|
||||||
let (storage, shape) =
|
|
||||||
self.storage()
|
|
||||||
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a ternary custom op without backward support
|
|
||||||
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op3(
|
|
||||||
self.layout(),
|
|
||||||
&t2.storage(),
|
|
||||||
t2.layout(),
|
|
||||||
&t3.storage(),
|
|
||||||
t3.layout(),
|
|
||||||
c,
|
|
||||||
)?;
|
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a unary custom op.
|
|
||||||
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
|
||||||
let (storage, shape) = self
|
|
||||||
.storage()
|
|
||||||
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
|
||||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
|
||||||
self.apply_op1_arc(Arc::new(Box::new(c)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a binary custom op.
|
|
||||||
pub fn apply_op2_arc(
|
|
||||||
&self,
|
|
||||||
rhs: &Self,
|
|
||||||
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op2(
|
|
||||||
self.layout(),
|
|
||||||
&rhs.storage(),
|
|
||||||
rhs.layout(),
|
|
||||||
c.as_ref().as_ref(),
|
|
||||||
)?;
|
|
||||||
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
|
||||||
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a ternary custom op.
|
|
||||||
pub fn apply_op3_arc(
|
|
||||||
&self,
|
|
||||||
t2: &Self,
|
|
||||||
t3: &Self,
|
|
||||||
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op3(
|
|
||||||
self.layout(),
|
|
||||||
&t2.storage(),
|
|
||||||
t2.layout(),
|
|
||||||
&t3.storage(),
|
|
||||||
t3.layout(),
|
|
||||||
c.as_ref().as_ref(),
|
|
||||||
)?;
|
|
||||||
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
|
||||||
Op::CustomOp3(t1, t2, t3, c.clone())
|
|
||||||
});
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
|
||||||
&self,
|
|
||||||
t2: &Self,
|
|
||||||
t3: &Self,
|
|
||||||
c: C,
|
|
||||||
) -> Result<Self> {
|
|
||||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Normalize a 'relative' axis value: positive values are kept, negative
|
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||||
/// values means counting the dimensions from the back.
|
/// values means counting the dimensions from the back.
|
||||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||||
@ -2579,9 +2519,19 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns log(sum(exp(tensor), dim)).
|
/// Returns log(sum(exp(tensor), dim)).
|
||||||
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||||
let exp = self.exp()?;
|
let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
|
||||||
let sum = exp.sum(sum_dims)?;
|
if sum_dims.is_empty() {
|
||||||
sum.log()
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
|
let max = sum_dims[1..]
|
||||||
|
.iter()
|
||||||
|
.try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
|
||||||
|
max.max_keepdim(dim)
|
||||||
|
})?;
|
||||||
|
let exp = self.broadcast_sub(&max)?.exp()?;
|
||||||
|
let sum = exp.sum(sum_dims.clone())?;
|
||||||
|
|
||||||
|
sum.log()? + max.squeeze_dims(&sum_dims)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pointwise pow operation.
|
/// Pointwise pow operation.
|
||||||
|
300
candle-core/src/tensor_cat.rs
Normal file
300
candle-core/src/tensor_cat.rs
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
use crate::{shape::Dim, Error, Result, Shape, Tensor};
|
||||||
|
|
||||||
|
impl Tensor {
|
||||||
|
/// Concatenates two or more tensors along a particular dimension.
|
||||||
|
///
|
||||||
|
/// All tensors must of the same rank, and the output will have
|
||||||
|
/// the same rank
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use candle_core::{Tensor, DType, Device};
|
||||||
|
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||||
|
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
||||||
|
///
|
||||||
|
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||||
|
if args.is_empty() {
|
||||||
|
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||||
|
}
|
||||||
|
let arg0 = args[0].as_ref();
|
||||||
|
if args.len() == 1 {
|
||||||
|
return Ok(arg0.clone());
|
||||||
|
}
|
||||||
|
let dim = dim.to_index(arg0.shape(), "cat")?;
|
||||||
|
for arg in args {
|
||||||
|
arg.as_ref().check_dim(dim, "cat")?;
|
||||||
|
}
|
||||||
|
for (arg_idx, arg) in args.iter().enumerate() {
|
||||||
|
let arg = arg.as_ref();
|
||||||
|
if arg0.rank() != arg.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: arg0.rank(),
|
||||||
|
got: arg.rank(),
|
||||||
|
shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
for (dim_idx, (v1, v2)) in arg0
|
||||||
|
.shape()
|
||||||
|
.dims()
|
||||||
|
.iter()
|
||||||
|
.zip(arg.shape().dims().iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
if dim_idx != dim && v1 != v2 {
|
||||||
|
Err(Error::ShapeMismatchCat {
|
||||||
|
dim: dim_idx,
|
||||||
|
first_shape: arg0.shape().clone(),
|
||||||
|
n: arg_idx + 1,
|
||||||
|
nth_shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
||||||
|
if all_contiguous {
|
||||||
|
Self::cat_contiguous(args, dim)
|
||||||
|
} else if dim == 0 {
|
||||||
|
Self::cat0(args)
|
||||||
|
} else {
|
||||||
|
let args: Vec<Tensor> = args
|
||||||
|
.iter()
|
||||||
|
.map(|a| a.as_ref().transpose(0, dim))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let cat = Self::cat0(&args)?;
|
||||||
|
cat.transpose(0, dim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||||
|
if args.is_empty() {
|
||||||
|
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||||
|
}
|
||||||
|
let arg0 = args[0].as_ref();
|
||||||
|
if args.len() == 1 {
|
||||||
|
return Ok(arg0.clone());
|
||||||
|
}
|
||||||
|
let rank = arg0.rank();
|
||||||
|
let device = arg0.device();
|
||||||
|
let dtype = arg0.dtype();
|
||||||
|
let first_dims = arg0.shape().dims();
|
||||||
|
let mut cat_dims = first_dims.to_vec();
|
||||||
|
cat_dims[0] = 0;
|
||||||
|
let mut offsets = vec![0usize];
|
||||||
|
for (arg_idx, arg) in args.iter().enumerate() {
|
||||||
|
let arg = arg.as_ref();
|
||||||
|
if arg.dtype() != dtype {
|
||||||
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: dtype,
|
||||||
|
rhs: arg.dtype(),
|
||||||
|
op: "cat",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if arg.device().location() != device.location() {
|
||||||
|
Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: device.location(),
|
||||||
|
rhs: arg.device().location(),
|
||||||
|
op: "cat",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if rank != arg.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: rank,
|
||||||
|
got: arg.rank(),
|
||||||
|
shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
for (dim_idx, (v1, v2)) in arg0
|
||||||
|
.shape()
|
||||||
|
.dims()
|
||||||
|
.iter()
|
||||||
|
.zip(arg.shape().dims().iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
if dim_idx == 0 {
|
||||||
|
cat_dims[0] += v2;
|
||||||
|
}
|
||||||
|
if dim_idx != 0 && v1 != v2 {
|
||||||
|
Err(Error::ShapeMismatchCat {
|
||||||
|
dim: dim_idx,
|
||||||
|
first_shape: arg0.shape().clone(),
|
||||||
|
n: arg_idx + 1,
|
||||||
|
nth_shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||||
|
offsets.push(next_offset);
|
||||||
|
}
|
||||||
|
let shape = Shape::from(cat_dims);
|
||||||
|
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
|
||||||
|
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||||
|
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||||
|
let arg = arg.as_ref();
|
||||||
|
arg.storage()
|
||||||
|
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||||
|
}
|
||||||
|
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
||||||
|
if args.is_empty() {
|
||||||
|
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||||
|
}
|
||||||
|
let arg0 = args[0].as_ref();
|
||||||
|
if args.len() == 1 {
|
||||||
|
return Ok(arg0.clone());
|
||||||
|
}
|
||||||
|
let rank = arg0.rank();
|
||||||
|
let device = arg0.device();
|
||||||
|
let dtype = arg0.dtype();
|
||||||
|
let first_dims = arg0.shape().dims();
|
||||||
|
let mut cat_dims = first_dims.to_vec();
|
||||||
|
cat_dims[dim] = 0;
|
||||||
|
for (arg_idx, arg) in args.iter().enumerate() {
|
||||||
|
let arg = arg.as_ref();
|
||||||
|
if arg.dtype() != dtype {
|
||||||
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: dtype,
|
||||||
|
rhs: arg.dtype(),
|
||||||
|
op: "cat",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if arg.device().location() != device.location() {
|
||||||
|
Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: device.location(),
|
||||||
|
rhs: arg.device().location(),
|
||||||
|
op: "cat",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if rank != arg.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: rank,
|
||||||
|
got: arg.rank(),
|
||||||
|
shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
for (dim_idx, (v1, v2)) in arg0
|
||||||
|
.shape()
|
||||||
|
.dims()
|
||||||
|
.iter()
|
||||||
|
.zip(arg.shape().dims().iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
if dim_idx == dim {
|
||||||
|
cat_dims[dim] += v2;
|
||||||
|
}
|
||||||
|
if dim_idx != dim && v1 != v2 {
|
||||||
|
Err(Error::ShapeMismatchCat {
|
||||||
|
dim: dim_idx,
|
||||||
|
first_shape: arg0.shape().clone(),
|
||||||
|
n: arg_idx + 1,
|
||||||
|
nth_shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let cat_target_dim_len = cat_dims[dim];
|
||||||
|
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
|
||||||
|
let shape = Shape::from(cat_dims);
|
||||||
|
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
|
||||||
|
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||||
|
let mut dst_o = 0;
|
||||||
|
for arg in args.iter() {
|
||||||
|
let arg = arg.as_ref();
|
||||||
|
let arg_dims = arg.shape().dims();
|
||||||
|
let d1: usize = arg_dims.iter().take(dim).product();
|
||||||
|
let d2 = block_size * arg_dims[dim];
|
||||||
|
let dst_s = block_size * cat_target_dim_len;
|
||||||
|
let src_o = arg.layout().start_offset();
|
||||||
|
arg.storage().copy2d(
|
||||||
|
&mut storage,
|
||||||
|
d1,
|
||||||
|
d2,
|
||||||
|
/* src_s */ d2,
|
||||||
|
dst_s,
|
||||||
|
src_o,
|
||||||
|
dst_o,
|
||||||
|
)?;
|
||||||
|
dst_o += d2;
|
||||||
|
}
|
||||||
|
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the values on `self` using values from `src`. The copy starts at the specified
|
||||||
|
/// `offset` for the target dimension `dim` on `self`.
|
||||||
|
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||||
|
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||||
|
///
|
||||||
|
/// Note that this modifies `self` in place and as such is not compatibel with
|
||||||
|
/// back-propagation.
|
||||||
|
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||||
|
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||||
|
if !self.is_contiguous() || !src.is_contiguous() {
|
||||||
|
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
|
||||||
|
}
|
||||||
|
if self.dtype() != src.dtype() {
|
||||||
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: self.dtype(),
|
||||||
|
rhs: src.dtype(),
|
||||||
|
op: "slice-set",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if self.device().location() != src.device().location() {
|
||||||
|
Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: self.device().location(),
|
||||||
|
rhs: src.device().location(),
|
||||||
|
op: "slice-set",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if self.rank() != src.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: self.rank(),
|
||||||
|
got: src.rank(),
|
||||||
|
shape: self.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
|
||||||
|
if dim_idx == dim && *v2 + offset > *v1 {
|
||||||
|
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
|
||||||
|
}
|
||||||
|
if dim_idx != dim && v1 != v2 {
|
||||||
|
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let block_size: usize = src.dims().iter().skip(1 + dim).product();
|
||||||
|
let d1: usize = src.dims().iter().take(dim).product();
|
||||||
|
let d2 = block_size * src.dims()[dim];
|
||||||
|
let dst_o = self.layout().start_offset() + offset * block_size;
|
||||||
|
let src_o = src.layout().start_offset();
|
||||||
|
src.storage().copy2d(
|
||||||
|
&mut self.storage_mut(),
|
||||||
|
d1,
|
||||||
|
d2,
|
||||||
|
/* src_s */ d2,
|
||||||
|
/* dst_s */ block_size * self.dims()[dim],
|
||||||
|
src_o,
|
||||||
|
dst_o,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
@ -34,9 +34,14 @@ impl Var {
|
|||||||
Ok(Self(inner))
|
Ok(Self(inner))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
|
||||||
pub fn from_tensor(t: &Tensor) -> Result<Self> {
|
pub fn from_tensor(t: &Tensor) -> Result<Self> {
|
||||||
let inner = t.make_var()?;
|
if t.is_variable() {
|
||||||
Ok(Self(inner))
|
Ok(Self(t.clone()))
|
||||||
|
} else {
|
||||||
|
let inner = t.make_var()?;
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rand_f64<S: Into<Shape>>(
|
pub fn rand_f64<S: Into<Shape>>(
|
||||||
@ -107,6 +112,10 @@ impl Var {
|
|||||||
Ok(Self(inner))
|
Ok(Self(inner))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn as_detached_tensor(&self) -> Tensor {
|
||||||
|
self.0.detach()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn as_tensor(&self) -> &Tensor {
|
pub fn as_tensor(&self) -> &Tensor {
|
||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,9 @@ w_t = w.transpose(0, 1)
|
|||||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||||
print(res.shape)
|
print(res.shape)
|
||||||
print(res)
|
print(res)
|
||||||
|
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
|
||||||
|
print(res.shape)
|
||||||
|
print(res)
|
||||||
*/
|
*/
|
||||||
fn conv1d(dev: &Device) -> Result<()> {
|
fn conv1d(dev: &Device) -> Result<()> {
|
||||||
let t = Tensor::new(
|
let t = Tensor::new(
|
||||||
@ -50,8 +53,11 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||||
);
|
);
|
||||||
if dev.is_cpu() {
|
|
||||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
let w = w.transpose(0, 1)?;
|
||||||
|
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
||||||
|
for w in [w.clone(), w.contiguous()?] {
|
||||||
|
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 1)?;
|
||||||
assert_eq!(res.dims(), [1, 2, 7]);
|
assert_eq!(res.dims(), [1, 2, 7]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
@ -60,6 +66,17 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
4.7076, -5.9745, -0.8276, 1.621
|
4.7076, -5.9745, -0.8276, 1.621
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 2)?;
|
||||||
|
assert_eq!(res.dims(), [1, 4, 7]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
||||||
|
[
|
||||||
|
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
||||||
|
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
||||||
|
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
||||||
|
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
||||||
|
]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -118,7 +135,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||||
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||||
],
|
],
|
||||||
dev,
|
dev,
|
||||||
)?;
|
)?;
|
||||||
@ -146,7 +163,9 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
|
|
||||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||||
@ -171,6 +190,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Dilations.
|
// Dilations.
|
||||||
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
||||||
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
||||||
@ -209,6 +229,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -255,13 +276,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[
|
[
|
||||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1640,
|
||||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
|
-0.0111, -0.1742, 0.0, 0.0, 0.0, 0.0, 2.6437, -2.0268, 1.1823, 0.0, 0.0, 0.0, 0.0,
|
||||||
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
|
3.2855, -1.0324, 0.2539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
0.0, 0.0, 0.0, 0.0
|
||||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -363,6 +384,7 @@ print(w.grad.shape)
|
|||||||
print(w.grad[0])
|
print(w.grad[0])
|
||||||
*/
|
*/
|
||||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||||
|
// conv-transposes are not implemented for metal
|
||||||
use candle_core::Var;
|
use candle_core::Var;
|
||||||
let t = Var::from_slice(
|
let t = Var::from_slice(
|
||||||
&[
|
&[
|
||||||
@ -375,7 +397,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
|||||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||||
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||||
],
|
],
|
||||||
(1, 4, 5, 5),
|
(1, 4, 5, 5),
|
||||||
dev,
|
dev,
|
||||||
@ -560,6 +582,251 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Conv Transpose 2d Test
|
||||||
|
//tested against following python
|
||||||
|
|
||||||
|
// import torch
|
||||||
|
// torch.manual_seed(4242)
|
||||||
|
// padding = 4
|
||||||
|
// outpadding = 2
|
||||||
|
// dilation = 3
|
||||||
|
// stride = 3
|
||||||
|
// input = torch.randn((1, 4, 7, 5), requires_grad=True)
|
||||||
|
// kernel = torch.randn((4, 2, 3, 5), requires_grad=True)
|
||||||
|
// print("input", input.flatten())
|
||||||
|
// print("kernel", kernel.flatten())
|
||||||
|
// res = torch.nn.functional.conv_transpose2d(
|
||||||
|
// input,
|
||||||
|
// kernel,
|
||||||
|
// stride=stride,
|
||||||
|
// padding=padding,
|
||||||
|
// dilation=dilation,
|
||||||
|
// output_padding=outpadding,
|
||||||
|
// )
|
||||||
|
// res.retain_grad()
|
||||||
|
// print(res.shape)
|
||||||
|
// loss = (res**2).sum()
|
||||||
|
// print(loss)
|
||||||
|
// loss.backward()
|
||||||
|
// print(input.grad.shape)
|
||||||
|
// print("input grad", torch.round(input.grad, decimals=1))
|
||||||
|
// print(kernel.grad.shape)
|
||||||
|
// print("kernel grad", torch.round(kernel.grad.flatten(), decimals=1))
|
||||||
|
|
||||||
|
let padding = 4;
|
||||||
|
let outpadding = 2;
|
||||||
|
let dilation = 3;
|
||||||
|
let stride = 3;
|
||||||
|
|
||||||
|
let t = Var::from_slice(
|
||||||
|
&[
|
||||||
|
0.4056_f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997,
|
||||||
|
3.0616, 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843,
|
||||||
|
0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013,
|
||||||
|
-0.6836, 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130,
|
||||||
|
1.3123, 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071,
|
||||||
|
1.1586, 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090,
|
||||||
|
0.2049, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323,
|
||||||
|
-1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742,
|
||||||
|
0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.5411, -1.1715, -2.4997, 2.3249, -0.8912,
|
||||||
|
-0.4733, -0.5701, -2.8888, -1.4112, -0.5471, -0.9234, -1.1660, 0.4189, -0.7465,
|
||||||
|
-0.6473, 0.1402, 0.7875, 0.5377, -0.6779, -0.8088, -0.4864, -0.2312, 0.9279, 0.1264,
|
||||||
|
1.5480, 0.8265, -0.1025, 0.5138, -0.2512, 0.1576, 1.2705, 0.3641, -0.9325, 0.6451,
|
||||||
|
-0.8537, 0.2378, 0.1794, 0.2752, -0.3687, -1.1149, -0.1410, -0.5829, -0.0892, 1.4258,
|
||||||
|
-2.2789, 0.5270, 0.1825, 1.7007, -0.5263, -0.2954, 0.4440, 0.5537, 0.3492, 0.6186,
|
||||||
|
1.6475, 0.2219,
|
||||||
|
],
|
||||||
|
(1, 4, 7, 5),
|
||||||
|
dev,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let w = Var::from_slice(
|
||||||
|
&[
|
||||||
|
-1.1744_f32, 0.3266, 2.5893, 1.0142, 0.1763, 0.7752, 0.6604, 0.2029, -0.2145, 0.7234,
|
||||||
|
-0.3441, -1.5400, -0.6333, 0.6613, 0.2083, 0.6230, -1.7002, 0.3393, 0.4049, 1.0762,
|
||||||
|
0.2723, 1.4181, 0.0029, -0.2122, 1.7668, 1.4168, 0.3320, -0.2719, 0.7932, -0.7204,
|
||||||
|
0.4447, 0.1211, 0.5908, 1.0089, -0.1646, 1.8033, -0.6286, 0.2016, -0.3370, 1.2555,
|
||||||
|
0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026, 0.8828, 2.4990,
|
||||||
|
0.6811, -0.3369, 1.3320, 1.7669, -1.1067, 1.2958, -0.9415, -0.9655, -0.4462, 0.7181,
|
||||||
|
0.5181, -1.1658, -1.8467, -0.7763, 1.2769, 0.8651, 0.9890, 1.5092, 0.7207, -0.8481,
|
||||||
|
0.7417, 0.3375, -1.2685, 1.4572, 1.0915, 0.1093, -0.8550, -0.5831, -0.6309, -0.2509,
|
||||||
|
0.5220, -0.0914, 0.7900, 0.1096, 0.3258, 0.2723, -1.0942, -0.3393, -0.1653, 0.5732,
|
||||||
|
-0.8014, 1.8194, -1.9023, 0.2127, 1.8636, -0.8979, 0.1927, -0.2778, 0.3105, 0.0071,
|
||||||
|
-1.1823, 0.2476, -0.7178, -1.3821, 1.0769, -0.4376, -0.9967, -0.1227, 1.6197, -1.0604,
|
||||||
|
0.1372, 0.8141, -0.6163, 0.7304, -0.8285, 2.0636, -0.7176, 0.2495, -0.2581, -0.4478,
|
||||||
|
],
|
||||||
|
(4, 2, 3, 5),
|
||||||
|
dev,
|
||||||
|
)?;
|
||||||
|
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
|
||||||
|
let loss = res.sqr()?.sum_all()?;
|
||||||
|
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 2904.0);
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
|
||||||
|
let grad_t = grads.get(&t).unwrap();
|
||||||
|
let grad_w = grads.get(&w).unwrap();
|
||||||
|
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
|
||||||
|
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
|
||||||
|
[
|
||||||
|
// torch gets 89.1
|
||||||
|
-89.0, -135.3, 136.7, 102.0, -53.4, 117.9, 118.6, -43.9, -218.0, -58.5, -114.3, -150.0,
|
||||||
|
-15.6, 172.1, 66.3, -64.3, -27.9, -19.8, 31.7, 62.1, 5.5, 92.6, 28.2, -29.6, 55.9,
|
||||||
|
52.7, -72.7, -119.8, 53.8, -25.5, 128.8, 19.3, 68.0, 190.9, -64.1, -86.2, -111.2,
|
||||||
|
106.6, -67.7, 37.8, 115.9, 50.4, -77.7, -54.9, 22.3, -4.6, 89.8, 61.7, 122.4, 192.6,
|
||||||
|
-27.8, -104.6, 57.0, 166.4, 27.1, 6.1, 18.7, -93.2, 31.5, 168.2, -3.7, -99.5, -55.5,
|
||||||
|
-10.8, 17.5, 20.8, 16.9, 43.8, 42.0, -89.2, 18.8, -9.6, -84.1, 212.6, 19.7, -50.0,
|
||||||
|
-52.0, -40.0, -166.6, -73.2, -10.8, -73.3, 31.5, -23.4, -79.3, -27.0, -84.4, -42.9,
|
||||||
|
-20.3, 51.8, -16.7, 76.3, -120.5, -65.8, 96.5, -10.7, -45.9, -88.1, 65.4, -7.0, -1.5,
|
||||||
|
92.8, -25.1, -114.2, -5.8, -14.8, -51.2, -20.7, 54.2, -79.8, 47.7, -29.2, -8.8, 53.5,
|
||||||
|
-28.4, 85.0, -18.3, 107.0, 28.3, -71.8
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[32.3, -41.6, -24.0, 14.1, 17.6],
|
||||||
|
[-11.8, 72.5, 87.6, 46.4, 61.5],
|
||||||
|
[115.0, 108.5, -48.6, -63.4, -50.0],
|
||||||
|
[51.3, 5.4, 31.3, 91.1, -30.9],
|
||||||
|
[52.7, 92.8, -68.0, -47.0, 83.0],
|
||||||
|
// pytorch gets -107.1
|
||||||
|
[-10.2, -107.0, -5.4, 213.1, -31.4],
|
||||||
|
[-2.4, 65.1, 9.2, -146.2, -24.2]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-72.6, -63.9, -61.9, 45.3, 33.0],
|
||||||
|
[79.3, -0.5, -26.2, 78.2, 42.7],
|
||||||
|
[90.9, 141.6, 40.1, -62.7, 37.0],
|
||||||
|
[32.8, 198.2, -0.8, -31.1, 27.3],
|
||||||
|
// torch gets 48.0
|
||||||
|
[34.5, 34.9, -47.9, 127.6, -12.3],
|
||||||
|
[-61.4, -3.2, -2.9, -10.9, -16.6],
|
||||||
|
[74.6, 60.1, -68.9, 34.5, -50.4]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[37.5, -56.9, -43.6, -13.5, -9.9],
|
||||||
|
[40.0, 97.3, 28.6, 14.2, -30.1],
|
||||||
|
[-22.3, -126.3, -68.8, -8.2, 26.1],
|
||||||
|
[-32.9, 37.3, 108.5, -54.8, 29.6],
|
||||||
|
[34.9, -176.9, -125.0, -28.3, -13.9],
|
||||||
|
[-54.9, 142.6, 62.1, -80.4, -65.6],
|
||||||
|
[7.4, -91.1, -67.6, 35.0, 39.7]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-57.2, -40.9, -10.1, 32.6, 29.4],
|
||||||
|
[18.7, -18.0, 29.5, -1.2, 59.2],
|
||||||
|
[-14.0, -74.4, 19.8, -117.0, 58.2],
|
||||||
|
[-21.8, 163.5, -71.1, -99.0, 80.9],
|
||||||
|
[-58.9, -10.9, 93.8, -139.6, 98.0],
|
||||||
|
// torch gets 54.5
|
||||||
|
[-54.4, 135.3, 6.0, -79.1, 134.6],
|
||||||
|
[27.5, -76.0, 43.4, -2.8, -7.8]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test the same, but then with the following properties, t & w are unmodified.
|
||||||
|
let padding = 1;
|
||||||
|
let outpadding = 1;
|
||||||
|
let dilation = 1;
|
||||||
|
let stride = 2;
|
||||||
|
|
||||||
|
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
|
||||||
|
let loss = res.sqr()?.sum_all()?;
|
||||||
|
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 3627.0); // torch gives 3626.8560
|
||||||
|
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
|
||||||
|
let grad_t = grads.get(&t).unwrap();
|
||||||
|
let grad_w = grads.get(&w).unwrap();
|
||||||
|
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
|
||||||
|
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[ 13.2, -40.7, -9.7, -47.3, -82.7],
|
||||||
|
[ -98.2, 9.7, 57.7, -6.2, 180.7],
|
||||||
|
[ 100.2, 24.1, 3.7, -100.5, -48.1],
|
||||||
|
[ -0.3, 13.5, -2.9, 80.0, -49.8],
|
||||||
|
[ 47.2, -25.6, -74.4, 61.2, -18.4],
|
||||||
|
[ 4.6, -69.5, 27.9, 66.5, -88.1],
|
||||||
|
// 4th column on next row; torch is 4.2
|
||||||
|
[ -12.0, 79.2, -40.0, 4.1, -97.1],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[ -42.2, -36.5, -51.1, 7.5, 32.3],
|
||||||
|
[ 74.1, -44.6, -68.8, 19.5, 7.7],
|
||||||
|
[ 137.1, 54.2, 153.8, -58.0, 45.5],
|
||||||
|
[ 24.4, -56.8, 9.7, -41.0, -14.5],
|
||||||
|
[ -3.7, 72.6, 8.3, 134.8, 40.5],
|
||||||
|
[ 43.2, -56.9, -47.5, -89.4, -95.4],
|
||||||
|
[ 68.2, 108.1, -80.0, 57.0, -121.1]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[ 31.1, -11.4, -34.8, 33.1, -44.2],
|
||||||
|
[ 29.4, -31.6, -40.2, 13.7, 13.1],
|
||||||
|
[ -0.8, -83.8, -7.8, -17.3, 78.2],
|
||||||
|
[ 12.0, -118.7, 137.5, -76.7, 50.8],
|
||||||
|
[ -28.7, -114.2, -3.7, -96.3, -13.8],
|
||||||
|
[ -31.8, 28.5, -14.3, 4.6, 13.4],
|
||||||
|
[ 28.0, -0.2, -38.9, -29.7, -59.0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[ -16.8, 38.5, 15.5, 26.6, 48.9],
|
||||||
|
[ 14.5, 49.6, -24.8, 65.6, 61.7],
|
||||||
|
[ 22.1, -64.7, -4.3, -51.0, 36.3],
|
||||||
|
[ 31.0, -88.9, 47.1, -123.5, -3.8],
|
||||||
|
[ -14.8, -39.8, 128.2, -110.3, 42.6],
|
||||||
|
// 1st column on next row; torch is -7.2
|
||||||
|
[ -7.1, 95.3, -21.3, -58.7, -13.9],
|
||||||
|
[ 26.9, 21.3, 16.1, 70.3, 32.1]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
|
||||||
|
[
|
||||||
|
// 2nd value; torch gets -3.2, 3rd value; torch gets 221.8
|
||||||
|
-2.460e+01, -3.100e+00, 2.219e+02, 7.400e+00, 5.620e+01,
|
||||||
|
7.420e+01, 7.830e+01, 8.900e+00, 1.050e+01, 2.810e+01,
|
||||||
|
5.100e+00, -1.046e+02, -1.572e+02, 8.710e+01, -9.840e+01,
|
||||||
|
-4.230e+01, -1.898e+02, 1.860e+01, -3.570e+01, 9.810e+01,
|
||||||
|
4.680e+01, 1.182e+02, 4.020e+01, -1.900e+00, 1.508e+02,
|
||||||
|
1.094e+02, 1.018e+02, -4.620e+01, 1.591e+02, -2.320e+01,
|
||||||
|
// 5th value; torch gets 7.1
|
||||||
|
-8.450e+01, -4.600e+00, 6.330e+01, 1.123e+02, -7.000e+00,
|
||||||
|
1.101e+02, -6.620e+01, 2.090e+01, -5.120e+01, 8.990e+01,
|
||||||
|
9.050e+01, -6.990e+01, 6.800e+01, -9.250e+01, 1.380e+02,
|
||||||
|
4.720e+01, 4.710e+01, 6.210e+01, 8.870e+01, 2.098e+02,
|
||||||
|
3.870e+01, -1.390e+01, 6.270e+01, 1.484e+02, -9.920e+01,
|
||||||
|
-4.200e+01, -1.505e+02, -1.480e+01, -2.620e+01, 8.220e+01,
|
||||||
|
-3.350e+01, -2.260e+01, -1.198e+02, -5.080e+01, 1.259e+02,
|
||||||
|
5.600e+01, 9.270e+01, 1.209e+02, 6.590e+01, -8.330e+01,
|
||||||
|
7.000e+00, -2.600e+01, -1.133e+02, 3.870e+01, 4.020e+01,
|
||||||
|
-6.300e+00, -8.710e+01, -5.150e+01, -8.510e+01, 2.000e-01,
|
||||||
|
3.640e+01, -6.100e+00, 6.590e+01, -2.700e+00, 6.550e+01,
|
||||||
|
// 4th value; torch gets 3.8
|
||||||
|
5.300e+00, -6.760e+01, -4.270e+01, -3.900e+00, 2.880e+01,
|
||||||
|
5.260e+01, 6.170e+01, -1.203e+02, -1.610e+01, 7.740e+01,
|
||||||
|
-1.008e+02, -1.070e+01, -9.900e+00, 3.300e+00, -2.620e+01,
|
||||||
|
-4.440e+01, 2.580e+01, -6.920e+01, -4.220e+01, 1.108e+02,
|
||||||
|
1.240e+01, -3.440e+01, -2.800e+00, 7.880e+01, -6.690e+01,
|
||||||
|
1.480e+01, 2.310e+01, -4.260e+01, -1.500e+00, -4.760e+01,
|
||||||
|
5.350e+01, -2.260e+01, 8.000e-01, -3.840e+01, -2.500e+00
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,3 +112,34 @@ fn custom_op1_with_backward() -> Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl candle_core::InplaceOp1 for Elu {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"elu"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(&self, s: &mut CpuStorage, _l: &Layout) -> Result<()> {
|
||||||
|
let alpha = self.alpha;
|
||||||
|
match s {
|
||||||
|
CpuStorage::BF16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
||||||
|
CpuStorage::F16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
||||||
|
CpuStorage::F32(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
||||||
|
CpuStorage::F64(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
||||||
|
_ => candle_core::bail!("unsupported dtype for inplace elu"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inplace_op1() -> Result<()> {
|
||||||
|
let cpu = &Device::Cpu;
|
||||||
|
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
|
||||||
|
let t = (t - 5.)?;
|
||||||
|
t.inplace_op1(&Elu { alpha: 1. })?;
|
||||||
|
assert_eq!(
|
||||||
|
to_vec1_round(&t, 4)?,
|
||||||
|
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
BIN
candle-core/tests/fortran_tensor_3d.pth
Normal file
BIN
candle-core/tests/fortran_tensor_3d.pth
Normal file
Binary file not shown.
@ -1,3 +1,4 @@
|
|||||||
|
#![allow(clippy::approx_constant)]
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
||||||
|
|
||||||
@ -96,24 +97,24 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
y.to_vec1::<f32>()?,
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
[20.0855, 2.7183, 54.5982, 1.1618]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
grad_x.to_vec1::<f32>()?,
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
[20.0855, 2.7183, 54.5982, 1.1618]
|
||||||
);
|
);
|
||||||
let y = x.exp()?.sqr()?;
|
let y = x.exp()?.sqr()?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
y.to_vec1::<f32>()?,
|
test_utils::to_vec1_round(&y, 3)?,
|
||||||
[403.4288, 7.3890557, 2980.9578, 1.3498588]
|
[403.429, 7.389, 2980.958, 1.35]
|
||||||
);
|
);
|
||||||
// exp(x)^2 = exp(2*x)
|
// exp(x)^2 = exp(2*x)
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
grad_x.to_vec1::<f32>()?,
|
test_utils::to_vec1_round(grad_x, 2)?,
|
||||||
[806.8576, 14.778111, 5961.9155, 2.6997175]
|
[806.86, 14.78, 5961.92, 2.7]
|
||||||
);
|
);
|
||||||
let y = x.sin()?;
|
let y = x.sin()?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
@ -261,6 +262,7 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
let y = elu_x.elu(2.)?;
|
let y = elu_x.elu(2.)?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(&y, 4)?,
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||||
@ -270,19 +272,51 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// testing compared to pytorch nn.Silu()
|
||||||
|
let y = x.silu()?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[2.8577, 0.7311, 3.9281, 0.0806]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||||
|
);
|
||||||
|
|
||||||
|
if device.is_cpu() {
|
||||||
|
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
|
||||||
|
let y = x.interpolate1d(12)?.reshape(36)?;
|
||||||
|
|
||||||
|
let z = Tensor::new(
|
||||||
|
&[
|
||||||
|
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16.,
|
||||||
|
17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
|
||||||
|
33., 34., 35., 36.,
|
||||||
|
],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(grad_x, 4)?,
|
||||||
|
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// manually checked: see comments
|
// manually checked: see comments
|
||||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let z = Tensor::new(
|
let z = Tensor::new(
|
||||||
&[
|
&[
|
||||||
1_f32, 02., 03., 04., 05., 06.,
|
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||||
07., 08., 09., 10., 11., 12.,
|
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
||||||
13., 14., 15., 16., 17., 18.,
|
35., 36.,
|
||||||
19., 20., 21., 22., 23., 24.,
|
|
||||||
25., 26., 27., 28., 29., 30.,
|
|
||||||
31., 32., 33., 34., 35., 36.,
|
|
||||||
],
|
],
|
||||||
device,
|
device,
|
||||||
)?;
|
)?;
|
||||||
@ -313,15 +347,11 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
||||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let z = Tensor::new(
|
let z = Tensor::new(
|
||||||
&[
|
&[
|
||||||
1_f32, 02., 03., 04., 05., 06.,
|
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||||
07., 08., 09., 10., 11., 12.,
|
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
||||||
13., 14., 15., 16., 17., 18.,
|
35., 36.,
|
||||||
19., 20., 21., 22., 23., 24.,
|
|
||||||
25., 26., 27., 28., 29., 30.,
|
|
||||||
31., 32., 33., 34., 35., 36.,
|
|
||||||
],
|
],
|
||||||
device,
|
device,
|
||||||
)?;
|
)?;
|
||||||
|
@ -88,7 +88,7 @@ fn strided_blocks() -> Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||||
let tensor = tensor.i((.., 1))?;
|
let tensor = tensor.i((.., 1))?.contiguous()?;
|
||||||
match tensor.strided_blocks() {
|
match tensor.strided_blocks() {
|
||||||
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
assert_eq!(start_offset, 0);
|
assert_eq!(start_offset, 0);
|
||||||
@ -100,6 +100,20 @@ fn strided_blocks() -> Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||||
|
let tensor = tensor.i((.., 1))?;
|
||||||
|
match tensor.strided_blocks() {
|
||||||
|
candle::StridedBlocks::SingleBlock { .. } => {
|
||||||
|
panic!("unexpected block structure")
|
||||||
|
}
|
||||||
|
candle::StridedBlocks::MultipleBlocks {
|
||||||
|
block_len,
|
||||||
|
block_start_index,
|
||||||
|
} => {
|
||||||
|
assert_eq!(block_len, 4);
|
||||||
|
assert_eq!(block_start_index.collect::<Vec<_>>(), &[4, 16])
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||||
match tensor.t()?.strided_blocks() {
|
match tensor.t()?.strided_blocks() {
|
||||||
candle::StridedBlocks::SingleBlock { .. } => {
|
candle::StridedBlocks::SingleBlock { .. } => {
|
||||||
panic!("unexpected block structure")
|
panic!("unexpected block structure")
|
||||||
|
126
candle-core/tests/matmul_tests.rs
Normal file
126
candle-core/tests/matmul_tests.rs
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
||||||
|
|
||||||
|
fn matmul(device: &Device) -> Result<()> {
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||||
|
|
||||||
|
let data = vec![1.0f32, 2.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||||
|
let data = vec![3.0f32, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||||
|
|
||||||
|
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||||
|
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
||||||
|
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||||
|
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||||
|
|
||||||
|
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
||||||
|
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||||
|
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
||||||
|
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
||||||
|
|
||||||
|
// Also perform the matmul on contiguous transposed versions.
|
||||||
|
let a_tt = a.t()?.contiguous()?.t()?;
|
||||||
|
assert!(!a_tt.is_contiguous());
|
||||||
|
assert_eq!(a.dims(), a_tt.dims());
|
||||||
|
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||||
|
|
||||||
|
let b_tt = b.t()?.contiguous()?.t()?;
|
||||||
|
assert!(!b_tt.is_contiguous());
|
||||||
|
assert_eq!(b.dims(), b_tt.dims());
|
||||||
|
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||||
|
|
||||||
|
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||||
|
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||||
|
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matmul_bf16(device: &Device) -> Result<()> {
|
||||||
|
if !device.supports_bf16() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?;
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?;
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?.to_dtype(DType::F32)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||||
|
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||||
|
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||||
|
let out = lhs.broadcast_matmul(&rhs)?;
|
||||||
|
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||||
|
for idx1 in 0..3 {
|
||||||
|
for idx2 in 0..6 {
|
||||||
|
let out = out.i((idx1, idx2))?;
|
||||||
|
let lhs = lhs.i((idx1, 0))?;
|
||||||
|
let rhs = rhs.i(idx2)?;
|
||||||
|
let out2 = lhs.matmul(&rhs);
|
||||||
|
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||||
|
// With cuda, we see errors of up to ~1e-12.
|
||||||
|
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/candle/issues/1948
|
||||||
|
fn squeeze_mm(device: &Device) -> Result<()> {
|
||||||
|
let seq_len = 8_usize;
|
||||||
|
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
|
||||||
|
let x = a.i((.., seq_len - 1, ..))?;
|
||||||
|
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
|
||||||
|
let x = x.matmul(&w)?;
|
||||||
|
assert_eq!(x.dims(), &[1, 32]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/candle/issues/1992
|
||||||
|
fn mm_layout(device: &Device) -> Result<()> {
|
||||||
|
let a = Tensor::arange(0f32, 16f32, device)?.reshape((1, 1, 4, 4))?;
|
||||||
|
let b = Tensor::arange(0f32, 8f32, device)?.reshape((1, 1, 4, 2))?;
|
||||||
|
let mm1 = a.matmul(&b)?;
|
||||||
|
// Forces the layout to be:
|
||||||
|
// shape: [1, 1, 4, 2], stride: [8, 2, 2, 1], start_offset: 0
|
||||||
|
// This is still a contiguous matrix but matmul checks are only the two last dimensions have
|
||||||
|
// non 1 sizes but matmul check may be reluctant to handle it.
|
||||||
|
let b = b.transpose(1, 2)?.force_contiguous()?.transpose(1, 2)?;
|
||||||
|
let mm2 = a.matmul(&b)?;
|
||||||
|
let diff = (mm1 - mm2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||||
|
test_device!(
|
||||||
|
matmul_bf16,
|
||||||
|
matmul_bf16_cpu,
|
||||||
|
matmul_bf16_gpu,
|
||||||
|
matmul_bf16_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
broadcast_matmul,
|
||||||
|
broadcast_matmul_cpu,
|
||||||
|
broadcast_matmul_gpu,
|
||||||
|
broadcast_matmul_metal
|
||||||
|
);
|
||||||
|
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
|
||||||
|
test_device!(mm_layout, mm_layout_cpu, mm_layout_gpu, mm_layout_metal);
|
@ -43,6 +43,9 @@ res = torch.nn.functional.avg_pool2d(t, 2)
|
|||||||
print(res)
|
print(res)
|
||||||
*/
|
*/
|
||||||
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||||
|
if dev.is_metal() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let t = Tensor::new(
|
let t = Tensor::new(
|
||||||
&[
|
&[
|
||||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||||
|
37
candle-core/tests/pth.py
Normal file
37
candle-core/tests/pth.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import torch
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
# Write a trivial tensor to a pt file
|
||||||
|
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
|
||||||
|
o = OrderedDict()
|
||||||
|
o["test"] = a
|
||||||
|
|
||||||
|
# Write a trivial tensor to a pt file
|
||||||
|
torch.save(o, "test.pt")
|
||||||
|
|
||||||
|
############################################################################################################
|
||||||
|
# Write a trivial tensor to a pt file with a key
|
||||||
|
torch.save({"model_state_dict": o}, "test_with_key.pt")
|
||||||
|
|
||||||
|
############################################################################################################
|
||||||
|
# Create a tensor with fortran contiguous memory layout
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
|
||||||
|
# For example, creating a 2x3x4 array
|
||||||
|
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
|
||||||
|
|
||||||
|
# Verify the memory order
|
||||||
|
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
|
||||||
|
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
|
||||||
|
|
||||||
|
# Step 2: Convert the NumPy array to a PyTorch tensor
|
||||||
|
tensor_fortran = torch.from_numpy(array_fortran)
|
||||||
|
|
||||||
|
# Verify the tensor layout
|
||||||
|
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
|
||||||
|
|
||||||
|
# Step 3: Save the PyTorch tensor to a .pth file
|
||||||
|
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
|
||||||
|
|
||||||
|
print("3D Tensor saved with Fortran layout.")
|
31
candle-core/tests/pth_tests.rs
Normal file
31
candle-core/tests/pth_tests.rs
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/// Regression test for pth files not loading on Windows.
|
||||||
|
#[test]
|
||||||
|
fn test_pth() {
|
||||||
|
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
|
||||||
|
tensors.get("test").unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pth_with_key() {
|
||||||
|
let tensors =
|
||||||
|
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
|
||||||
|
.unwrap();
|
||||||
|
tensors.get("test").unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pth_fortran_congiguous() {
|
||||||
|
let tensors =
|
||||||
|
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
|
||||||
|
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
tensor.to_vec3::<i64>().unwrap(),
|
||||||
|
[
|
||||||
|
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||||
|
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
@ -3,7 +3,7 @@ use candle_core::{
|
|||||||
quantized::{self, GgmlDType},
|
quantized::{self, GgmlDType},
|
||||||
test_device,
|
test_device,
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, Module, Result, Tensor,
|
DType, Device, IndexOp, Module, Result, Tensor,
|
||||||
};
|
};
|
||||||
use quantized::{k_quants, GgmlType};
|
use quantized::{k_quants, GgmlType};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
@ -47,18 +47,14 @@ fn test_matmul(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantized_matmul(device: &Device) -> Result<()> {
|
fn quantized_matmul(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let (m, k, n) = (3, 64, 4);
|
let (m, k, n) = (3, 64, 4);
|
||||||
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
let lhs_s = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
|
||||||
let mut dst = vec![42.; 3 * 4];
|
let mut dst = vec![42.; 3 * 4];
|
||||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||||
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||||
&[
|
&[
|
||||||
@ -67,7 +63,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
let mm = lhs.matmul(&tensor_rhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
mm.to_vec2::<f32>()?,
|
mm.to_vec2::<f32>()?,
|
||||||
&[
|
&[
|
||||||
@ -79,7 +75,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
|
|
||||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&lhs)?;
|
||||||
match device {
|
match device {
|
||||||
Device::Metal(_) => assert_eq!(
|
Device::Metal(_) => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
@ -89,7 +85,15 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
_ => assert_eq!(
|
Device::Cuda(_) => assert_eq!(
|
||||||
|
to_vec2_round(&res, 0)?,
|
||||||
|
&[
|
||||||
|
[84866.0, 214045.0, 344676.0, 473707.0],
|
||||||
|
[213425.0, 604313.0, 1000431.0, 1387960.0],
|
||||||
|
[342030.0, 994630.0, 1656248.0, 2302250.0]
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Device::Cpu => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
&[
|
&[
|
||||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||||
@ -98,22 +102,16 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let (m, k, n) = (3, 64, 4);
|
let (m, k, n) = (3, 64, 4);
|
||||||
let lhs = (0..(m * k))
|
let lhs_s = (0..(m * k))
|
||||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
|
||||||
let mut dst = vec![42.; 3 * 4];
|
let mut dst = vec![42.; 3 * 4];
|
||||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||||
let rhs = (0..k * n)
|
let rhs = (0..k * n)
|
||||||
@ -121,7 +119,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||||
&[
|
&[
|
||||||
@ -129,7 +127,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
-196472.0, 63012.0, 324585.0, 587902.0
|
-196472.0, 63012.0, 324585.0, 587902.0
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
let mm = lhs.matmul(&tensor_rhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
to_vec2_round(&mm, 0)?,
|
to_vec2_round(&mm, 0)?,
|
||||||
&[
|
&[
|
||||||
@ -141,7 +139,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
|
|
||||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&lhs)?;
|
||||||
match device {
|
match device {
|
||||||
Device::Metal(_) => assert_eq!(
|
Device::Metal(_) => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
@ -151,7 +149,15 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
[-196102.0, 63022.0, 324233.0, 587191.0]
|
[-196102.0, 63022.0, 324233.0, 587191.0]
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
_ => assert_eq!(
|
Device::Cuda(_) => assert_eq!(
|
||||||
|
to_vec2_round(&res, 0)?,
|
||||||
|
&[
|
||||||
|
[243740.0, -19762.0, -285476.0, -550498.0],
|
||||||
|
[23774.0, 21645.0, 19395.0, 18364.0],
|
||||||
|
[-196045.0, 63030.0, 324120.0, 587079.0]
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Device::Cpu => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
&[
|
&[
|
||||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||||
@ -160,33 +166,72 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
let lhs2 = Tensor::stack(&[&lhs, &lhs], 0)?;
|
||||||
|
let res2 = matmul.forward(&lhs2)?;
|
||||||
|
let res2 = res2.i(1)?;
|
||||||
|
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
if device.is_cuda() {
|
||||||
|
assert!(diff < 0.1);
|
||||||
|
} else {
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(
|
fn qmm_batch(dev: &Device) -> Result<()> {
|
||||||
quantized_matmul,
|
let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?;
|
||||||
quantized_matmul_cpu,
|
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
|
||||||
quantized_matmul_cuda,
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
quantized_matmul_metal
|
let mm = rhs.forward(&lhs)?;
|
||||||
);
|
assert_eq!(mm.shape().dims(), [2, 6]);
|
||||||
test_device!(
|
let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?;
|
||||||
quantized_matmul_neg,
|
let mm2 = rhs.forward(&lhs2)?;
|
||||||
quantized_matmul_neg_cpu,
|
assert_eq!(mm2.shape().dims(), [4, 6]);
|
||||||
quantized_matmul_neg_cuda,
|
let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
quantized_matmul_neg_metal
|
assert_eq!(diff2, 0.0);
|
||||||
);
|
let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?;
|
||||||
|
let mm3 = rhs.forward(&lhs3)?;
|
||||||
|
assert_eq!(mm3.shape().dims(), [6, 6]);
|
||||||
|
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff3, 0.0);
|
||||||
|
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff3, 0.0);
|
||||||
|
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
|
||||||
|
let mm4 = rhs.forward(&lhs4)?;
|
||||||
|
assert_eq!(mm4.shape().dims(), [12, 6]);
|
||||||
|
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
if dev.is_cuda() {
|
||||||
|
// We use a different kernel for sizes from 1 to 8 on cuda which explains
|
||||||
|
// the difference here.
|
||||||
|
assert!(0. < diff4 && diff4 < 1e-4)
|
||||||
|
} else {
|
||||||
|
assert_eq!(diff4, 0.0)
|
||||||
|
};
|
||||||
|
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff4, 0.0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
|
||||||
|
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
|
||||||
|
test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
|
||||||
|
|
||||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
|
|
||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst.to_vec1::<f32>()?,
|
dst.to_vec1::<f32>()?,
|
||||||
&[
|
&[
|
||||||
@ -209,14 +254,17 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q4_1(device: &Device) -> Result<()> {
|
fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst.to_vec1::<f32>()?),
|
round_vector(&dst.to_vec1::<f32>()?),
|
||||||
&[
|
&[
|
||||||
@ -239,14 +287,17 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q5_0(device: &Device) -> Result<()> {
|
fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst.to_vec1::<f32>()?),
|
round_vector(&dst.to_vec1::<f32>()?),
|
||||||
&[
|
&[
|
||||||
@ -269,14 +320,17 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q5_1(device: &Device) -> Result<()> {
|
fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst.to_vec1::<f32>()?),
|
round_vector(&dst.to_vec1::<f32>()?),
|
||||||
&[
|
&[
|
||||||
@ -361,6 +415,13 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
|||||||
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
||||||
if error > max_error {
|
if error > max_error {
|
||||||
bail!(
|
bail!(
|
||||||
@ -373,15 +434,18 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q2k(device: &Device) -> Result<()> {
|
fn quantize_q2k(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let dtype = GgmlDType::Q2K;
|
let dtype = GgmlDType::Q2K;
|
||||||
|
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -401,6 +465,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -411,14 +482,17 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q3k(device: &Device) -> Result<()> {
|
fn quantize_q3k(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let dtype = GgmlDType::Q3K;
|
let dtype = GgmlDType::Q3K;
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -438,6 +512,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -448,14 +529,17 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q4k(device: &Device) -> Result<()> {
|
fn quantize_q4k(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let dtype = GgmlDType::Q4K;
|
let dtype = GgmlDType::Q4K;
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -475,6 +559,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -485,14 +576,17 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q5k(device: &Device) -> Result<()> {
|
fn quantize_q5k(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let dtype = GgmlDType::Q5K;
|
let dtype = GgmlDType::Q5K;
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -512,6 +606,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -522,14 +623,17 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q6k(device: &Device) -> Result<()> {
|
fn quantize_q6k(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let dtype = GgmlDType::Q6K;
|
let dtype = GgmlDType::Q6K;
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -549,6 +653,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -559,14 +670,17 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q8k(device: &Device) -> Result<()> {
|
fn quantize_q8k(device: &Device) -> Result<()> {
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let dtype = GgmlDType::Q8K;
|
let dtype = GgmlDType::Q8K;
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -586,6 +700,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -778,10 +899,6 @@ macro_rules! quantized_matmul {
|
|||||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||||
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
||||||
fn $fn_name(device: &Device) -> Result<()> {
|
fn $fn_name(device: &Device) -> Result<()> {
|
||||||
if device.is_cuda() {
|
|
||||||
// TODO Enable Cuda GGML sometime maybe.
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,31 @@
|
|||||||
use candle_core::{DType, Result, Tensor};
|
use candle_core::{DType, Result, Tensor};
|
||||||
|
|
||||||
|
struct TmpFile(std::path::PathBuf);
|
||||||
|
|
||||||
|
impl TmpFile {
|
||||||
|
fn create(base: &str) -> TmpFile {
|
||||||
|
let filename = std::env::temp_dir().join(format!(
|
||||||
|
"candle-{}-{}-{:?}",
|
||||||
|
base,
|
||||||
|
std::process::id(),
|
||||||
|
std::thread::current().id(),
|
||||||
|
));
|
||||||
|
TmpFile(filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::AsRef<std::path::Path> for TmpFile {
|
||||||
|
fn as_ref(&self) -> &std::path::Path {
|
||||||
|
self.0.as_path()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for TmpFile {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
std::fs::remove_file(&self.0).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn npy() -> Result<()> {
|
fn npy() -> Result<()> {
|
||||||
let npy = Tensor::read_npy("tests/test.npy")?;
|
let npy = Tensor::read_npy("tests/test.npy")?;
|
||||||
@ -22,3 +48,24 @@ fn npz() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn safetensors() -> Result<()> {
|
||||||
|
use candle_core::safetensors::Load;
|
||||||
|
|
||||||
|
let tmp_file = TmpFile::create("st");
|
||||||
|
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
|
||||||
|
t.save_safetensors("t", &tmp_file)?;
|
||||||
|
// Load from file.
|
||||||
|
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
|
||||||
|
let t2 = st.get("t").unwrap();
|
||||||
|
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0f32);
|
||||||
|
// Load from bytes.
|
||||||
|
let bytes = std::fs::read(tmp_file)?;
|
||||||
|
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
|
||||||
|
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
|
||||||
|
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0f32);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -96,6 +96,40 @@ fn clamp(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn asort(device: &Device) -> Result<()> {
|
||||||
|
let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
|
||||||
|
let tensor = Tensor::new(data, device)?;
|
||||||
|
let indexes = tensor.arg_sort_last_dim(true)?;
|
||||||
|
assert_eq!(
|
||||||
|
indexes.to_vec2::<u32>()?,
|
||||||
|
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||||
|
);
|
||||||
|
let indexes = tensor.arg_sort_last_dim(false)?;
|
||||||
|
assert_eq!(
|
||||||
|
indexes.to_vec2::<u32>()?,
|
||||||
|
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||||
|
);
|
||||||
|
let (sorted, indexes) = tensor.sort_last_dim(true)?;
|
||||||
|
assert_eq!(
|
||||||
|
indexes.to_vec2::<u32>()?,
|
||||||
|
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
sorted.to_vec2::<f32>()?,
|
||||||
|
[[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]]
|
||||||
|
);
|
||||||
|
let (sorted, indexes) = tensor.sort_last_dim(false)?;
|
||||||
|
assert_eq!(
|
||||||
|
indexes.to_vec2::<u32>()?,
|
||||||
|
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
sorted.to_vec2::<f32>()?,
|
||||||
|
[[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn unary_op(device: &Device) -> Result<()> {
|
fn unary_op(device: &Device) -> Result<()> {
|
||||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
@ -106,6 +140,9 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
|
||||||
|
let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
|
||||||
|
assert!(max_diff.to_vec0::<f32>()? < 5e-3);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||||
[
|
[
|
||||||
@ -120,6 +157,13 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
|
||||||
|
[
|
||||||
|
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
|
||||||
|
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
|
||||||
|
]
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||||
@ -141,6 +185,27 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||||
[3000.0, 300.]
|
[3000.0, 300.]
|
||||||
);
|
);
|
||||||
|
let tensor = Tensor::new(
|
||||||
|
&[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
assert_eq!(
|
||||||
|
tensor.sign()?.to_vec1::<f32>()?,
|
||||||
|
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
|
||||||
|
);
|
||||||
|
let tensor = Tensor::new(&[-1.0f32, 0., -2., 3.], device)?;
|
||||||
|
let y = tensor.elu(2.)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||||
|
);
|
||||||
|
// This test failed on metal prior to the following PR:
|
||||||
|
// https://github.com/huggingface/candle/pull/2490
|
||||||
|
let y = tensor.reshape((2, 2))?.t()?.elu(2.)?.flatten_all()?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[-1.2642, -1.7293, 0.0000, 3.0000]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -613,6 +678,30 @@ fn broadcast(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn slice_set(device: &Device) -> Result<()> {
|
||||||
|
let (b, h, max_t, d) = (2, 4, 7, 3);
|
||||||
|
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
|
||||||
|
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
|
||||||
|
cache.slice_set(&tensor, 2, 0)?;
|
||||||
|
let cache_t = cache.narrow(2, 0, 4)?;
|
||||||
|
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
cache.slice_set(&tensor, 2, 1)?;
|
||||||
|
let cache_t = cache.narrow(2, 1, 4)?;
|
||||||
|
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
|
||||||
|
cache.slice_set(&ones, 2, 6)?;
|
||||||
|
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
let diff = (cache.narrow(2, 6, 1)? - 1.)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn cat(device: &Device) -> Result<()> {
|
fn cat(device: &Device) -> Result<()> {
|
||||||
// 1D
|
// 1D
|
||||||
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||||
@ -665,6 +754,31 @@ fn cat(device: &Device) -> Result<()> {
|
|||||||
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// 3D
|
||||||
|
let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?;
|
||||||
|
let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?;
|
||||||
|
let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?;
|
||||||
|
|
||||||
|
let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
||||||
|
|
||||||
|
let t1 = t1.t()?.contiguous()?.t()?;
|
||||||
|
let t2 = t2.t()?.contiguous()?.t()?;
|
||||||
|
let t3 = t3.t()?.contiguous()?.t()?;
|
||||||
|
let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
||||||
|
|
||||||
|
let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?;
|
||||||
|
assert_eq!(diff.to_vec0::<f32>()?, 104.0);
|
||||||
|
assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::<i64>()?, 0);
|
||||||
|
assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::<i64>()?, 16);
|
||||||
|
assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::<i64>()?, 20);
|
||||||
|
assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::<i64>()?, 44);
|
||||||
|
assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::<i64>()?, 100);
|
||||||
|
assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::<i64>()?, 112);
|
||||||
|
assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::<i64>()?, 101);
|
||||||
|
assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::<i64>()?, 105);
|
||||||
|
assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::<i64>()?, 10013);
|
||||||
|
assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::<i64>()?, 10031);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -675,6 +789,8 @@ fn embeddings(device: &Device) -> Result<()> {
|
|||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
let hs = t.index_select(&ids, 0)?;
|
let hs = t.index_select(&ids, 0)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
|
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
||||||
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -702,44 +818,47 @@ fn index_select(device: &Device) -> Result<()> {
|
|||||||
[9.0, 10.0, 11.0]
|
[9.0, 10.0, 11.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
let hs = t.index_select(&ids, 1)?;
|
for dtype in [DType::U8, DType::U32, DType::I64] {
|
||||||
assert_eq!(
|
let ids = ids.to_dtype(dtype)?;
|
||||||
hs.to_vec2::<f32>()?,
|
let hs = t.index_select(&ids, 1)?;
|
||||||
&[
|
assert_eq!(
|
||||||
[0.0, 2.0, 1.0],
|
hs.to_vec2::<f32>()?,
|
||||||
[3.0, 5.0, 4.0],
|
&[
|
||||||
[6.0, 8.0, 7.0],
|
[0.0, 2.0, 1.0],
|
||||||
[9.0, 11.0, 10.0]
|
[3.0, 5.0, 4.0],
|
||||||
]
|
[6.0, 8.0, 7.0],
|
||||||
);
|
[9.0, 11.0, 10.0]
|
||||||
let hs = t.index_select(&ids, 0)?;
|
]
|
||||||
assert_eq!(
|
);
|
||||||
hs.to_vec2::<f32>()?,
|
let hs = t.index_select(&ids, 0)?;
|
||||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
assert_eq!(
|
||||||
);
|
hs.to_vec2::<f32>()?,
|
||||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
);
|
||||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||||
let hs = t.index_select(&ids, 0)?;
|
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||||
assert_eq!(
|
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||||
hs.to_vec2::<f32>()?,
|
let hs = t.index_select(&ids, 0)?;
|
||||||
&[
|
assert_eq!(
|
||||||
[0.0, 1.0, 2.0],
|
hs.to_vec2::<f32>()?,
|
||||||
[6.0, 7.0, 8.0],
|
&[
|
||||||
[3.0, 4.0, 5.0],
|
[0.0, 1.0, 2.0],
|
||||||
[0.0, 1.0, 2.0],
|
[6.0, 7.0, 8.0],
|
||||||
[6.0, 7.0, 8.0],
|
[3.0, 4.0, 5.0],
|
||||||
[3.0, 4.0, 5.0],
|
[0.0, 1.0, 2.0],
|
||||||
]
|
[6.0, 7.0, 8.0],
|
||||||
);
|
[3.0, 4.0, 5.0],
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
// Test when selecting dim > 0 with ids size different from elem count of
|
// Test when selecting dim > 0 with ids size different from elem count of
|
||||||
// target dim in source/input.
|
// target dim in source/input.
|
||||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||||
let hs = t.index_select(&ids, 1)?;
|
let hs = t.index_select(&ids, 1)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -901,74 +1020,6 @@ fn gather(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matmul(device: &Device) -> Result<()> {
|
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
|
||||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
|
||||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
|
||||||
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
|
||||||
|
|
||||||
let data = vec![1.0f32, 2.0];
|
|
||||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
|
||||||
let data = vec![3.0f32, 4.0];
|
|
||||||
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
|
||||||
|
|
||||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
|
||||||
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
|
||||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
|
||||||
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
|
||||||
|
|
||||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
|
||||||
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
|
||||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
|
||||||
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
|
||||||
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
|
||||||
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
|
||||||
|
|
||||||
// Also perform the matmul on contiguous transposed versions.
|
|
||||||
let a_tt = a.t()?.contiguous()?.t()?;
|
|
||||||
assert!(!a_tt.is_contiguous());
|
|
||||||
assert_eq!(a.dims(), a_tt.dims());
|
|
||||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
|
||||||
|
|
||||||
let b_tt = b.t()?.contiguous()?.t()?;
|
|
||||||
assert!(!b_tt.is_contiguous());
|
|
||||||
assert_eq!(b.dims(), b_tt.dims());
|
|
||||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
|
||||||
|
|
||||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
|
||||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
|
||||||
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
|
||||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
|
||||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
|
||||||
let out = lhs.broadcast_matmul(&rhs)?;
|
|
||||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
|
||||||
for idx1 in 0..3 {
|
|
||||||
for idx2 in 0..6 {
|
|
||||||
let out = out.i((idx1, idx2))?;
|
|
||||||
let lhs = lhs.i((idx1, 0))?;
|
|
||||||
let rhs = rhs.i(idx2)?;
|
|
||||||
let out2 = lhs.matmul(&rhs);
|
|
||||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
|
||||||
// With cuda, we see errors of up to ~1e-12.
|
|
||||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn broadcasting(device: &Device) -> Result<()> {
|
fn broadcasting(device: &Device) -> Result<()> {
|
||||||
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
||||||
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
||||||
@ -1073,8 +1124,54 @@ fn broadcasting(device: &Device) -> Result<()> {
|
|||||||
fn randn(device: &Device) -> Result<()> {
|
fn randn(device: &Device) -> Result<()> {
|
||||||
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||||
assert_eq!(tensor.dims(), [5, 3]);
|
assert_eq!(tensor.dims(), [5, 3]);
|
||||||
|
// Check that the seed gets updated by checking that
|
||||||
|
// a new series of numbers is generated each time
|
||||||
|
let tensor2 = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||||
|
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
||||||
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||||
assert_eq!(tensor.dims(), [5, 3]);
|
assert_eq!(tensor.dims(), [5, 3]);
|
||||||
|
// Check that the seed gets updated by checking that
|
||||||
|
// a new series of numbers is generated each time
|
||||||
|
let tensor2 = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||||
|
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
||||||
|
// We do not expect deterministic elements at any index.
|
||||||
|
// There once was a bug that had a deterministic zero element in evenly sized tensors.
|
||||||
|
const N: usize = 2;
|
||||||
|
let v = (0..100)
|
||||||
|
.map(|_| Tensor::randn(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
assert!(
|
||||||
|
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
||||||
|
"There are deterministic values in the randn tensors"
|
||||||
|
);
|
||||||
|
let v = (0..100)
|
||||||
|
.map(|_| Tensor::rand(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
assert!(
|
||||||
|
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
||||||
|
"There are deterministic values in the rand tensors"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zero_dim(device: &Device) -> Result<()> {
|
||||||
|
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
|
||||||
|
assert_eq!(t.dims3()?, (4, 0, 1));
|
||||||
|
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
|
||||||
|
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
|
||||||
|
assert_eq!(t_cat.dims3()?, (4, 3, 1));
|
||||||
|
let t_cat = Tensor::cat(&[&t, &t], 1)?;
|
||||||
|
assert_eq!(t_cat.dims3()?, (4, 0, 1));
|
||||||
|
let t_unary = t.sqrt()?;
|
||||||
|
assert_eq!(t_unary.dims3()?, (4, 0, 1));
|
||||||
|
let t_plus = (&t + 1.)?;
|
||||||
|
assert_eq!(t_plus.dims3()?, (4, 0, 1));
|
||||||
|
let t_mm = t2.matmul(&t.t()?)?;
|
||||||
|
assert_eq!(t_mm.dims3()?, (4, 3, 0));
|
||||||
|
let t_mm = t.matmul(&t2.t()?)?;
|
||||||
|
assert_eq!(t_mm.dims3()?, (4, 0, 3));
|
||||||
|
let t_mm = t.t()?.matmul(&t)?;
|
||||||
|
assert_eq!(t_mm.dims3()?, (4, 1, 1));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1086,6 +1183,7 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
|||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||||
|
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
|
||||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||||
@ -1097,13 +1195,6 @@ test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
|||||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
|
||||||
test_device!(
|
|
||||||
broadcast_matmul,
|
|
||||||
broadcast_matmul_cpu,
|
|
||||||
broadcast_matmul_gpu,
|
|
||||||
broadcast_matmul_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
test_device!(
|
||||||
broadcasting,
|
broadcasting,
|
||||||
broadcasting_cpu,
|
broadcasting_cpu,
|
||||||
@ -1132,7 +1223,9 @@ test_device!(
|
|||||||
);
|
);
|
||||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||||
|
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
|
||||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||||
|
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
@ -1246,11 +1339,29 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn log_sum_exp() -> Result<()> {
|
fn log_sum_exp() -> Result<()> {
|
||||||
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
let input = Tensor::new(
|
||||||
|
&[
|
||||||
|
[[1f64, 2., 3.], [4., 5., 6.]],
|
||||||
|
[[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]],
|
||||||
|
],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
let output = input.log_sum_exp(D::Minus1)?;
|
let output = input.log_sum_exp(D::Minus1)?;
|
||||||
// The expectations obtained from pytorch.
|
// The expectations obtained from pytorch.
|
||||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?;
|
||||||
assert_close(&output, &expected, 0.00001)?;
|
assert_eq!(output.dims(), expected.dims());
|
||||||
|
assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
input.log_sum_exp((0, 1))?.to_vec1::<f64>()?,
|
||||||
|
[1000.0, 999.0, 1001.0]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
input.log_sum_exp(())?.to_vec3::<f64>()?,
|
||||||
|
input.to_vec3::<f64>()?
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1260,8 +1371,8 @@ fn pow() -> Result<()> {
|
|||||||
let rhs = (&lhs - 2.)?;
|
let rhs = (&lhs - 2.)?;
|
||||||
let res = lhs.pow(&rhs)?;
|
let res = lhs.pow(&rhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&res, 4)?,
|
test_utils::to_vec2_round(&res, 3)?,
|
||||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]]
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
BIN
candle-core/tests/test.pt
Normal file
BIN
candle-core/tests/test.pt
Normal file
Binary file not shown.
BIN
candle-core/tests/test_with_key.pt
Normal file
BIN
candle-core/tests/test_with_key.pt
Normal file
Binary file not shown.
@ -89,7 +89,7 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
|||||||
|
|
||||||
pub fn load() -> Result<crate::vision::Dataset> {
|
pub fn load() -> Result<crate::vision::Dataset> {
|
||||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
let dataset_id = "mnist".to_string();
|
let dataset_id = "ylecun/mnist".to_string();
|
||||||
let repo = Repo::with_revision(
|
let repo = Repo::with_revision(
|
||||||
dataset_id,
|
dataset_id,
|
||||||
RepoType::Dataset,
|
RepoType::Dataset,
|
||||||
|
@ -12,7 +12,7 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { workspace = true }
|
candle = { workspace = true }
|
||||||
candle-datasets = { workspace = true }
|
candle-datasets = { workspace = true, optional = true }
|
||||||
candle-nn = { workspace = true }
|
candle-nn = { workspace = true }
|
||||||
candle-transformers = { workspace = true }
|
candle-transformers = { workspace = true }
|
||||||
candle-flash-attn = { workspace = true, optional = true }
|
candle-flash-attn = { workspace = true, optional = true }
|
||||||
@ -21,16 +21,21 @@ candle-onnx = { workspace = true, optional = true }
|
|||||||
csv = "1.3.0"
|
csv = "1.3.0"
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
hf-hub = { workspace = true, features=["tokio"]}
|
hf-hub = { workspace = true, features = ["tokio"] }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
palette = { version = "0.7.6", optional = true }
|
||||||
|
enterpolation = { version = "0.2.1", optional = true}
|
||||||
|
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
|
rubato = { version = "0.15.0", optional = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
|
cpal = { version = "0.15.2", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -39,11 +44,10 @@ clap = { workspace = true }
|
|||||||
imageproc = { workspace = true }
|
imageproc = { workspace = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
rusttype = { workspace = true }
|
ab_glyph = { workspace = true }
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tracing-chrome = { workspace = true }
|
tracing-chrome = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
wav = { workspace = true }
|
|
||||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
tokio = "1.29.1"
|
tokio = "1.29.1"
|
||||||
|
|
||||||
@ -61,6 +65,10 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
|
|||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
onnx = ["candle-onnx"]
|
onnx = ["candle-onnx"]
|
||||||
metal = ["candle/metal", "candle-nn/metal"]
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
|
microphone = ["cpal"]
|
||||||
|
encodec = ["cpal", "symphonia", "rubato"]
|
||||||
|
mimi = ["cpal", "symphonia", "rubato"]
|
||||||
|
depth_anything_v2 = ["palette", "enterpolation"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
@ -77,3 +85,35 @@ required-features = ["onnx"]
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "onnx_basics"
|
name = "onnx_basics"
|
||||||
required-features = ["onnx"]
|
required-features = ["onnx"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "whisper"
|
||||||
|
required-features = ["symphonia"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "whisper-microphone"
|
||||||
|
required-features = ["microphone"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "mnist-training"
|
||||||
|
required-features = ["candle-datasets"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "llama2-c"
|
||||||
|
required-features = ["candle-datasets"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "mimi"
|
||||||
|
required-features = ["mimi"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "encodec"
|
||||||
|
required-features = ["encodec"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "depth_anything_v2"
|
||||||
|
required-features = ["depth_anything_v2"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "silero-vad"
|
||||||
|
required-features = ["onnx"]
|
||||||
|
20
candle-examples/examples/based/README.md
Normal file
20
candle-examples/examples/based/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-based
|
||||||
|
|
||||||
|
Experimental, not instruction-tuned small LLM from the Hazy Research group, combining local and linear attention layers.
|
||||||
|
|
||||||
|
[Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based)
|
||||||
|
|
||||||
|
[Simple linear attention language models balance the recall-throughput tradeoff](https://arxiv.org/abs/2402.18668)
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example based --release -- --prompt "Flying monkeys are" --which 1b-50b --sample-len 100
|
||||||
|
|
||||||
|
Flying monkeys are a common sight in the wild, but they are also a threat to humans.
|
||||||
|
|
||||||
|
The new study, published today (July 31) in the journal Science Advances, shows that the monkeys are using their brains to solve the problem of how to get around the problem.
|
||||||
|
|
||||||
|
"We found that the monkeys were using a strategy called 'cognitive mapping' - they would use their brains to map out the route ahead," says lead author Dr. David J. Smith from the University of California
|
||||||
|
|
||||||
|
```
|
275
candle-examples/examples/based/main.rs
Normal file
275
candle-examples/examples/based/main.rs
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle_transformers::models::based::Model;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||||
|
};
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "360m")]
|
||||||
|
W360m,
|
||||||
|
#[value(name = "1b")]
|
||||||
|
W1b,
|
||||||
|
#[value(name = "1b-50b")]
|
||||||
|
W1b50b,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "refs/pr/1")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "360m")]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id,
|
||||||
|
None => match args.which {
|
||||||
|
Which::W360m => "hazyresearch/based-360m".to_string(),
|
||||||
|
Which::W1b => "hazyresearch/based-1b".to_string(),
|
||||||
|
Which::W1b50b => "hazyresearch/based-1b-50b".to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let config_file = match args.config_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("config.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => vec![repo.get("model.safetensors")?],
|
||||||
|
};
|
||||||
|
|
||||||
|
let repo = api.model("openai-community/gpt2".to_string());
|
||||||
|
let tokenizer_file = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
if args.which == Which::W1b50b {
|
||||||
|
vb = vb.pp("model");
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
20
candle-examples/examples/beit/README.md
Normal file
20
candle-examples/examples/beit/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-beit
|
||||||
|
|
||||||
|
[Beit](https://arxiv.org/abs/2106.08254) is a computer vision model.
|
||||||
|
In this example, it is used as an ImageNet classifier: the model returns the
|
||||||
|
probability for the image to belong to each of the 1000 ImageNet categories.
|
||||||
|
|
||||||
|
## Running some example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example beit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
> mountain bike, all-terrain bike, off-roader: 56.16%
|
||||||
|
> bicycle-built-for-two, tandem bicycle, tandem: 3.08%
|
||||||
|
> maillot : 2.23%
|
||||||
|
> alp : 0.88%
|
||||||
|
> crash helmet : 0.85%
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|

|
79
candle-examples/examples/beit/main.rs
Normal file
79
candle-examples/examples/beit/main.rs
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
//! BEiT: BERT Pre-Training of Image Transformers
|
||||||
|
//! https://github.com/microsoft/unilm/tree/master/beit
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::beit;
|
||||||
|
|
||||||
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
|
/// (3, 384, 384). Beit special normalization is applied.
|
||||||
|
pub fn load_image384_beit_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||||
|
let img = image::ImageReader::open(p)?
|
||||||
|
.decode()
|
||||||
|
.map_err(candle::Error::wrap)?
|
||||||
|
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
let data = img.into_raw();
|
||||||
|
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||||
|
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
|
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
|
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||||
|
.broadcast_sub(&mean)?
|
||||||
|
.broadcast_div(&std)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = load_image384_beit_norm(args.image)?.to_device(&device)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("vincent-espitalier/candle-beit".into());
|
||||||
|
api.get("beit_base_patch16_384.in22k_ft_in22k_in1k.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = beit::vit_base(vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -126,7 +126,7 @@ fn main() -> Result<()> {
|
|||||||
println!("Loaded and encoded {:?}", start.elapsed());
|
println!("Loaded and encoded {:?}", start.elapsed());
|
||||||
for idx in 0..args.n {
|
for idx in 0..args.n {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
let ys = model.forward(&token_ids, &token_type_ids, None)?;
|
||||||
if idx == 0 {
|
if idx == 0 {
|
||||||
println!("{ys}");
|
println!("{ys}");
|
||||||
}
|
}
|
||||||
@ -163,11 +163,19 @@ fn main() -> Result<()> {
|
|||||||
Ok(Tensor::new(tokens.as_slice(), device)?)
|
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let attention_mask = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|tokens| {
|
||||||
|
let tokens = tokens.get_attention_mask().to_vec();
|
||||||
|
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||||
|
let attention_mask = Tensor::stack(&attention_mask, 0)?;
|
||||||
let token_type_ids = token_ids.zeros_like()?;
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
println!("running inference on batch {:?}", token_ids.shape());
|
println!("running inference on batch {:?}", token_ids.shape());
|
||||||
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
|
||||||
println!("generated embeddings {:?}", embeddings.shape());
|
println!("generated embeddings {:?}", embeddings.shape());
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||||
|
@ -55,7 +55,7 @@ const SEP_TOKEN_ID: u32 = 102;
|
|||||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
/// (3, 384, 384). OpenAI normalization is applied.
|
/// (3, 384, 384). OpenAI normalization is applied.
|
||||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||||
let img = image::io::Reader::open(p)?
|
let img = image::ImageReader::open(p)?
|
||||||
.decode()
|
.decode()
|
||||||
.map_err(candle::Error::wrap)?
|
.map_err(candle::Error::wrap)?
|
||||||
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
||||||
|
237
candle-examples/examples/chatglm/main.rs
Normal file
237
candle-examples/examples/chatglm/main.rs
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::chatglm::{Config, Model};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
verbose_prompt,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
println!("starting the inference loop");
|
||||||
|
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||||
|
if tokens.is_empty() {
|
||||||
|
anyhow::bail!("Empty prompts are not supported in the chatglm model.")
|
||||||
|
}
|
||||||
|
if self.verbose_prompt {
|
||||||
|
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||||
|
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
println!("{id:7} -> '{token}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut tokens = tokens.get_ids().to_vec();
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
|
||||||
|
Some(token) => *token,
|
||||||
|
None => anyhow::bail!("cannot find the endoftext token"),
|
||||||
|
};
|
||||||
|
print!("{prompt}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input)?;
|
||||||
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||||
|
print!("{token}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Display the token for the specified prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
verbose_prompt: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id.to_string(),
|
||||||
|
None => "THUDM/chatglm3-6b".to_string(),
|
||||||
|
};
|
||||||
|
let revision = match args.revision {
|
||||||
|
Some(rev) => rev.to_string(),
|
||||||
|
None => "main".to_string(),
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
let tokenizer_filename = match args.tokenizer {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => api
|
||||||
|
.model("lmz/candle-chatglm".to_string())
|
||||||
|
.get("chatglm-tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_file {
|
||||||
|
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||||
|
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = Config::glm3_6b();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
args.verbose_prompt,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
46
candle-examples/examples/clip/README.md
Normal file
46
candle-examples/examples/clip/README.md
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# candle-clip
|
||||||
|
|
||||||
|
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||||
|
pairs of images with related texts.
|
||||||
|
|
||||||
|
https://github.com/openai/CLIP
|
||||||
|
|
||||||
|
https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||||
|
|
||||||
|
## Running on an example on cpu
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||||
|
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
|
||||||
|
INFO clip: Probability: 0.0000% Text: a cycling race
|
||||||
|
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
||||||
|
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
INFO clip: Probability: 99.9999% Text: a cycling race
|
||||||
|
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
||||||
|
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running on an example with metal feature (mac)
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||||
|
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
|
||||||
|
INFO clip: Probability: 0.0000% Text: a cycling race
|
||||||
|
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
||||||
|
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
INFO clip: Probability: 99.9999% Text: a cycling race
|
||||||
|
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
||||||
|
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
||||||
|
```
|
164
candle-examples/examples/clip/main.rs
Normal file
164
candle-examples/examples/clip/main.rs
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
|
use candle_transformers::models::clip;
|
||||||
|
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, use_value_delimiter = true)]
|
||||||
|
images: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(long, use_value_delimiter = true)]
|
||||||
|
sequences: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
||||||
|
let img = image::ImageReader::open(path)?.decode()?;
|
||||||
|
let (height, width) = (image_size, image_size);
|
||||||
|
let img = img.resize_to_fill(
|
||||||
|
width as u32,
|
||||||
|
height as u32,
|
||||||
|
image::imageops::FilterType::Triangle,
|
||||||
|
);
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
let img = img.into_raw();
|
||||||
|
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
||||||
|
.permute((2, 0, 1))?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.affine(2. / 255., -1.)?;
|
||||||
|
Ok(img)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_images<T: AsRef<std::path::Path>>(
|
||||||
|
paths: &Vec<T>,
|
||||||
|
image_size: usize,
|
||||||
|
) -> anyhow::Result<Tensor> {
|
||||||
|
let mut images = vec![];
|
||||||
|
for path in paths {
|
||||||
|
let tensor = load_image(path, image_size)?;
|
||||||
|
images.push(tensor);
|
||||||
|
}
|
||||||
|
let images = Tensor::stack(&images, 0)?;
|
||||||
|
Ok(images)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
|
||||||
|
let api = api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"openai/clip-vit-base-patch32".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/15".to_string(),
|
||||||
|
));
|
||||||
|
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
let tokenizer = get_tokenizer(args.tokenizer)?;
|
||||||
|
let config = clip::ClipConfig::vit_base_patch32();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let vec_imgs = match args.images {
|
||||||
|
Some(imgs) => imgs,
|
||||||
|
None => vec![
|
||||||
|
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
|
||||||
|
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
||||||
|
let vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
||||||
|
let model = clip::ClipModel::new(vb, &config)?;
|
||||||
|
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
||||||
|
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||||
|
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||||
|
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
println!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||||
|
let probability_vec = softmax_image_vec
|
||||||
|
.iter()
|
||||||
|
.map(|v| v * 100.0)
|
||||||
|
.collect::<Vec<f32>>();
|
||||||
|
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||||
|
for (i, img) in vec_imgs.iter().enumerate() {
|
||||||
|
let start = i * probability_per_image;
|
||||||
|
let end = start + probability_per_image;
|
||||||
|
let prob = &probability_vec[start..end];
|
||||||
|
println!("\n\nResults for image: {}\n", img);
|
||||||
|
for (i, p) in prob.iter().enumerate() {
|
||||||
|
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||||
|
let tokenizer = match tokenizer {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"openai/clip-vit-base-patch32".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/15".to_string(),
|
||||||
|
));
|
||||||
|
api.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
Some(file) => file.into(),
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(tokenizer).map_err(E::msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tokenize_sequences(
|
||||||
|
sequences: Option<Vec<String>>,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
device: &Device,
|
||||||
|
) -> anyhow::Result<(Tensor, Vec<String>)> {
|
||||||
|
let pad_id = *tokenizer
|
||||||
|
.get_vocab(true)
|
||||||
|
.get("<|endoftext|>")
|
||||||
|
.ok_or(E::msg("No pad token"))?;
|
||||||
|
let vec_seq = match sequences {
|
||||||
|
Some(seq) => seq,
|
||||||
|
None => vec![
|
||||||
|
"a cycling race".to_string(),
|
||||||
|
"a photo of two cats".to_string(),
|
||||||
|
"a robot holding a candle".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
let mut tokens = vec![];
|
||||||
|
for seq in vec_seq.clone() {
|
||||||
|
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||||
|
tokens.push(encoding.get_ids().to_vec());
|
||||||
|
}
|
||||||
|
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
|
||||||
|
// Pad the sequences to have the same length
|
||||||
|
for token_vec in tokens.iter_mut() {
|
||||||
|
let len_diff = max_len - token_vec.len();
|
||||||
|
if len_diff > 0 {
|
||||||
|
token_vec.extend(vec![pad_id; len_diff]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let input_ids = Tensor::new(tokens, device)?;
|
||||||
|
Ok((input_ids, vec_seq))
|
||||||
|
}
|
96
candle-examples/examples/codegeex4-9b/README.org
Normal file
96
candle-examples/examples/codegeex4-9b/README.org
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
* candle-codegeex4_9b
|
||||||
|
THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios, including code completion, code interpreter, web search, function calling, repository-level Q&A and much more.
|
||||||
|
|
||||||
|
- [[https://github.com/THUDM/CodeGeeX4][Github]]
|
||||||
|
- [[https://codegeex.cn/][HomePage]]
|
||||||
|
- [[https://huggingface.co/THUDM/codegeex4-all-9b][huggingface]]
|
||||||
|
|
||||||
|
** Running with ~cuda~
|
||||||
|
|
||||||
|
#+begin_src shell
|
||||||
|
cargo run --example codegeex4-9b --release --features cuda -- --prompt "please write a insertion sort in rust" --sample-len 300
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
** Running with ~cpu~
|
||||||
|
#+begin_src shell
|
||||||
|
cargo run --example codegeex4-9b --release --cpu -- --prompt "please write a insertion sort in rust" --sample-len 300
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
** Output_Example
|
||||||
|
*** Input
|
||||||
|
#+begin_src shell
|
||||||
|
cargo run --release --features cuda -- --prompt 'please write a FFT in rust' --sample-len 500 --cache /root/autodl-tmp
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
*** Output
|
||||||
|
#+begin_src shell
|
||||||
|
avx: false, neon: false, simd128: false, f16c: false
|
||||||
|
temp: 0.95 repeat-penalty: 1.10 repeat-last-n: 64
|
||||||
|
cache path /root/autodl-tmp
|
||||||
|
Prompt: [please write a FFT in rust]
|
||||||
|
Using Seed 11511762269791786684
|
||||||
|
DType is BF16
|
||||||
|
transofrmer layers create
|
||||||
|
模型加载完毕 4
|
||||||
|
starting the inference loop
|
||||||
|
|
||||||
|
开始生成
|
||||||
|
samplelen 500
|
||||||
|
|
||||||
|
500 tokens generated (34.60 token/s)
|
||||||
|
Result:
|
||||||
|
|
||||||
|
Sure, I can help you with that. Here's an example of a Fast Fourier Transform (FFT) implementation in Rust:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use num_complex::Complex;
|
||||||
|
|
||||||
|
fn fft(input: &[Complex<f64> > ] ) -> Vec<Complex<f64> > > {
|
||||||
|
let n = input.len();
|
||||||
|
|
||||||
|
if n == 1 {
|
||||||
|
return vec![input[0]]];
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut even = vec![];
|
||||||
|
let mut odd = vec![];
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
|
||||||
|
if i % 2 == 0 {
|
||||||
|
even.push(input[i]);
|
||||||
|
} else {
|
||||||
|
odd.push(input[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let even_fft = fft(&even);
|
||||||
|
let odd_fft = fft(&odd);
|
||||||
|
|
||||||
|
let mut output = vec![];
|
||||||
|
|
||||||
|
for k in 0..n/2 {
|
||||||
|
let t = Complex::new(0.0, -2.0 * std::f64::consts::PI * (k as f64) / (n as f64))) ).exp();
|
||||||
|
|
||||||
|
output.push(even_fft[k] + odd_fft[k] * t]);
|
||||||
|
output.push(even_fft[k] - odd_fft[k] * t]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This implementation uses the Cooley-Tukey algorithm to perform the FFT. The function takes an array of complex numbers and returns an array of complex numbers which is the result of the FFT.
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
|
||||||
|
* Citation
|
||||||
|
#+begin_src
|
||||||
|
@inproceedings{zheng2023codegeex,
|
||||||
|
title={CodeGeeX: A Pre-Trained Model for Code Generation with Multilingual Benchmarking on HumanEval-X},
|
||||||
|
author={Qinkai Zheng and Xiao Xia and Xu Zou and Yuxiao Dong and Shan Wang and Yufei Xue and Zihan Wang and Lei Shen and Andi Wang and Yang Li and Teng Su and Zhilin Yang and Jie Tang},
|
||||||
|
booktitle={Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
|
||||||
|
pages={5673--5684},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
#+end_src
|
252
candle-examples/examples/codegeex4-9b/main.rs
Normal file
252
candle-examples/examples/codegeex4-9b/main.rs
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
use candle_transformers::models::codegeex4_9b::*;
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
device: &Device,
|
||||||
|
dtype: DType,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
verbose_prompt,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> anyhow::Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
println!("starting the inference loop");
|
||||||
|
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
|
||||||
|
if tokens.is_empty() {
|
||||||
|
panic!("Empty prompts are not supported in the chatglm model.")
|
||||||
|
}
|
||||||
|
if self.verbose_prompt {
|
||||||
|
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||||
|
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
println!("{id:7} -> '{token}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||||
|
Some(token) => *token,
|
||||||
|
None => panic!("cannot find the endoftext token"),
|
||||||
|
};
|
||||||
|
let mut tokens = tokens.get_ids().to_vec();
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
|
||||||
|
print!("{prompt}");
|
||||||
|
std::io::stdout().flush().expect("output flush error");
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
|
||||||
|
println!("\n start_gen");
|
||||||
|
println!("samplelen {}", sample_len);
|
||||||
|
let mut count = 0;
|
||||||
|
let mut result = vec![];
|
||||||
|
for index in 0..sample_len {
|
||||||
|
count += 1;
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input)?;
|
||||||
|
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let token = self
|
||||||
|
.tokenizer
|
||||||
|
.decode(&[next_token], true)
|
||||||
|
.expect("Token error");
|
||||||
|
if self.verbose_prompt {
|
||||||
|
println!(
|
||||||
|
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
|
||||||
|
count, next_token, token
|
||||||
|
);
|
||||||
|
}
|
||||||
|
result.push(token);
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
println!("Result:");
|
||||||
|
for tokens in result {
|
||||||
|
print!("{tokens}");
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(name = "cache", short, long, default_value = ".")]
|
||||||
|
cache_path: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Display the token for the specified prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
verbose_prompt: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.95),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
println!("cache path {}", args.cache_path);
|
||||||
|
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
|
||||||
|
.build()
|
||||||
|
.map_err(anyhow::Error::msg)?;
|
||||||
|
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id.to_string(),
|
||||||
|
None => "THUDM/codegeex4-all-9b".to_string(),
|
||||||
|
};
|
||||||
|
let revision = match args.revision {
|
||||||
|
Some(rev) => rev.to_string(),
|
||||||
|
None => "main".to_string(),
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
let tokenizer_filename = match args.tokenizer {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => api
|
||||||
|
.model("THUDM/codegeex4-all-9b".to_string())
|
||||||
|
.get("tokenizer.json")
|
||||||
|
.map_err(anyhow::Error::msg)?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_file {
|
||||||
|
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||||
|
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = Config::codegeex4();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
args.verbose_prompt,
|
||||||
|
&device,
|
||||||
|
dtype,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -28,7 +28,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let model_file = match args.model {
|
let model_file = match args.model {
|
||||||
|
23
candle-examples/examples/convnext/README.md
Normal file
23
candle-examples/examples/convnext/README.md
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# candle-convnext
|
||||||
|
|
||||||
|
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
|
||||||
|
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
|
||||||
|
|
||||||
|
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
||||||
|
classification head has been trained on the ImageNet dataset and returns the
|
||||||
|
probabilities for the top-5 classes.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example convnext --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
|
||||||
|
|
||||||
|
loaded image Tensor[dims 3, 224, 224; f32]
|
||||||
|
model built
|
||||||
|
mountain bike, all-terrain bike, off-roader: 84.09%
|
||||||
|
bicycle-built-for-two, tandem bicycle, tandem: 4.15%
|
||||||
|
maillot : 0.74%
|
||||||
|
crash helmet : 0.54%
|
||||||
|
unicycle, monocycle : 0.44%
|
||||||
|
|
||||||
|
```
|
126
candle-examples/examples/convnext/main.rs
Normal file
126
candle-examples/examples/convnext/main.rs
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::convnext;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
Atto,
|
||||||
|
Femto,
|
||||||
|
Pico,
|
||||||
|
Nano,
|
||||||
|
Tiny,
|
||||||
|
Small,
|
||||||
|
Base,
|
||||||
|
Large,
|
||||||
|
AttoV2,
|
||||||
|
FemtoV2,
|
||||||
|
PicoV2,
|
||||||
|
NanoV2,
|
||||||
|
TinyV2,
|
||||||
|
BaseV2,
|
||||||
|
LargeV2,
|
||||||
|
XLarge,
|
||||||
|
Huge,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn model_filename(&self) -> String {
|
||||||
|
let name = match self {
|
||||||
|
Self::Atto => "convnext_atto.d2_in1k",
|
||||||
|
Self::Femto => "convnext_femto.d1_in1k",
|
||||||
|
Self::Pico => "convnext_pico.d1_in1k",
|
||||||
|
Self::Nano => "convnext_nano.d1h_in1k",
|
||||||
|
Self::Tiny => "convnext_tiny.fb_in1k",
|
||||||
|
Self::Small => "convnext_small.fb_in1k",
|
||||||
|
Self::Base => "convnext_base.fb_in1k",
|
||||||
|
Self::Large => "convnext_large.fb_in1k",
|
||||||
|
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
|
||||||
|
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
|
||||||
|
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
|
||||||
|
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
|
||||||
|
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
|
||||||
|
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
|
||||||
|
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
|
||||||
|
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
|
||||||
|
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
|
||||||
|
};
|
||||||
|
|
||||||
|
format!("timm/{name}")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config(&self) -> convnext::Config {
|
||||||
|
match self {
|
||||||
|
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
|
||||||
|
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
|
||||||
|
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
|
||||||
|
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
|
||||||
|
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
|
||||||
|
Self::Small => convnext::Config::small(),
|
||||||
|
Self::Base | Self::BaseV2 => convnext::Config::base(),
|
||||||
|
Self::Large | Self::LargeV2 => convnext::Config::large(),
|
||||||
|
Self::XLarge => convnext::Config::xlarge(),
|
||||||
|
Self::Huge => convnext::Config::huge(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(value_enum, long, default_value_t=Which::Tiny)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let model_name = args.which.model_filename();
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model(model_name);
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = convnext::convnext(&args.which.config(), 1000, vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -0,0 +1 @@
|
|||||||
|
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));
|
||||||
|
13
candle-examples/examples/depth_anything_v2/README.md
Normal file
13
candle-examples/examples/depth_anything_v2/README.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# candle-dinov2
|
||||||
|
|
||||||
|
[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which
|
||||||
|
builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer.
|
||||||
|
|
||||||
|
This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it.
|
||||||
|
|
||||||
|
## Running an example with color map and CUDA
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
```
|
||||||
|
|
50
candle-examples/examples/depth_anything_v2/color_map.rs
Normal file
50
candle-examples/examples/depth_anything_v2/color_map.rs
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
use enterpolation::linear::ConstEquidistantLinear;
|
||||||
|
use enterpolation::Generator;
|
||||||
|
use palette::LinSrgb;
|
||||||
|
|
||||||
|
use candle::Tensor;
|
||||||
|
|
||||||
|
pub struct SpectralRColormap {
|
||||||
|
gradient: ConstEquidistantLinear<f32, LinSrgb, 9>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SpectralRColormap {
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
// Define a colormap similar to 'Spectral_r' by specifying key colors.
|
||||||
|
// got the colors from ChatGPT-4o
|
||||||
|
let gradient = ConstEquidistantLinear::<f32, _, 9>::equidistant_unchecked([
|
||||||
|
LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue
|
||||||
|
LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue
|
||||||
|
LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan
|
||||||
|
LinSrgb::new(0.6706, 0.8667, 0.6431), // Green
|
||||||
|
LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow
|
||||||
|
LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange
|
||||||
|
LinSrgb::new(0.9922, 0.6824, 0.3804), // Red
|
||||||
|
LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red
|
||||||
|
LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple
|
||||||
|
]);
|
||||||
|
Self { gradient }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_color(&self, value: f32) -> LinSrgb {
|
||||||
|
self.gradient.gen(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gray2color(&self, gray: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
println!("Gray: {:?}", gray.dims());
|
||||||
|
let gray_values: Vec<f32> = gray.flatten_all()?.to_vec1()?;
|
||||||
|
let rgb_values: Vec<f32> = gray_values
|
||||||
|
.iter()
|
||||||
|
.map(|g| self.get_color(*g))
|
||||||
|
.flat_map(|rgb| [rgb.red, rgb.green, rgb.blue])
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let [.., height, width] = gray.dims() else {
|
||||||
|
candle::bail!("Not enough dims!")
|
||||||
|
};
|
||||||
|
|
||||||
|
let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?;
|
||||||
|
|
||||||
|
color.permute((2, 0, 1))
|
||||||
|
}
|
||||||
|
}
|
187
candle-examples/examples/depth_anything_v2/main.rs
Normal file
187
candle-examples/examples/depth_anything_v2/main.rs
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
//! Depth Anything V2
|
||||||
|
//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use std::ffi::OsString;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::DType::{F32, U8};
|
||||||
|
use candle::{DType, Device, Module, Result, Tensor};
|
||||||
|
use candle_examples::{load_image, load_image_and_resize, save_image};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config};
|
||||||
|
use candle_transformers::models::dinov2;
|
||||||
|
|
||||||
|
use crate::color_map::SpectralRColormap;
|
||||||
|
|
||||||
|
mod color_map;
|
||||||
|
|
||||||
|
// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207
|
||||||
|
const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
|
||||||
|
const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225];
|
||||||
|
|
||||||
|
const DINO_IMG_SIZE: usize = 518;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
dinov2_model: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
depth_anything_v2_model: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: PathBuf,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
output_dir: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
color_map: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let dinov2_model_file = match args.dinov2_model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("lmz/candle-dino-v2".into());
|
||||||
|
api.get("dinov2_vits14.safetensors")?
|
||||||
|
}
|
||||||
|
Some(dinov2_model) => dinov2_model,
|
||||||
|
};
|
||||||
|
println!("Using file {:?}", dinov2_model_file);
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? };
|
||||||
|
let dinov2 = dinov2::vit_small(vb)?;
|
||||||
|
println!("DinoV2 model built");
|
||||||
|
|
||||||
|
let depth_anything_model_file = match args.depth_anything_v2_model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into());
|
||||||
|
api.get("depth_anything_v2_vits.safetensors")?
|
||||||
|
}
|
||||||
|
Some(depth_anything_model) => depth_anything_model,
|
||||||
|
};
|
||||||
|
println!("Using file {:?}", depth_anything_model_file);
|
||||||
|
|
||||||
|
let vb = unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = DepthAnythingV2Config::vit_small();
|
||||||
|
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
|
||||||
|
|
||||||
|
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;
|
||||||
|
|
||||||
|
println!("Loaded image {image:?}");
|
||||||
|
|
||||||
|
let depth = depth_anything.forward(&image)?;
|
||||||
|
|
||||||
|
println!("Got predictions {:?}", depth.shape());
|
||||||
|
|
||||||
|
let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?;
|
||||||
|
|
||||||
|
let output_path = full_output_path(&args.image, &args.output_dir);
|
||||||
|
println!("Saving image to {}", output_path.to_string_lossy());
|
||||||
|
save_image(&output_image, output_path)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn full_output_path(image_path: &PathBuf, output_dir: &Option<PathBuf>) -> PathBuf {
|
||||||
|
let input_file_name = image_path.file_name().unwrap();
|
||||||
|
let mut output_file_name = OsString::from("depth_");
|
||||||
|
output_file_name.push(input_file_name);
|
||||||
|
let mut output_path = match output_dir {
|
||||||
|
None => image_path.parent().unwrap().to_path_buf(),
|
||||||
|
Some(output_path) => output_path.clone(),
|
||||||
|
};
|
||||||
|
output_path.push(output_file_name);
|
||||||
|
|
||||||
|
output_path
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_and_prep_image(
|
||||||
|
image_path: &PathBuf,
|
||||||
|
device: &Device,
|
||||||
|
) -> anyhow::Result<(usize, usize, Tensor)> {
|
||||||
|
let (_original_image, original_height, original_width) = load_image(&image_path, None)?;
|
||||||
|
|
||||||
|
let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)?
|
||||||
|
.unsqueeze(0)?
|
||||||
|
.to_dtype(F32)?
|
||||||
|
.to_device(&device)?;
|
||||||
|
|
||||||
|
let max_pixel_val = Tensor::try_from(255.0f32)?
|
||||||
|
.to_device(&device)?
|
||||||
|
.broadcast_as(image.shape())?;
|
||||||
|
let image = (image / max_pixel_val)?;
|
||||||
|
let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?;
|
||||||
|
|
||||||
|
Ok((original_height, original_width, image))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result<Tensor> {
|
||||||
|
let mean_tensor =
|
||||||
|
Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
|
||||||
|
let std_tensor =
|
||||||
|
Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
|
||||||
|
image.sub(&mean_tensor)?.div(&std_tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn post_process_image(
|
||||||
|
image: &Tensor,
|
||||||
|
original_height: usize,
|
||||||
|
original_width: usize,
|
||||||
|
color_map: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let out = image.interpolate2d(original_height, original_width)?;
|
||||||
|
let out = scale_image(&out)?;
|
||||||
|
|
||||||
|
let out = if color_map {
|
||||||
|
let spectral_r = SpectralRColormap::new();
|
||||||
|
spectral_r.gray2color(&out)?
|
||||||
|
} else {
|
||||||
|
let rgb_slice = [&out, &out, &out];
|
||||||
|
Tensor::cat(&rgb_slice, 0)?.squeeze(1)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let max_pixel_val = Tensor::try_from(255.0f32)?
|
||||||
|
.to_device(out.device())?
|
||||||
|
.broadcast_as(out.shape())?;
|
||||||
|
let out = (out * max_pixel_val)?;
|
||||||
|
|
||||||
|
out.to_dtype(U8)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scale_image(depth: &Tensor) -> Result<Tensor> {
|
||||||
|
let flat_values: Vec<f32> = depth.flatten_all()?.to_vec1()?;
|
||||||
|
|
||||||
|
let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap();
|
||||||
|
let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
|
||||||
|
|
||||||
|
let min_val_tensor = Tensor::try_from(*min_val)?
|
||||||
|
.to_device(depth.device())?
|
||||||
|
.broadcast_as(depth.shape())?;
|
||||||
|
let depth = (depth - min_val_tensor)?;
|
||||||
|
|
||||||
|
let range = max_val - min_val;
|
||||||
|
let range_tensor = Tensor::try_from(range)?
|
||||||
|
.to_device(depth.device())?
|
||||||
|
.broadcast_as(depth.shape())?;
|
||||||
|
|
||||||
|
depth / range_tensor
|
||||||
|
}
|
@ -31,7 +31,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let model_file = match args.model {
|
let model_file = match args.model {
|
||||||
|
25
candle-examples/examples/dinov2reg4/README.md
Normal file
25
candle-examples/examples/dinov2reg4/README.md
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# candle-dinov2-reg4
|
||||||
|
|
||||||
|
[DINOv2-reg4](https://arxiv.org/abs/2309.16588) is the lastest version of DINOv2 with registers.
|
||||||
|
In this example, it is used as an plant species classifier: the model returns the
|
||||||
|
probability for the image to belong to each of the 7806 PlantCLEF2024 categories.
|
||||||
|
|
||||||
|
## Running some example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Download classes names and a plant picture to identify
|
||||||
|
curl https://huggingface.co/vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights/raw/main/species_id_mapping.txt --output candle-examples/examples/dinov2reg4/species_id_mapping.txt
|
||||||
|
curl https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c --output candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg
|
||||||
|
|
||||||
|
# Perform inference
|
||||||
|
cargo run --example dinov2reg4 --release -- --image candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg
|
||||||
|
|
||||||
|
> Orchis simia Lam. : 45.55%
|
||||||
|
> Orchis × bergonii Nanteuil: 9.80%
|
||||||
|
> Orchis italica Poir. : 9.66%
|
||||||
|
> Orchis × angusticruris Franch.: 2.76%
|
||||||
|
> Orchis × bivonae Tod. : 2.54%
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|

|
70
candle-examples/examples/dinov2reg4/main.rs
Normal file
70
candle-examples/examples/dinov2reg4/main.rs
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
//! DINOv2 reg4 finetuned on PlantCLEF 2024
|
||||||
|
//! https://arxiv.org/abs/2309.16588
|
||||||
|
//! https://huggingface.co/spaces/BVRA/PlantCLEF2024
|
||||||
|
//! https://zenodo.org/records/10848263
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::dinov2reg4;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = candle_examples::imagenet::load_image518(args.image)?.to_device(&device)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let f_species_id_mapping = "candle-examples/examples/dinov2reg4/species_id_mapping.txt";
|
||||||
|
let classes: Vec<String> = std::fs::read_to_string(f_species_id_mapping)
|
||||||
|
.expect("missing classes file")
|
||||||
|
.split('\n')
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api =
|
||||||
|
api.model("vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights".into());
|
||||||
|
api.get(
|
||||||
|
"vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all.safetensors",
|
||||||
|
)?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = dinov2reg4::vit_base(vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!("{:24}: {:.2}%", classes[category_idx], 100. * pr);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -47,7 +47,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let model_file = match args.model {
|
let model_file = match args.model {
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user