mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Compare commits
139 Commits
wasm-llama
...
initialize
Author | SHA1 | Date | |
---|---|---|---|
0bb344f798 | |||
965597a873 | |||
ca449f9ee1 | |||
b8263aa15c | |||
e68b2accb4 | |||
08effe3762 | |||
8ad4a21ffc | |||
5e49922be2 | |||
ebcfd96d94 | |||
5b1690fffa | |||
3cc87058b7 | |||
531f23b4d0 | |||
495e0b7580 | |||
90374097dc | |||
c84883ecf2 | |||
a094dc503d | |||
34f4b3187e | |||
eab54e4490 | |||
9e7e6e0288 | |||
8bd2b22b33 | |||
d379a76a9e | |||
9af438ac1b | |||
b1ff78f762 | |||
5a63b51f14 | |||
6d694554b8 | |||
9aca398a4f | |||
60cd1551ca | |||
a0908d212c | |||
972078e1ae | |||
16b89f5b83 | |||
0741ebbd51 | |||
0c3f109faa | |||
2ba6b2826f | |||
1d0157bbc4 | |||
91dbf907d3 | |||
e12372021b | |||
55e428c8ae | |||
01ea57da8c | |||
662db45fc3 | |||
906c0f3eb5 | |||
e29c7809ec | |||
a325c1aa50 | |||
b6cf26e48e | |||
379eadc68e | |||
7e4fbc1e17 | |||
80f0482f26 | |||
94eff56aee | |||
a55133effd | |||
ff53f38467 | |||
4a95d34c83 | |||
7f710a573d | |||
c8039579a5 | |||
0b0fa56978 | |||
385f0d261c | |||
b765f2c37f | |||
66d1c093e0 | |||
de7c31bfe9 | |||
8e7ef96588 | |||
f3fe730a30 | |||
c7f92f985e | |||
3bbc08a8df | |||
6a2137af4f | |||
0dc1e5f387 | |||
bd2fb6216b | |||
3542b26143 | |||
a690f14a77 | |||
90d778c059 | |||
171fcbe539 | |||
07e83c55c0 | |||
25ec2d9f6b | |||
da26e2832c | |||
fcfdcbd337 | |||
653ec5abc1 | |||
c3a0761e62 | |||
0cef3998fd | |||
e5f510d209 | |||
0dd94eff4c | |||
a3b1699409 | |||
5b79b38bc7 | |||
a5c5a893aa | |||
e6ce47f9e0 | |||
1892bd139c | |||
749c8c7f51 | |||
d9b4fef189 | |||
8fa329aca2 | |||
cd225bd3b1 | |||
a4f6977087 | |||
dece0b8a76 | |||
b80348d22f | |||
3a62aee91f | |||
be21d7e75a | |||
9c4cf6804b | |||
dbc6f281c9 | |||
47a5bee249 | |||
cf965ecaa8 | |||
b9864e1357 | |||
608b2358c6 | |||
1e6dbeac01 | |||
13ce68ff9b | |||
89d3926c9b | |||
ab35684326 | |||
b5bb5e056d | |||
d0d7010682 | |||
fc265d9dcf | |||
2345b8ce3f | |||
f53a333ea9 | |||
e72ba0b9e7 | |||
5bb2fce998 | |||
2c9f605976 | |||
141df4ad2b | |||
166bfd5847 | |||
1c062bf06b | |||
d34039e352 | |||
93cfe5642f | |||
88bd3b604a | |||
b278834267 | |||
0b175fcbbd | |||
620f83cf66 | |||
f7b2a0391d | |||
8b6f5be1cc | |||
df6667ba88 | |||
a79286885c | |||
74845a4dcd | |||
aa76b783eb | |||
25564357f7 | |||
634700d84a | |||
e635f18eda | |||
dba31473d4 | |||
1b2b32e58d | |||
166f4d1101 | |||
ae68635af9 | |||
c11e78b334 | |||
1b705a426f | |||
a70b95f9e7 | |||
a44471a305 | |||
45642a8530 | |||
82464166e4 | |||
52414ba5c8 | |||
186c308d51 |
2
.github/workflows/book-cd.yml
vendored
2
.github/workflows/book-cd.yml
vendored
@ -1,7 +1,5 @@
|
||||
name: Deploy Rust book
|
||||
on:
|
||||
# TODO put this back only when merging after this PR lands.
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
87
.github/workflows/ci_cuda.yaml
vendored
Normal file
87
.github/workflows/ci_cuda.yaml
vendored
Normal file
@ -0,0 +1,87 @@
|
||||
name: CI / cuda
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
start-runner:
|
||||
name: Start self-hosted EC2 runner
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||
EC2_INSTANCE_TYPE: g5.xlarge
|
||||
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
||||
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
||||
outputs:
|
||||
label: ${{ steps.start-ec2-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Start EC2 runner
|
||||
id: start-ec2-runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: start
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
||||
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
||||
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
||||
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
||||
aws-resource-tags: > # optional, requires additional permissions
|
||||
[
|
||||
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
||||
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
||||
]
|
||||
|
||||
test-cuda:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
needs: start-runner # required to start the main job when the runner is ready
|
||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
# This is used to complete the identity challenge
|
||||
# with sigstore/fulcio when running outside of PRs.
|
||||
id-token: write
|
||||
security-events: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Install Rust Stable
|
||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: apt update -y && apt install libssl-dev -y
|
||||
- name: Test (cuda)
|
||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||
stop-runner:
|
||||
name: Stop self-hosted EC2 runner
|
||||
needs:
|
||||
- start-runner
|
||||
- test-cuda
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Stop EC2 runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: stop
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
label: ${{ needs.start-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -20,6 +20,7 @@ Cargo.lock
|
||||
|
||||
perf.data
|
||||
flamegraph.svg
|
||||
*.dylib
|
||||
*.so
|
||||
*.swp
|
||||
trace-*.json
|
||||
|
17
Cargo.toml
17
Cargo.toml
@ -1,6 +1,7 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"candle-core",
|
||||
"candle-datasets",
|
||||
"candle-examples",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
@ -14,23 +15,25 @@ exclude = [
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT/Apache-2.0"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
cudarc = { version = "0.9.13", features = ["f16"] }
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
# TODO: Switch back to the official gemm implementation once it has caught up.
|
||||
gemm = { version = "0.15.5", package = "candle-gemm" }
|
||||
gemm = { version = "0.15.6", package = "candle-gemm" }
|
||||
hf-hub = "0.2.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
@ -38,11 +41,13 @@ memmap2 = "0.7.1"
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
safetensors = "0.3.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.13.3", default-features = false }
|
||||
tokenizers = { version = "0.13.4", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
|
201
LICENSE-APACHE
Normal file
201
LICENSE-APACHE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
23
LICENSE-MIT
Normal file
23
LICENSE-MIT
Normal file
@ -0,0 +1,23 @@
|
||||
Permission is hereby granted, free of charge, to any
|
||||
person obtaining a copy of this software and associated
|
||||
documentation files (the "Software"), to deal in the
|
||||
Software without restriction, including without
|
||||
limitation the rights to use, copy, modify, merge,
|
||||
publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software
|
||||
is furnished to do so, subject to the following
|
||||
conditions:
|
||||
|
||||
The above copyright notice and this permission notice
|
||||
shall be included in all copies or substantial portions
|
||||
of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
|
||||
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
|
||||
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
|
||||
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
|
||||
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
|
||||
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
6
Makefile
6
Makefile
@ -2,6 +2,8 @@ clean-ptx:
|
||||
find target -name "*.ptx" -type f -delete
|
||||
echo "" > candle-kernels/src/lib.rs
|
||||
touch candle-kernels/build.rs
|
||||
touch candle-examples/build.rs
|
||||
touch candle-flash-attn/build.rs
|
||||
|
||||
clean:
|
||||
cargo clean
|
||||
@ -9,4 +11,8 @@ clean:
|
||||
test:
|
||||
cargo test
|
||||
|
||||
pyo3-test:
|
||||
cargo build --profile=release-with-debug --package candle-pyo3
|
||||
python3 candle-pyo3/test.py
|
||||
|
||||
all: test
|
||||
|
92
README.md
92
README.md
@ -1,10 +1,11 @@
|
||||
# candle
|
||||
[](https://discord.com/channels/879548962464493619/1136218819447238726)
|
||||
[](https://crates.io/crates/candle-core)
|
||||
[](https://docs.rs/candle-core)
|
||||

|
||||
|
||||
Candle is a minimalist ML framework for Rust with a focus on easiness of use and
|
||||
on performance (including GPU support). Try our online demos:
|
||||
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
|
||||
and ease of use. Try our online demos:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
|
||||
|
||||
@ -26,6 +27,8 @@ Check out our [examples](./candle-examples/examples/):
|
||||
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
||||
generation.
|
||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||
image generative model, yet to be optimized.
|
||||
|
||||
Run them using the following commands:
|
||||
```
|
||||
@ -34,6 +37,7 @@ cargo run --example llama --release
|
||||
cargo run --example falcon --release
|
||||
cargo run --example bert --release
|
||||
cargo run --example bigcode --release
|
||||
cargo run --example stable-diffusion --release --features image -- --prompt "a rusty robot holding a fire torch"
|
||||
```
|
||||
|
||||
In order to use **CUDA** add `--features cuda` to the example command line.
|
||||
@ -48,37 +52,40 @@ For llama2, run the following command to retrieve the weight files and start a
|
||||
test server:
|
||||
```bash
|
||||
cd candle-wasm-examples/llama2-c
|
||||
wget https://karpathy.ai/llama2c/model.bin
|
||||
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
|
||||
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
|
||||
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
|
||||
trunk serve --release --public-url /candle-llama2/ --port 8081
|
||||
```
|
||||
And then browse to
|
||||
And then head over to
|
||||
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
|
||||
|
||||
<!--- ANCHOR: features --->
|
||||
|
||||
## Features
|
||||
|
||||
- Simple syntax, looks and like PyTorch.
|
||||
- CPU and Cuda backends, m1, f16, bf16.
|
||||
- Enable serverless (CPU), small and fast deployments
|
||||
- WASM support, run your models in a browser.
|
||||
- Model training.
|
||||
- Distributed computing using NCCL.
|
||||
- Models out of the box: Llama, Whisper, Falcon, StarCoder...
|
||||
- Embed user-defined ops/kernels, such as [flash-attention
|
||||
v2](https://github.com/LaurentMazare/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
|
||||
- Simple syntax, looks and feels like PyTorch.
|
||||
- Model training.
|
||||
- Embed user-defined ops/kernels, such as [flash-attention v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
|
||||
- Backends.
|
||||
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
|
||||
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
|
||||
- WASM support, run your models in a browser.
|
||||
- Model support out of the box.
|
||||
- LLMs: Llama v1 and v2, Falcon, StarCoder.
|
||||
- Whisper.
|
||||
- Stable Diffusion.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
|
||||
<!--- ANCHOR_END: features --->
|
||||
|
||||
## How to use ?
|
||||
## How to use
|
||||
|
||||
<!--- ANCHOR: cheatsheet --->
|
||||
Cheatsheet:
|
||||
|
||||
| | Using PyTorch | Using Candle |
|
||||
|------------|------------------------------------------|------------------------------------------------------------------|
|
||||
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.]], [3., 4.]], &Device::Cpu)?` |
|
||||
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?` |
|
||||
| Creation | `torch.zeros((2, 2))` | `Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?` |
|
||||
| Indexing | `tensor[:, :4]` | `tensor.i((.., ..4))?` |
|
||||
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
|
||||
@ -95,43 +102,46 @@ Cheatsheet:
|
||||
## Structure
|
||||
|
||||
- [candle-core](./candle-core): Core ops, devices, and `Tensor` struct definition
|
||||
- [candle-nn](./candle-nn/): Facilities to build real models
|
||||
- [candle-examples](./candle-examples/): Real-world like examples on how to use the library in real settings
|
||||
- [candle-nn](./candle-nn/): Tools to build real models
|
||||
- [candle-examples](./candle-examples/): Examples of using the library in realistic settings
|
||||
- [candle-kernels](./candle-kernels/): CUDA custom kernels
|
||||
|
||||
|
||||
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
||||
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
||||
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
||||
|
||||
## FAQ
|
||||
|
||||
### Why Candle?
|
||||
### Why should I use Candle?
|
||||
|
||||
Candle stems from the need to reduce binary size in order to *enable serverless*
|
||||
possible by making the whole engine smaller than PyTorch very large library volume.
|
||||
This enables creating runtimes on a cluster much faster.
|
||||
Candle's core goal is to *make serverless inference possible*. Full machine learning frameworks like PyTorch
|
||||
are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight
|
||||
binaries.
|
||||
|
||||
And simply *removing Python* from production workloads.
|
||||
Python can really add overhead in more complex workflows and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
|
||||
Secondly, Candle lets you *remove Python* from production workloads. Python overhead can seriously hurt performance,
|
||||
and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
|
||||
|
||||
Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
|
||||
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
|
||||
|
||||
|
||||
### Other ML frameworks
|
||||
|
||||
- [dfdx](https://github.com/coreylowman/dfdx) is a formidable crate, with shapes being included
|
||||
in types preventing a lot of headaches by getting compiler to complain about shape mismatch right off the bat
|
||||
However we found that some features still require nightly and writing code can be a bit dauting for non rust experts.
|
||||
in types. This prevents a lot of headaches by getting the compiler to complain about shape mismatches right off the bat.
|
||||
However, we found that some features still require nightly, and writing code can be a bit daunting for non rust experts.
|
||||
|
||||
We're leveraging and contributing to other core crates for the runtime so hopefully both crates can benefit from each
|
||||
other
|
||||
other.
|
||||
|
||||
- [burn](https://github.com/burn-rs/burn) is a general crate that can leverage multiple backends so you can choose the best
|
||||
engine for your workload
|
||||
engine for your workload.
|
||||
|
||||
- [tch-rs](https://github.com/LaurentMazare/tch-rs.git) Bindings to the torch library in Rust. Extremely versatile, but they
|
||||
do bring in the entire torch library into the runtime. The main contributor of `tch-rs` is also involved in the development
|
||||
bring in the entire torch library into the runtime. The main contributor of `tch-rs` is also involved in the development
|
||||
of `candle`.
|
||||
|
||||
### Missing symbols when compiling with the mkl feature.
|
||||
### Common Errors
|
||||
|
||||
#### Missing symbols when compiling with the mkl feature.
|
||||
|
||||
If you get some missing symbols when compiling binaries/tests using the mkl
|
||||
features, e.g.:
|
||||
@ -144,13 +154,25 @@ features, e.g.:
|
||||
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo (see https://doc.rust-lang.org/cargo/reference/build-scripts.html#cargorustc-link-libkindname)
|
||||
```
|
||||
|
||||
This is likely due to some missing linker flag that enable the mkl library. You
|
||||
This is likely due to a missing linker flag that was needed to enable the mkl library. You
|
||||
can try adding the following at the top of your binary:
|
||||
```
|
||||
extern crate intel_mkl_src;
|
||||
```
|
||||
|
||||
### How to know where an error comes from.
|
||||
#### Cannot run llama example : access to source requires login credentials
|
||||
|
||||
```
|
||||
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401
|
||||
```
|
||||
|
||||
This is likely because you're not permissioned for the llama-v2 model. To fix
|
||||
this, you have to register on the huggingface-hub, accept the [llama-v2 model
|
||||
conditions](https://huggingface.co/meta-llama/Llama-2-7b-hf), and set up your
|
||||
authentication token. See issue
|
||||
[#350](https://github.com/huggingface/candle/issues/350) for more details.
|
||||
|
||||
#### Tracking down errors
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
error is generated.
|
||||
|
@ -12,16 +12,16 @@
|
||||
|
||||
- [Running a model](inference/README.md)
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Serialization](inference/serialization.md)
|
||||
- [Advanced Cuda usage](inference/cuda/README.md)
|
||||
- [Writing a custom kernel](inference/cuda/writing.md)
|
||||
- [Porting a custom kernel](inference/cuda/porting.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Creating apps](apps/README.md)
|
||||
- [Creating a WASM app](apps/wasm.md)
|
||||
- [Creating a REST api webserver](apps/rest.md)
|
||||
- [Creating a desktop Tauri app](apps/dekstop.md)
|
||||
- [Training](training/README.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
- [Fine-tuning](training/finetuning.md)
|
||||
- [Using MKL](advanced/mkl.md)
|
||||
- [Error management]()
|
||||
- [Advanced Cuda usage]()
|
||||
- [Writing a custom kernel]()
|
||||
- [Porting a custom kernel]()
|
||||
- [Using MKL]()
|
||||
- [Creating apps]()
|
||||
- [Creating a WASM app]()
|
||||
- [Creating a REST api webserver]()
|
||||
- [Creating a desktop Tauri app]()
|
||||
- [Training]()
|
||||
- [MNIST]()
|
||||
- [Fine-tuning]()
|
||||
- [Serialization]()
|
||||
|
1
candle-book/src/cuda/README.md
Normal file
1
candle-book/src/cuda/README.md
Normal file
@ -0,0 +1 @@
|
||||
# Advanced Cuda usage
|
1
candle-book/src/cuda/porting.md
Normal file
1
candle-book/src/cuda/porting.md
Normal file
@ -0,0 +1 @@
|
||||
# Porting a custom kernel
|
1
candle-book/src/cuda/writing.md
Normal file
1
candle-book/src/cuda/writing.md
Normal file
@ -0,0 +1 @@
|
||||
# Writing a custom kernel
|
@ -1 +1,51 @@
|
||||
# Error management
|
||||
|
||||
You might have seen in the code base a lot of `.unwrap()` or `?`.
|
||||
If you're unfamiliar with Rust check out the [Rust book](https://doc.rust-lang.org/book/ch09-02-recoverable-errors-with-result.html)
|
||||
for more information.
|
||||
|
||||
What's important to know though, is that if you want to know *where* a particular operation failed
|
||||
You can simply use `RUST_BACKTRACE=1` to get the location of where the model actually failed.
|
||||
|
||||
Let's see on failing code:
|
||||
|
||||
```rust,ignore
|
||||
let x = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
let y = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
let z = x.matmul(&y)?;
|
||||
```
|
||||
|
||||
Will print at runtime:
|
||||
|
||||
```bash
|
||||
Error: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }
|
||||
```
|
||||
|
||||
|
||||
After adding `RUST_BACKTRACE=1`:
|
||||
|
||||
|
||||
```bash
|
||||
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
|
||||
```
|
||||
|
||||
Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
|
||||
|
||||
|
||||
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
|
||||
especially in release builds. We're using [`anyhow`](https://docs.rs/anyhow/latest/anyhow/) for that.
|
||||
The library is still young, please [report](https://github.com/LaurentMazare/candle/issues) any issues detecting where an error is coming from.
|
||||
|
||||
## Cuda error management
|
||||
|
||||
When running a model on Cuda, you might get a stacktrace not really representing the error.
|
||||
The reason is that CUDA is async by nature, and therefore the error might be caught while you were sending totally different kernels.
|
||||
|
||||
One way to avoid this is to use `CUDA_LAUNCH_BLOCKING=1` as an environment variable. This will force every kernel to be launched sequentially.
|
||||
You might still however see the error happening on other kernels as the faulty kernel might exit without an error but spoiling some pointer for which the error will happen when dropping the `CudaSlice` only.
|
||||
|
||||
|
||||
If this occurs, you can use [`compute-sanitizer`](https://docs.nvidia.com/compute-sanitizer/ComputeSanitizer/index.html)
|
||||
This tool is like `valgrind` but for cuda. It will help locate the errors in the kernels.
|
||||
|
||||
|
||||
|
@ -128,17 +128,17 @@ fn main() -> Result<()> {
|
||||
```
|
||||
|
||||
Now it works, it is a great way to create your own layers.
|
||||
But most of the classical layers are already implemented in [candle-nn](https://github.com/LaurentMazare/candle/tree/main/candle-nn).
|
||||
But most of the classical layers are already implemented in [candle-nn](https://github.com/huggingface/candle/tree/main/candle-nn).
|
||||
|
||||
## Using `candle_nn`.
|
||||
|
||||
For instance [Linear](https://github.com/LaurentMazare/candle/blob/main/candle-nn/src/linear.rs) is already there.
|
||||
For instance [Linear](https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs) is already there.
|
||||
This Linear is coded with PyTorch layout in mind, to reuse better existing models out there, so it uses the transpose of the weights and not the weights directly.
|
||||
|
||||
So instead we can simplify our example:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/LaurentMazare/candle.git candle-nn
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-nn
|
||||
```
|
||||
|
||||
And rewrite our examples using it
|
||||
|
@ -5,13 +5,13 @@ Start by creating a new app:
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
cargo add --git https://github.com/LaurentMazare/candle.git candle
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core
|
||||
```
|
||||
|
||||
At this point, candle will be built **without** CUDA support.
|
||||
To get CUDA support use the `cuda` feature
|
||||
```bash
|
||||
cargo add --git https://github.com/LaurentMazare/candle.git candle --features cuda
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features cuda
|
||||
```
|
||||
|
||||
You can check everything works properly:
|
||||
|
@ -1 +1,7 @@
|
||||
# Running a model
|
||||
|
||||
|
||||
In order to run an existing model, you will need to download and use existing weights.
|
||||
Most models are already available on https://huggingface.co/ in [`safetensors`](https://github.com/huggingface/safetensors) format.
|
||||
|
||||
Let's get started by running an old model : `bert-base-uncased`.
|
||||
|
@ -1 +1,104 @@
|
||||
# Using the hub
|
||||
|
||||
Install the [`hf-hub`](https://github.com/huggingface/hf-hub) crate:
|
||||
|
||||
```bash
|
||||
cargo add hf-hub
|
||||
```
|
||||
|
||||
Then let's start by downloading the [model file](https://huggingface.co/bert-base-uncased/tree/main).
|
||||
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# extern crate hf_hub;
|
||||
use hf_hub::api::sync::Api;
|
||||
use candle_core::Device;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
|
||||
let weights = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let weights = candle_core::safetensors::load(weights, &Device::Cpu);
|
||||
```
|
||||
|
||||
We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.
|
||||
|
||||
You can check all the names of the tensors [here](https://huggingface.co/bert-base-uncased?show_tensors=true)
|
||||
|
||||
|
||||
## Using async
|
||||
|
||||
`hf-hub` comes with an async API.
|
||||
|
||||
```bash
|
||||
cargo add hf-hub --features tokio
|
||||
```
|
||||
|
||||
```rust,ignore
|
||||
# This is tested directly in examples crate because it needs external dependencies unfortunately:
|
||||
# See [this](https://github.com/rust-lang/mdBook/issues/706)
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
|
||||
```
|
||||
|
||||
|
||||
## Using in a real model.
|
||||
|
||||
Now that we have our weights, we can use them in our bert architecture:
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# extern crate candle_nn;
|
||||
# extern crate hf_hub;
|
||||
# use hf_hub::api::sync::Api;
|
||||
#
|
||||
# let api = Api::new().unwrap();
|
||||
# let repo = api.model("bert-base-uncased".to_string());
|
||||
#
|
||||
# let weights = repo.get("model.safetensors").unwrap();
|
||||
use candle_core::{Device, Tensor, DType};
|
||||
use candle_nn::Linear;
|
||||
|
||||
let weights = candle_core::safetensors::load(weights, &Device::Cpu).unwrap();
|
||||
|
||||
let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
|
||||
let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();
|
||||
|
||||
let linear = Linear::new(weight.clone(), Some(bias.clone()));
|
||||
|
||||
let input_ids = Tensor::zeros((3, 768), DType::F32, &Device::Cpu).unwrap();
|
||||
let output = linear.forward(&input_ids).unwrap();
|
||||
```
|
||||
|
||||
For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example.
|
||||
|
||||
## Memory mapping
|
||||
|
||||
For more efficient loading, instead of reading the file, you could use [`memmap2`](https://docs.rs/memmap2/latest/memmap2/)
|
||||
|
||||
**Note**: Be careful about memory mapping it seems to cause issues on [Windows, WSL](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5893)
|
||||
and will definitely be slower on network mounted disk, because it will issue more read calls.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
|
||||
```
|
||||
|
||||
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
|
||||
In practice model files should never be modified, and the mmaps should be mostly READONLY anyway, so the caveat most likely does not apply, but always keep it in mind.
|
||||
|
||||
|
||||
## Tensor Parallel Sharding
|
||||
|
||||
When using multiple GPUs to use in Tensor Parallel in order to get good latency, you can load only the part of the Tensor you need.
|
||||
|
||||
For that you need to use [`safetensors`](https://crates.io/crates/safetensors) directly.
|
||||
|
||||
```bash
|
||||
cargo add safetensors
|
||||
```
|
||||
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
|
||||
```
|
||||
|
@ -10,8 +10,9 @@ license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.1.0", optional = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.1.1", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
@ -21,14 +22,19 @@ memmap2 = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
num_cpus = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rand_distr = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["dep:cudarc", "dep:candle-kernels"]
|
||||
cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
|
@ -1,29 +1,18 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
||||
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
||||
let c = a.matmul(&b)?;
|
||||
println!("{a} {b} {c}");
|
||||
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
|
||||
let t1 = Tensor::new(data, &Device::Cpu)?;
|
||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];
|
||||
let t2 = Tensor::new(data2, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
Tensor::cat(&[&t1.t()?, &t2.t()?], 1)?
|
||||
.t()?
|
||||
.to_vec2::<f32>()?,
|
||||
[
|
||||
[3.0, 1.0, 4.0, 1.0, 5.0],
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0],
|
||||
[5.0, 5.0, 5.0, 5.0, 5.0],
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0]
|
||||
]
|
||||
);
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
let start = std::time::Instant::now();
|
||||
let res = inp.conv2d(&w, 0, 1);
|
||||
println!("{:?}", start.elapsed());
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
142
candle-core/examples/cpu_benchmarks.rs
Normal file
142
candle-core/examples/cpu_benchmarks.rs
Normal file
@ -0,0 +1,142 @@
|
||||
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_core::{Device, Result, Tensor, D};
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
fn softmax<D: candle_core::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
|
||||
let dim = dim.to_index(xs.shape(), "softmax")?;
|
||||
let max = xs.max_keepdim(dim)?;
|
||||
let diff = xs.broadcast_sub(&max)?;
|
||||
let num = diff.exp()?;
|
||||
let den = num.sum_keepdim(dim)?;
|
||||
num.broadcast_div(&den)
|
||||
}
|
||||
|
||||
trait Benchmark {
|
||||
type PreProcessData;
|
||||
type RunResult;
|
||||
|
||||
fn preprocess() -> Result<Self::PreProcessData>;
|
||||
fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>;
|
||||
|
||||
const ITERS: usize;
|
||||
}
|
||||
|
||||
// Conv1d example as used in whisper.
|
||||
struct Conv1d;
|
||||
impl Benchmark for Conv1d {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
|
||||
Ok((inp, w))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv1d(&d.1, 0, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 5;
|
||||
}
|
||||
|
||||
// Conv2d example as used in stable-diffusion.
|
||||
struct Conv2d;
|
||||
impl Benchmark for Conv2d {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
Ok((inp, w))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv2d(&d.1, 0, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 1;
|
||||
}
|
||||
|
||||
struct Matmul;
|
||||
impl Benchmark for Matmul {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.matmul(&d.1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
struct Softmax;
|
||||
impl Benchmark for Softmax {
|
||||
type PreProcessData = Tensor;
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
// Typical whisper tiny size.
|
||||
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
softmax(d, D::Minus1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
|
||||
use std::hint::black_box;
|
||||
|
||||
let iters = iters.unwrap_or(B::ITERS);
|
||||
let d = B::preprocess()?;
|
||||
let start = std::time::Instant::now();
|
||||
for _iter in 0..iters {
|
||||
let _res = black_box(B::run_one(black_box(&d))?);
|
||||
}
|
||||
println!("{:?}", start.elapsed() / iters as u32);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Task {
|
||||
Conv1d,
|
||||
Conv2d,
|
||||
Matmul,
|
||||
Softmax,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// The benchmark to be run.
|
||||
#[command(subcommand)]
|
||||
task: Task,
|
||||
|
||||
#[arg(long)]
|
||||
iters: Option<usize>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
match args.task {
|
||||
Task::Conv1d => run::<Conv1d>(args.iters)?,
|
||||
Task::Conv2d => run::<Conv2d>(args.iters)?,
|
||||
Task::Matmul => run::<Matmul>(args.iters)?,
|
||||
Task::Softmax => run::<Softmax>(args.iters)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,3 +1,6 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
@ -6,10 +9,9 @@ use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
|
||||
let sum = t.sum_keepdim(0)?;
|
||||
println!("{sum}");
|
||||
let sum = t.sum_keepdim(1)?;
|
||||
println!("{sum}");
|
||||
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
||||
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
||||
let res = t.conv2d(&w, 1, 1)?;
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,6 +1,9 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::Result;
|
||||
|
350
candle-core/src/accelerate.rs
Normal file
350
candle-core/src/accelerate.rs
Normal file
@ -0,0 +1,350 @@
|
||||
#![allow(dead_code)]
|
||||
use libc::{c_char, c_double, c_float, c_int, c_long, c_ulong};
|
||||
|
||||
mod ffi {
|
||||
use super::*;
|
||||
extern "C" {
|
||||
// It would be nice to be able to switch to the NEWLAPACK version of the function but this
|
||||
// seems to trigger some link error. Available function names can be seen here:
|
||||
// /Library/Developer/CommandLineTools/SDKs/MacOSX13.3.sdk/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate.tbd
|
||||
#[link_name = "sgemm_"]
|
||||
pub fn sgemm_ffi(
|
||||
transa: *const c_char,
|
||||
transb: *const c_char,
|
||||
m: *const c_int,
|
||||
n: *const c_int,
|
||||
k: *const c_int,
|
||||
alpha: *const c_float,
|
||||
a: *const c_float,
|
||||
lda: *const c_int,
|
||||
b: *const c_float,
|
||||
ldb: *const c_int,
|
||||
beta: *const c_float,
|
||||
c: *mut c_float,
|
||||
ldc: *const c_int,
|
||||
);
|
||||
#[link_name = "dgemm_"]
|
||||
pub fn dgemm_ffi(
|
||||
transa: *const c_char,
|
||||
transb: *const c_char,
|
||||
m: *const c_int,
|
||||
n: *const c_int,
|
||||
k: *const c_int,
|
||||
alpha: *const c_double,
|
||||
a: *const c_double,
|
||||
lda: *const c_int,
|
||||
b: *const c_double,
|
||||
ldb: *const c_int,
|
||||
beta: *const c_double,
|
||||
c: *mut c_double,
|
||||
ldc: *const c_int,
|
||||
);
|
||||
|
||||
pub fn vvexpf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvexp(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvsqrtf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvsqrt(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvsinf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvsin(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvcosf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
|
||||
pub fn vDSP_vaddD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vadd(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vsubD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vsub(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmulD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmul(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vdivD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vdiv(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[inline]
|
||||
pub unsafe fn sgemm(
|
||||
transa: u8,
|
||||
transb: u8,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: f32,
|
||||
a: &[f32],
|
||||
lda: i32,
|
||||
b: &[f32],
|
||||
ldb: i32,
|
||||
beta: f32,
|
||||
c: &mut [f32],
|
||||
ldc: i32,
|
||||
) {
|
||||
ffi::sgemm_ffi(
|
||||
&(transa as c_char),
|
||||
&(transb as c_char),
|
||||
&m,
|
||||
&n,
|
||||
&k,
|
||||
&alpha,
|
||||
a.as_ptr(),
|
||||
&lda,
|
||||
b.as_ptr(),
|
||||
&ldb,
|
||||
&beta,
|
||||
c.as_mut_ptr(),
|
||||
&ldc,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[inline]
|
||||
pub unsafe fn dgemm(
|
||||
transa: u8,
|
||||
transb: u8,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: f64,
|
||||
a: &[f64],
|
||||
lda: i32,
|
||||
b: &[f64],
|
||||
ldb: i32,
|
||||
beta: f64,
|
||||
c: &mut [f64],
|
||||
ldc: i32,
|
||||
) {
|
||||
ffi::dgemm_ffi(
|
||||
&(transa as c_char),
|
||||
&(transb as c_char),
|
||||
&m,
|
||||
&n,
|
||||
&k,
|
||||
&alpha,
|
||||
a.as_ptr(),
|
||||
&lda,
|
||||
b.as_ptr(),
|
||||
&ldb,
|
||||
&beta,
|
||||
c.as_mut_ptr(),
|
||||
&ldc,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvexpf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvexp(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_sqrt(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvsqrtf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_sqrt(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvsqrt(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_sin(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvsinf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_sin(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvsin(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
#[inline]
|
||||
pub fn vs_cos(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvcosf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_cos(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
#[inline]
|
||||
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvlogf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_ln(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvlog(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_sqr(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||
#[inline]
|
||||
pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) {
|
||||
let a_len = a.len();
|
||||
let b_len = b.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len || b_len != y_len {
|
||||
panic!(
|
||||
"{} a,b,y len mismatch {a_len} {b_len} {y_len}",
|
||||
stringify!($fn_name)
|
||||
);
|
||||
}
|
||||
unsafe {
|
||||
// Weird quirk of accelerate, the rhs comes before the lhs.
|
||||
ffi::$accelerate_name(
|
||||
b.as_ptr(),
|
||||
1,
|
||||
a.as_ptr(),
|
||||
1,
|
||||
y.as_mut_ptr(),
|
||||
1,
|
||||
a_len as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
binary_op!(vs_add, f32, vDSP_vadd);
|
||||
binary_op!(vd_add, f64, vDSP_vaddD);
|
||||
binary_op!(vs_sub, f32, vDSP_vsub);
|
||||
binary_op!(vd_sub, f64, vDSP_vsubD);
|
||||
binary_op!(vs_mul, f32, vDSP_vmul);
|
||||
binary_op!(vd_mul, f64, vDSP_vmulD);
|
||||
binary_op!(vs_div, f32, vDSP_vdiv);
|
||||
binary_op!(vd_div, f64, vDSP_vdivD);
|
@ -37,6 +37,18 @@ pub trait BackendStorage: Sized {
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn scatter_add(
|
||||
&self,
|
||||
|
@ -55,6 +55,11 @@ impl Tensor {
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::Conv2D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::CustomOp2(lhs, rhs, _)
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Gather(lhs, rhs, _)
|
||||
@ -81,6 +86,9 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest2D(node)
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
| Op::Copy(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Cmp(node, _)
|
||||
@ -163,6 +171,12 @@ impl Tensor {
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
|
||||
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
|
||||
Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
Op::Gather(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
@ -291,6 +305,11 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&grad)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Recip) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let grad = (grad / arg.sqr()?)?;
|
||||
*sum_grad = sum_grad.sub(&grad)?
|
||||
}
|
||||
&Op::Narrow(ref arg, dim, start_idx, len) => {
|
||||
let arg_dims = arg.dims();
|
||||
let left_pad = if start_idx == 0 {
|
||||
|
@ -1,6 +1,6 @@
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConv1D {
|
||||
pub(crate) b_size: Option<usize>,
|
||||
pub(crate) b_size: usize,
|
||||
// Maybe we should have a version without l_in as this bit depends on the input and not only on
|
||||
// the weights.
|
||||
pub(crate) l_in: usize,
|
||||
@ -19,9 +19,35 @@ impl ParamsConv1D {
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
let l_out = self.l_out();
|
||||
match self.b_size {
|
||||
None => vec![self.c_out, l_out],
|
||||
Some(n) => vec![n, self.c_out, l_out],
|
||||
}
|
||||
vec![self.b_size, self.c_out, l_out]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConv2D {
|
||||
pub(crate) b_size: usize,
|
||||
pub(crate) i_h: usize,
|
||||
pub(crate) i_w: usize,
|
||||
pub(crate) k_h: usize,
|
||||
pub(crate) k_w: usize,
|
||||
pub(crate) c_out: usize,
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
}
|
||||
|
||||
impl ParamsConv2D {
|
||||
pub(crate) fn out_h(&self) -> usize {
|
||||
let dilation = 1;
|
||||
(self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_w(&self) -> usize {
|
||||
let dilation = 1;
|
||||
(self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! Implement conversion traits for tensors
|
||||
use crate::{Device, Error, Tensor, WithDType};
|
||||
use half::{bf16, f16};
|
||||
use crate::{DType, Device, Error, Tensor, WithDType};
|
||||
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
||||
use std::convert::TryFrom;
|
||||
|
||||
impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
|
||||
@ -94,3 +94,46 @@ from_tensor!(f16);
|
||||
from_tensor!(bf16);
|
||||
from_tensor!(u32);
|
||||
from_tensor!(u8);
|
||||
|
||||
impl Tensor {
|
||||
pub fn write_bytes<W: std::io::Write>(&self, f: &mut W) -> crate::Result<()> {
|
||||
use byteorder::{LittleEndian, WriteBytesExt};
|
||||
|
||||
let vs = self.flatten_all()?;
|
||||
match self.dtype() {
|
||||
DType::BF16 => {
|
||||
let vs = vs.to_vec1::<bf16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
let vs = vs.to_vec1::<f16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F32 => {
|
||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||
for v in vs.to_vec1::<f32>()? {
|
||||
f.write_f32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F64 => {
|
||||
for v in vs.to_vec1::<f64>()? {
|
||||
f.write_f64::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U32 => {
|
||||
for v in vs.to_vec1::<u32>()? {
|
||||
f.write_u32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U8 => {
|
||||
let vs = vs.to_vec1::<u8>()?;
|
||||
f.write_all(&vs)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
148
candle-core/src/cpu/avx.rs
Normal file
148
candle-core/src/cpu/avx.rs
Normal file
@ -0,0 +1,148 @@
|
||||
use super::{Cpu, CpuF16};
|
||||
#[cfg(target_arch = "x86")]
|
||||
use core::arch::x86::*;
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
use core::arch::x86_64::*;
|
||||
|
||||
use half::f16;
|
||||
|
||||
pub struct CurrentCpu {}
|
||||
|
||||
const STEP: usize = 32;
|
||||
const EPR: usize = 8;
|
||||
const ARR: usize = STEP / EPR;
|
||||
|
||||
impl Cpu<ARR> for CurrentCpu {
|
||||
type Unit = __m256;
|
||||
type Array = [__m256; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
_mm256_setzero_ps()
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit {
|
||||
_mm256_set1_ps(v)
|
||||
}
|
||||
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
|
||||
_mm256_loadu_ps(mem_addr)
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(_mm256_mul_ps(b, c), a)
|
||||
}
|
||||
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
|
||||
_mm256_storeu_ps(mem_addr, a);
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
for i in 0..ARR / 2 {
|
||||
x[2 * i] = _mm256_add_ps(x[2 * i], x[2 * i + 1]);
|
||||
}
|
||||
for i in 0..ARR / 4 {
|
||||
x[4 * i] = _mm256_add_ps(x[4 * i], x[4 * i + 2]);
|
||||
}
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
for i in 0..ARR / 8 {
|
||||
x[8 * i] = _mm256_add_ps(x[8 * i], x[8 * i + 4]);
|
||||
}
|
||||
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
|
||||
let t1 = _mm_hadd_ps(t0, t0);
|
||||
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CurrentCpuF16 {}
|
||||
impl CpuF16<ARR> for CurrentCpuF16 {
|
||||
type Unit = __m256;
|
||||
type Array = [__m256; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
_mm256_setzero_ps()
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit {
|
||||
_mm256_set1_ps(v)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "f16c")]
|
||||
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
|
||||
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "f16c"))]
|
||||
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
|
||||
let mut tmp = [0.0f32; 8];
|
||||
for i in 0..8 {
|
||||
tmp[i] = (*mem_addr.add(i)).to_f32();
|
||||
}
|
||||
_mm_loadu_ps(tmp.as_ptr())
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(_mm256_mul_ps(b, c), a)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "f16c")]
|
||||
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
|
||||
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "f16c"))]
|
||||
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
|
||||
let mut tmp = [0.0f32; 8];
|
||||
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
|
||||
for i in 0..8 {
|
||||
*mem_addr.add(i) = f16::from_f32(tmp[i]);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
let mut offset = ARR >> 1;
|
||||
for i in 0..offset {
|
||||
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
||||
}
|
||||
offset >>= 1;
|
||||
for i in 0..offset {
|
||||
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
||||
}
|
||||
offset >>= 1;
|
||||
for i in 0..offset {
|
||||
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
||||
}
|
||||
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
|
||||
let t1 = _mm_hadd_ps(t0, t0);
|
||||
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
|
||||
}
|
||||
}
|
89
candle-core/src/cpu/kernels.rs
Normal file
89
candle-core/src/cpu/kernels.rs
Normal file
@ -0,0 +1,89 @@
|
||||
pub trait VecOps: num_traits::NumAssign + Copy {
|
||||
/// Dot-product of two vectors.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = Self::zero();
|
||||
for i in 0..len {
|
||||
*res += *lhs.add(i) * *rhs.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sum of all elements in a vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len`. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = Self::zero();
|
||||
for i in 0..len {
|
||||
*res += *xs.add(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for f32 {
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
super::vec_dot_f32(lhs, rhs, res, len)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
|
||||
super::vec_sum(xs, res, len)
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for half::f16 {
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
let mut res_f32 = 0f32;
|
||||
super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
|
||||
*res = half::f16::from_f32(res_f32);
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for f64 {}
|
||||
impl VecOps for half::bf16 {}
|
||||
impl VecOps for u8 {}
|
||||
impl VecOps for u32 {}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
||||
if n_threads == 1 {
|
||||
func(0)
|
||||
} else {
|
||||
rayon::scope(|s| {
|
||||
for thread_idx in 0..n_threads {
|
||||
let func = &func;
|
||||
s.spawn(move |_| func(thread_idx));
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
||||
if n_threads == 1 {
|
||||
for i in lo..up {
|
||||
func(i)
|
||||
}
|
||||
} else {
|
||||
rayon::scope(|s| {
|
||||
for thread_idx in 0..n_threads {
|
||||
let func = &func;
|
||||
s.spawn(move |_| {
|
||||
for i in (thread_idx..up).step_by(n_threads) {
|
||||
func(i)
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
179
candle-core/src/cpu/mod.rs
Normal file
179
candle-core/src/cpu/mod.rs
Normal file
@ -0,0 +1,179 @@
|
||||
pub mod kernels;
|
||||
|
||||
trait Cpu<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
const STEP: usize;
|
||||
const EPR: usize;
|
||||
|
||||
fn n() -> usize;
|
||||
unsafe fn zero() -> Self::Unit;
|
||||
unsafe fn zero_array() -> Self::Array;
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit;
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit;
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
||||
}
|
||||
|
||||
trait CpuF16<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
const STEP: usize;
|
||||
const EPR: usize;
|
||||
|
||||
fn n() -> usize;
|
||||
unsafe fn zero() -> Self::Unit;
|
||||
unsafe fn zero_array() -> Self::Array;
|
||||
unsafe fn load(mem_addr: *const f16) -> Self::Unit;
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit;
|
||||
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
|
||||
}
|
||||
use half::f16;
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub use avx::{CurrentCpu, CurrentCpuF16};
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub use simd128::CurrentCpu;
|
||||
|
||||
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub use neon::CurrentCpu;
|
||||
|
||||
#[cfg(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
|
||||
let np = k & !(CurrentCpu::STEP - 1);
|
||||
|
||||
let mut sum = CurrentCpu::zero_array();
|
||||
let mut ax = CurrentCpu::zero_array();
|
||||
let mut ay = CurrentCpu::zero_array();
|
||||
|
||||
for i in (0..np).step_by(CurrentCpu::STEP) {
|
||||
for j in 0..CurrentCpu::n() {
|
||||
ax[j] = CurrentCpu::load(a_row.add(i + j * CurrentCpu::EPR));
|
||||
ay[j] = CurrentCpu::load(b_row.add(i + j * CurrentCpu::EPR));
|
||||
|
||||
sum[j] = CurrentCpu::vec_fma(sum[j], ax[j], ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
CurrentCpu::vec_reduce(sum, c);
|
||||
|
||||
// leftovers
|
||||
for i in np..k {
|
||||
*c += *a_row.add(i) * (*b_row.add(i));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
)))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
|
||||
// leftovers
|
||||
for i in 0..k {
|
||||
*c += *a_row.add(i) * (*b_row.add(i));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
|
||||
let np = k & !(CurrentCpu::STEP - 1);
|
||||
|
||||
let mut sum = CurrentCpu::zero_array();
|
||||
let mut x = CurrentCpu::zero_array();
|
||||
|
||||
for i in (0..np).step_by(CurrentCpu::STEP) {
|
||||
for j in 0..CurrentCpu::n() {
|
||||
x[j] = CurrentCpu::load(row.add(i + j * CurrentCpu::EPR));
|
||||
sum[j] = CurrentCpu::vec_add(sum[j], x[j]);
|
||||
}
|
||||
}
|
||||
|
||||
CurrentCpu::vec_reduce(sum, b);
|
||||
|
||||
// leftovers
|
||||
for i in np..k {
|
||||
*b += *row.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
)))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
|
||||
*b = 0f32;
|
||||
for i in 0..k {
|
||||
*b += *row.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
|
||||
let mut sumf = 0.0f32;
|
||||
let np = k & !(CurrentCpuF16::STEP - 1);
|
||||
|
||||
let mut sum = CurrentCpuF16::zero_array();
|
||||
let mut ax = CurrentCpuF16::zero_array();
|
||||
let mut ay = CurrentCpuF16::zero_array();
|
||||
|
||||
for i in (0..np).step_by(CurrentCpuF16::STEP) {
|
||||
for j in 0..CurrentCpuF16::n() {
|
||||
ax[j] = CurrentCpuF16::load(a_row.add(i + j * CurrentCpuF16::EPR));
|
||||
ay[j] = CurrentCpuF16::load(b_row.add(i + j * CurrentCpuF16::EPR));
|
||||
|
||||
sum[j] = CurrentCpuF16::vec_fma(sum[j], ax[j], ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
CurrentCpuF16::vec_reduce(sum, &mut sumf);
|
||||
|
||||
// leftovers
|
||||
for i in np..k {
|
||||
sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
|
||||
}
|
||||
*c = sumf;
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "avx"))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
|
||||
// leftovers
|
||||
let mut sum = 0.0;
|
||||
for i in 0..k {
|
||||
sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
|
||||
}
|
||||
*c = sum;
|
||||
}
|
74
candle-core/src/cpu/neon.rs
Normal file
74
candle-core/src/cpu/neon.rs
Normal file
@ -0,0 +1,74 @@
|
||||
use super::Cpu;
|
||||
#[cfg(target_arch = "arm")]
|
||||
use core::arch::arm::*;
|
||||
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
use core::arch::aarch64::*;
|
||||
|
||||
pub struct CurrentCpu {}
|
||||
|
||||
const STEP: usize = 16;
|
||||
const EPR: usize = 4;
|
||||
const ARR: usize = STEP / EPR;
|
||||
|
||||
impl CurrentCpu {
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
unsafe fn reduce_one(x: float32x4_t) -> f32 {
|
||||
vaddvq_f32(x)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "arm")]
|
||||
unsafe fn reduce_one(x: float32x4_t) -> f32 {
|
||||
vgetq_lane_f32(x, 0) + vgetq_lane_f32(x, 1) + vgetq_lane_f32(x, 2) + vgetq_lane_f32(x, 3)
|
||||
}
|
||||
}
|
||||
|
||||
impl Cpu<ARR> for CurrentCpu {
|
||||
type Unit = float32x4_t;
|
||||
type Array = [float32x4_t; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
vdupq_n_f32(0.0)
|
||||
}
|
||||
|
||||
unsafe fn from_f32(x: f32) -> Self::Unit {
|
||||
vdupq_n_f32(x)
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
|
||||
vld1q_f32(mem_addr)
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
vaddq_f32(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
vfmaq_f32(a, b, c)
|
||||
}
|
||||
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
|
||||
vst1q_f32(mem_addr, a);
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
for i in 0..ARR / 2 {
|
||||
x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]);
|
||||
}
|
||||
for i in 0..ARR / 4 {
|
||||
x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]);
|
||||
}
|
||||
*y = Self::reduce_one(x[0]);
|
||||
}
|
||||
}
|
64
candle-core/src/cpu/simd128.rs
Normal file
64
candle-core/src/cpu/simd128.rs
Normal file
@ -0,0 +1,64 @@
|
||||
use super::Cpu;
|
||||
use core::arch::wasm32::*;
|
||||
|
||||
pub struct CurrentCpu {}
|
||||
|
||||
const STEP: usize = 16;
|
||||
const EPR: usize = 4;
|
||||
const ARR: usize = STEP / EPR;
|
||||
|
||||
impl Cpu<ARR> for CurrentCpu {
|
||||
type Unit = v128;
|
||||
type Array = [v128; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
f32x4_splat(0.0)
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit {
|
||||
f32x4_splat(v)
|
||||
}
|
||||
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
|
||||
v128_load(mem_addr as *mut v128)
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
f32x4_add(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
f32x4_add(f32x4_mul(b, c), a)
|
||||
}
|
||||
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
|
||||
v128_store(mem_addr as *mut v128, a);
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
for i in 0..ARR / 2 {
|
||||
x[2 * i] = f32x4_add(x[2 * i], x[2 * i + 1]);
|
||||
}
|
||||
for i in 0..ARR / 4 {
|
||||
x[4 * i] = f32x4_add(x[4 * i], x[4 * i + 2]);
|
||||
}
|
||||
for i in 0..ARR / 8 {
|
||||
x[8 * i] = f32x4_add(x[8 * i], x[8 * i + 4]);
|
||||
}
|
||||
*y = f32x4_extract_lane::<0>(x[0])
|
||||
+ f32x4_extract_lane::<1>(x[0])
|
||||
+ f32x4_extract_lane::<2>(x[0])
|
||||
+ f32x4_extract_lane::<3>(x[0]);
|
||||
}
|
||||
}
|
@ -278,17 +278,17 @@ impl Map1Any for ReduceIndex {
|
||||
}
|
||||
}
|
||||
|
||||
struct Reduce<'a> {
|
||||
struct ReduceSum<'a> {
|
||||
dst_shape: &'a Shape,
|
||||
reduce_dims: &'a [usize],
|
||||
reduce_dims_and_stride: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl<'a> Reduce<'a> {
|
||||
impl<'a> ReduceSum<'a> {
|
||||
#[inline(always)]
|
||||
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
|
||||
where
|
||||
T: Clone + Copy,
|
||||
T: WithDType,
|
||||
F: Fn(T, T) -> T,
|
||||
{
|
||||
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
|
||||
@ -310,12 +310,15 @@ impl<'a> Reduce<'a> {
|
||||
.iter()
|
||||
.map(|(u, _)| u)
|
||||
.product::<usize>();
|
||||
let mut src_i = 0;
|
||||
for dst_v in dst.iter_mut() {
|
||||
for &s in src[src_i..src_i + reduce_sz].iter() {
|
||||
*dst_v = f(*dst_v, s)
|
||||
}
|
||||
src_i += reduce_sz
|
||||
for (dst_i, dst_v) in dst.iter_mut().enumerate() {
|
||||
let src_i = dst_i * reduce_sz;
|
||||
unsafe {
|
||||
T::vec_reduce_sum(
|
||||
src[src_i..src_i + reduce_sz].as_ptr(),
|
||||
dst_v,
|
||||
reduce_sz,
|
||||
)
|
||||
};
|
||||
}
|
||||
return Ok(dst);
|
||||
};
|
||||
@ -347,7 +350,7 @@ impl<'a> Reduce<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Reduce<'a> {
|
||||
impl<'a> Map1 for ReduceSum<'a> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
|
||||
@ -633,6 +636,126 @@ impl Map1 for Affine {
|
||||
}
|
||||
}
|
||||
|
||||
struct AvgPool2D((usize, usize), (usize, usize));
|
||||
|
||||
impl Map1 for AvgPool2D {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
|
||||
let (k_h, k_w) = self.0;
|
||||
let (s_h, s_w) = self.1;
|
||||
let (b_sz, c, h, w) = layout.shape().dims4()?;
|
||||
let stride = layout.stride();
|
||||
let (stride_h, stride_w) = (stride[2], stride[3]);
|
||||
let h_out = (h - k_h) / s_h + 1;
|
||||
let w_out = (w - k_w) / s_w + 1;
|
||||
let src_index = layout.start_offset();
|
||||
let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
|
||||
let scale = 1f64 / (k_h * k_w) as f64;
|
||||
let scale = T::from_f64(scale);
|
||||
for b_idx in 0..b_sz {
|
||||
let dst = &mut dst[b_idx * c * h_out * w_out..];
|
||||
let src_index = src_index + b_idx * stride[0];
|
||||
for c_idx in 0..c {
|
||||
let dst = &mut dst[c_idx * h_out * w_out..];
|
||||
let src_index = src_index + c_idx * stride[1];
|
||||
for h_idx in 0..h_out {
|
||||
for w_idx in 0..w_out {
|
||||
let mut sum = T::zero();
|
||||
for m in 0..k_h {
|
||||
for n in 0..k_w {
|
||||
let m = s_h * h_idx + m;
|
||||
let n = s_w * w_idx + n;
|
||||
sum += src[src_index + m * stride_h + n * stride_w]
|
||||
}
|
||||
}
|
||||
dst[h_idx * w_out + w_idx] = sum * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct MaxPool2D((usize, usize), (usize, usize));
|
||||
|
||||
impl Map1 for MaxPool2D {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
|
||||
let (k_h, k_w) = self.0;
|
||||
let (s_h, s_w) = self.1;
|
||||
let (b_sz, c, h, w) = layout.shape().dims4()?;
|
||||
let stride = layout.stride();
|
||||
let (stride_h, stride_w) = (stride[2], stride[3]);
|
||||
let h_out = (h - k_h) / s_h + 1;
|
||||
let w_out = (w - k_w) / s_w + 1;
|
||||
let src_index = layout.start_offset();
|
||||
let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
|
||||
for b_idx in 0..b_sz {
|
||||
let dst = &mut dst[b_idx * c * h_out * w_out..];
|
||||
let src_index = src_index + b_idx * stride[0];
|
||||
for c_idx in 0..c {
|
||||
let dst = &mut dst[c_idx * h_out * w_out..];
|
||||
let src_index = src_index + c_idx * stride[1];
|
||||
for h_idx in 0..h_out {
|
||||
for w_idx in 0..w_out {
|
||||
let mut largest =
|
||||
src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
|
||||
for m in 0..k_h {
|
||||
for n in 0..k_w {
|
||||
let m = s_h * h_idx + m;
|
||||
let n = s_w * w_idx + n;
|
||||
if largest < src[src_index + m * stride_h + n * stride_w] {
|
||||
largest = src[src_index + m * stride_h + n * stride_w]
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[h_idx * w_out + w_idx] = largest;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct UpsampleNearest2D(usize, usize);
|
||||
|
||||
impl Map1 for UpsampleNearest2D {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
// TODO: Specialized implementation for the case 2*h, 2*w?
|
||||
let (dst_h, dst_w) = (self.0, self.1);
|
||||
let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
|
||||
let stride = layout.stride();
|
||||
let (stride_h, stride_w) = (stride[2], stride[3]);
|
||||
let src_index = layout.start_offset();
|
||||
let scale_h = src_h as f64 / dst_h as f64;
|
||||
let scale_w = src_w as f64 / dst_w as f64;
|
||||
let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
|
||||
let src_h_idxs = (0..dst_h)
|
||||
.map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
|
||||
.collect::<Vec<_>>();
|
||||
let src_w_idxs = (0..dst_w)
|
||||
.map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
|
||||
.collect::<Vec<_>>();
|
||||
for b_idx in 0..b_sz {
|
||||
let dst = &mut dst[b_idx * c * dst_h * dst_w..];
|
||||
let src_index = src_index + b_idx * stride[0];
|
||||
for c_idx in 0..c {
|
||||
let dst = &mut dst[c_idx * dst_h * dst_w..];
|
||||
let src_index = src_index + c_idx * stride[1];
|
||||
for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
|
||||
for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
|
||||
let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
|
||||
dst[h_idx * dst_w + w_idx] = src[src_index]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Gather<'a, I: IntDType> {
|
||||
ids: &'a [I],
|
||||
ids_l: &'a Layout,
|
||||
@ -903,56 +1026,152 @@ struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
||||
|
||||
impl<'a> Map2 for Conv1D<'a> {
|
||||
const OP: &'static str = "conv1d";
|
||||
fn f<T: 'static + num_traits::NumAssign + Copy>(
|
||||
&self,
|
||||
inp: &[T],
|
||||
inp_l: &Layout,
|
||||
k: &[T],
|
||||
k_l: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
// TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc).
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let k = &k[k_l.start_offset()..];
|
||||
let inp_stride = inp_l.stride();
|
||||
let (inp_stride0, inp_stride) = if inp_stride.len() == 3 {
|
||||
(inp_stride[0], &inp_stride[1..])
|
||||
} else {
|
||||
(0, inp_stride) // This value never gets used anyway
|
||||
};
|
||||
let k_stride = k_l.stride();
|
||||
let k_over_2 = p.k_size / 2;
|
||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||
let l_out = p.l_out();
|
||||
let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
|
||||
let mut dst = vec![T::zero(); dst_elems];
|
||||
let dst_elems = p.c_out * l_out * p.b_size;
|
||||
// The output shape is [b_size, c_out, l_out]
|
||||
for b_idx in 0..p.b_size.unwrap_or(1) {
|
||||
let inp_idx = b_idx * inp_stride0;
|
||||
let dst_idx = b_idx * p.c_out * l_out;
|
||||
for dst_c_idx in 0..p.c_out {
|
||||
let dst_idx = dst_idx + dst_c_idx * l_out;
|
||||
for dst_l in 0..l_out {
|
||||
let dst_idx = dst_idx + dst_l;
|
||||
let mut d = T::zero();
|
||||
for offset in 0..p.k_size {
|
||||
let src_l_plus = p.stride * dst_l + offset;
|
||||
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
|
||||
if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in {
|
||||
let src_l = src_l_plus - k_over_2;
|
||||
for src_c_idx in 0..p.c_in {
|
||||
let inp_idx =
|
||||
inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
|
||||
let k_idx = dst_c_idx * k_stride[0]
|
||||
+ src_c_idx * k_stride[1]
|
||||
+ offset * k_stride[2];
|
||||
d += inp[inp_idx] * k[k_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[dst_idx] = d
|
||||
let dst = vec![T::zero(); dst_elems];
|
||||
|
||||
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
||||
for b_idx in 0..p.b_size {
|
||||
for src_l in 0..p.l_in {
|
||||
for src_c_idx in 0..p.c_in {
|
||||
let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
|
||||
inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let num_threads = crate::utils::get_num_threads();
|
||||
|
||||
for offset in 0..p.k_size {
|
||||
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
|
||||
let dst_idx = dst_c_idx * l_out;
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
|
||||
.collect::<Vec<_>>();
|
||||
for b_idx in 0..p.b_size {
|
||||
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
|
||||
for dst_l in 0..l_out {
|
||||
let dst_idx = dst_idx + dst_l;
|
||||
let src_l = p.stride * dst_l + offset;
|
||||
if src_l < p.padding || src_l >= p.padding + p.l_in {
|
||||
continue;
|
||||
}
|
||||
let src_l = src_l - p.padding;
|
||||
let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
|
||||
assert!(inp_cont.len() >= p.c_in);
|
||||
assert!(k_cont.len() >= p.c_in);
|
||||
let mut d = T::zero();
|
||||
unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
|
||||
let dst_p = dst.as_ptr();
|
||||
// Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
|
||||
// the different tasks so no two threads can try to write at the same
|
||||
// location.
|
||||
unsafe {
|
||||
let ptr = dst_p.add(dst_idx) as *mut T;
|
||||
*ptr += d
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||
|
||||
impl<'a> Map2 for Conv2D<'a> {
|
||||
const OP: &'static str = "conv2d";
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
|
||||
let k = &k[k_l.start_offset()..];
|
||||
let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
|
||||
let (out_h, out_w) = (p.out_h(), p.out_w());
|
||||
|
||||
// Output shape: [b_size, c_out, out_h, out_w].
|
||||
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
|
||||
|
||||
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
|
||||
let cont_s0 = p.i_h * p.i_w * p.c_in;
|
||||
let cont_s1 = p.i_w * p.c_in;
|
||||
let cont_s2 = p.c_in;
|
||||
for b_idx in 0..p.b_size {
|
||||
for h_idx in 0..p.i_h {
|
||||
for w_idx in 0..p.i_w {
|
||||
for c_idx in 0..p.c_in {
|
||||
let src_idx =
|
||||
b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
|
||||
let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
|
||||
inp_cont[dst_idx] = inp[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let num_threads = crate::utils::get_num_threads();
|
||||
|
||||
for offset_h in 0..p.k_h {
|
||||
for offset_w in 0..p.k_w {
|
||||
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
|
||||
let dst_idx = dst_c_idx * out_w * out_h;
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| {
|
||||
k[dst_c_idx * k_s0
|
||||
+ c_in_idx * k_s1
|
||||
+ offset_h * k_s2
|
||||
+ offset_w * k_s3]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
for b_idx in 0..p.b_size {
|
||||
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
|
||||
for dst_h in 0..out_h {
|
||||
let dst_idx = dst_idx + dst_h * out_w;
|
||||
let src_h = p.stride * dst_h + offset_h;
|
||||
if src_h < p.padding || src_h >= p.i_h + p.padding {
|
||||
continue;
|
||||
}
|
||||
let src_h = src_h - p.padding;
|
||||
for dst_w in 0..out_w {
|
||||
let dst_idx = dst_idx + dst_w;
|
||||
let src_w = p.stride * dst_w + offset_w;
|
||||
if src_w < p.padding || src_w >= p.i_w + p.padding {
|
||||
continue;
|
||||
}
|
||||
let src_w = src_w - p.padding;
|
||||
let inp_cont = &inp_cont
|
||||
[b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
|
||||
assert!(inp_cont.len() >= p.c_in);
|
||||
assert!(k_cont.len() >= p.c_in);
|
||||
let mut d = T::zero();
|
||||
unsafe {
|
||||
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
|
||||
}
|
||||
let dst_p = dst.as_ptr();
|
||||
// Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
|
||||
// the different tasks so no two threads can try to write at the same
|
||||
// location.
|
||||
unsafe {
|
||||
let ptr = dst_p.add(dst_idx) as *mut T;
|
||||
*ptr += d
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
@ -974,7 +1193,7 @@ impl MatMul {
|
||||
impl Map2 for MatMul {
|
||||
const OP: &'static str = "mat_mul";
|
||||
|
||||
#[cfg(not(feature = "mkl"))]
|
||||
#[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
|
||||
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
|
||||
&self,
|
||||
lhs: &[T],
|
||||
@ -1053,6 +1272,109 @@ impl Map2 for MatMul {
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
|
||||
&self,
|
||||
lhs: &[T],
|
||||
lhs_l: &Layout,
|
||||
rhs: &[T],
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
let (b, m, n, k) = self.0;
|
||||
let lhs = &lhs[lhs_l.start_offset()..];
|
||||
let rhs = &rhs[rhs_l.start_offset()..];
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rank = lhs_stride.len();
|
||||
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||
};
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
(n as i32, b'N')
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
(k as i32, b'T')
|
||||
} else {
|
||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
(k as i32, b'N')
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
(m as i32, b'T')
|
||||
} else {
|
||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
|
||||
};
|
||||
|
||||
let mut dst = vec![T::zero(); b * m * n];
|
||||
match T::DTYPE {
|
||||
DType::F16 => {
|
||||
crate::bail!("the accelerate backend does not support f16 matmul")
|
||||
}
|
||||
DType::F32 => {
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
let a = rhs_p.as_ptr() as *const f32;
|
||||
let b = lhs_p.as_ptr() as *const f32;
|
||||
let c = dst_p.as_mut_ptr() as *mut f32;
|
||||
let a = std::slice::from_raw_parts(a, a_skip);
|
||||
let b = std::slice::from_raw_parts(b, b_skip);
|
||||
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||
crate::accelerate::sgemm(
|
||||
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
||||
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
||||
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
||||
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
DType::F64 => {
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
let a = rhs_p.as_ptr() as *const f64;
|
||||
let b = lhs_p.as_ptr() as *const f64;
|
||||
let c = dst_p.as_mut_ptr() as *mut f64;
|
||||
let a = std::slice::from_raw_parts(a, a_skip);
|
||||
let b = std::slice::from_raw_parts(b, b_skip);
|
||||
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||
crate::accelerate::dgemm(
|
||||
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
||||
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
||||
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
||||
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
|
||||
&self,
|
||||
@ -1379,7 +1701,7 @@ impl BackendStorage for CpuStorage {
|
||||
.iter()
|
||||
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
|
||||
.collect();
|
||||
Reduce {
|
||||
ReduceSum {
|
||||
dst_shape: &dst_shape,
|
||||
reduce_dims: &reduce_dims,
|
||||
reduce_dims_and_stride,
|
||||
@ -1426,6 +1748,28 @@ impl BackendStorage for CpuStorage {
|
||||
Affine(mul, add).map(self, layout)
|
||||
}
|
||||
|
||||
fn avg_pool2d(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
) -> Result<Self> {
|
||||
AvgPool2D(kernel_size, stride).map(self, layout)
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
) -> Result<Self> {
|
||||
MaxPool2D(kernel_size, stride).map(self, layout)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
||||
UpsampleNearest2D(h, w).map(self, layout)
|
||||
}
|
||||
|
||||
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
|
||||
match self {
|
||||
@ -1612,6 +1956,16 @@ impl BackendStorage for CpuStorage {
|
||||
Conv1D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
Conv2D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
match ids {
|
||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
@ -1767,35 +2121,36 @@ impl BackendDevice for CpuDevice {
|
||||
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
|
||||
DType::BF16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let std = bf16::from_f64(std);
|
||||
let mean = bf16::from_f64(mean);
|
||||
let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
|
||||
.map_err(Error::wrap)?;
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean)
|
||||
data.push(normal.sample(&mut rng))
|
||||
}
|
||||
Ok(CpuStorage::BF16(data))
|
||||
}
|
||||
DType::F16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let std = f16::from_f64(std);
|
||||
let mean = f16::from_f64(mean);
|
||||
let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
|
||||
.map_err(Error::wrap)?;
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean)
|
||||
data.push(normal.sample(&mut rng))
|
||||
}
|
||||
Ok(CpuStorage::F16(data))
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let std = std as f32;
|
||||
let mean = mean as f32;
|
||||
let normal =
|
||||
rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
|
||||
data.push(normal.sample(&mut rng))
|
||||
}
|
||||
Ok(CpuStorage::F32(data))
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
|
||||
data.push(normal.sample(&mut rng))
|
||||
}
|
||||
Ok(CpuStorage::F64(data))
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ impl From<CudaError> for crate::Error {
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct DeviceId(usize);
|
||||
pub struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
fn new() -> Self {
|
||||
@ -111,6 +111,14 @@ impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||
self.device.clone()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DeviceId {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
@ -897,14 +905,13 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
// Kernel shape: (c_out, c_in_k, k_size)
|
||||
// Input shape: (b_size, c_in, l_in) or (c_in, l_in)
|
||||
let p = &self.0;
|
||||
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let l_out = p.l_out();
|
||||
let dst_el = p.c_out * l_out * p.b_size.unwrap_or(1);
|
||||
let dst_el = p.c_out * l_out * p.b_size;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -917,7 +924,136 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (el, l_out, p.stride, &ds, inp, k, &out);
|
||||
let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||
impl<'a> Map2 for Conv2D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
inp_l: &Layout,
|
||||
k: &CudaSlice<T>,
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_out, c_in_k, w_k, h_k)
|
||||
// Input shape: (b_size, c_in, w_in, c_in)
|
||||
let p = &self.0;
|
||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv2d"), kernels::CONV)?;
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
enum PoolOp {
|
||||
Max,
|
||||
Avg,
|
||||
}
|
||||
|
||||
struct Pool2D {
|
||||
w_k: usize,
|
||||
h_k: usize,
|
||||
w_stride: usize,
|
||||
h_stride: usize,
|
||||
op: PoolOp,
|
||||
}
|
||||
|
||||
impl Map1 for Pool2D {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
inp_l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Input shape: (b_size, c, h, w)
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride()].concat()
|
||||
} else {
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let el = shape.elem_count();
|
||||
let out_w = (dims[2] - self.w_k) / self.w_stride + 1;
|
||||
let out_h = (dims[3] - self.h_k) / self.h_stride + 1;
|
||||
let dst_el = out_w * out_h * dims[0] * dims[1];
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let kname = match self.op {
|
||||
PoolOp::Max => "max_pool2d",
|
||||
PoolOp::Avg => "avg_pool2d",
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(kname), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el,
|
||||
self.w_k,
|
||||
self.h_k,
|
||||
self.w_stride,
|
||||
self.h_stride,
|
||||
&ds,
|
||||
inp,
|
||||
&out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct UpsampleNearest2D(usize, usize);
|
||||
impl Map1 for UpsampleNearest2D {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
inp_l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Input shape: (b_size, c, h, w)
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride()].concat()
|
||||
} else {
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let (out_w, out_h) = (self.0, self.1);
|
||||
let dst_el = out_w * out_h * dims[0] * dims[1];
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let scale_w = dims[2] as f64 / out_w as f64;
|
||||
let scale_h = dims[3] as f64 / out_h as f64;
|
||||
let params = (out_w, out_h, scale_w, scale_h, &ds, inp, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
@ -1381,6 +1517,114 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cudnn"))]
|
||||
fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
#[cfg(feature = "cudnn")]
|
||||
fn conv2d(
|
||||
&self,
|
||||
inp_l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
if !kernel_l.is_contiguous() {
|
||||
let slice = Conv2D(params).map(&self.slice, inp_l, &kernel.slice, kernel_l, &device)?;
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
let (out_w, out_h) = (params.out_w(), params.out_h());
|
||||
let dst_el = params.c_out * out_w * out_h * params.b_size;
|
||||
let slice = match (&self.slice, &kernel.slice) {
|
||||
(S::U8(inp), S::U8(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
|
||||
crate::cudnn::launch_conv2d::<u8>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::U8(out)
|
||||
}
|
||||
(S::BF16(inp), S::BF16(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
|
||||
crate::cudnn::launch_conv2d::<bf16>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::BF16(out)
|
||||
}
|
||||
(S::F16(inp), S::F16(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
|
||||
crate::cudnn::launch_conv2d::<f16>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F16(out)
|
||||
}
|
||||
|
||||
(S::F32(inp), S::F32(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
|
||||
crate::cudnn::launch_conv2d::<f32>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F32(out)
|
||||
}
|
||||
(S::F64(inp), S::F64(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
|
||||
crate::cudnn::launch_conv2d::<f64>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F64(out)
|
||||
}
|
||||
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
|
||||
};
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Pool2D {
|
||||
w_k: k.0,
|
||||
h_k: k.1,
|
||||
w_stride: stride.0,
|
||||
h_stride: stride.1,
|
||||
op: PoolOp::Avg,
|
||||
}
|
||||
.map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Pool2D {
|
||||
w_k: k.0,
|
||||
h_k: k.1,
|
||||
w_stride: stride.0,
|
||||
h_stride: stride.1,
|
||||
op: PoolOp::Max,
|
||||
}
|
||||
.map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
|
107
candle-core/src/cudnn.rs
Normal file
107
candle-core/src/cudnn.rs
Normal file
@ -0,0 +1,107 @@
|
||||
use crate::WithDType;
|
||||
use cudarc;
|
||||
use cudarc::cudnn::safe::{Conv2dForward, Cudnn};
|
||||
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
// The cudnn handles are stored per thread here rather than on the CudaDevice as they are neither
|
||||
// send nor sync.
|
||||
thread_local! {
|
||||
static CUDNN: RefCell<HashMap<crate::cuda_backend::DeviceId, Arc<Cudnn>>> = HashMap::new().into();
|
||||
}
|
||||
|
||||
impl From<cudarc::cudnn::CudnnError> for crate::Error {
|
||||
fn from(err: cudarc::cudnn::CudnnError) -> Self {
|
||||
crate::Error::wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<cudarc::driver::DriverError> for crate::Error {
|
||||
fn from(err: cudarc::driver::DriverError) -> Self {
|
||||
crate::Error::wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn launch_conv2d<
|
||||
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
||||
>(
|
||||
src: &CudaView<T>,
|
||||
src_l: &crate::Layout,
|
||||
filter: &CudaView<T>,
|
||||
dst: &mut CudaSlice<T>,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
dev: &crate::cuda_backend::CudaDevice,
|
||||
) -> crate::Result<()> {
|
||||
let device_id = dev.id();
|
||||
let cudnn = CUDNN.with(|cudnn| {
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
return Ok(cudnn.clone());
|
||||
}
|
||||
let c = Cudnn::new(dev.cuda_device());
|
||||
if let Ok(c) = &c {
|
||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||
}
|
||||
c
|
||||
})?;
|
||||
let conv = cudnn.create_conv2d::<T>(
|
||||
/* pad */ [params.padding as i32, params.padding as i32],
|
||||
/* stride */ [params.stride as i32, params.stride as i32],
|
||||
/* dilation */ [1, 1],
|
||||
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
||||
)?;
|
||||
let x_shape = [
|
||||
params.b_size as i32,
|
||||
params.c_in as i32,
|
||||
params.i_w as i32,
|
||||
params.i_h as i32,
|
||||
];
|
||||
// Note that `src` already starts at the proper offset.
|
||||
let x = if src_l.is_contiguous() {
|
||||
cudnn.create_4d_tensor(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
x_shape,
|
||||
)?
|
||||
} else {
|
||||
let s = src_l.stride();
|
||||
cudnn.create_4d_tensor_ex(
|
||||
x_shape,
|
||||
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
|
||||
)?
|
||||
};
|
||||
let w = cudnn.create_4d_filter(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[
|
||||
params.c_out as i32,
|
||||
params.c_in as i32,
|
||||
params.k_w as i32,
|
||||
params.k_h as i32,
|
||||
],
|
||||
)?;
|
||||
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
||||
let y = cudnn.create_4d_tensor(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[params.b_size as i32, params.c_out as i32, w_out, h_out],
|
||||
)?;
|
||||
let conv2d = Conv2dForward {
|
||||
conv: &conv,
|
||||
x: &x,
|
||||
w: &w,
|
||||
y: &y,
|
||||
};
|
||||
let alg = conv2d.pick_algorithm()?;
|
||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
||||
alg,
|
||||
Some(&mut workspace),
|
||||
(T::one(), T::zero()),
|
||||
src,
|
||||
filter,
|
||||
dst,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -101,6 +101,13 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cpu(&self) -> bool {
|
||||
match self {
|
||||
Self::Cpu => true,
|
||||
Self::Cuda(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cuda(&self) -> bool {
|
||||
match self {
|
||||
Self::Cpu => false,
|
||||
|
@ -43,7 +43,7 @@ impl DType {
|
||||
|
||||
pub fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
Self::U8 => 4,
|
||||
Self::U8 => 1,
|
||||
Self::U32 => 4,
|
||||
Self::BF16 => 2,
|
||||
Self::F16 => 2,
|
||||
@ -53,7 +53,17 @@ impl DType {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithDType: Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + 'static {
|
||||
pub trait WithDType:
|
||||
Sized
|
||||
+ Copy
|
||||
+ num_traits::NumAssign
|
||||
+ std::cmp::PartialOrd
|
||||
+ std::fmt::Display
|
||||
+ 'static
|
||||
+ Send
|
||||
+ Sync
|
||||
+ crate::cpu::kernels::VecOps
|
||||
{
|
||||
const DTYPE: DType;
|
||||
|
||||
fn from_f64(v: f64) -> Self;
|
||||
|
@ -75,6 +75,16 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -119,6 +129,18 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::backend::BackendDevice for CudaDevice {
|
||||
|
@ -185,6 +185,13 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
Wrapped(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
/// Adding path information to an error.
|
||||
#[error("path: {path:?} {inner}")]
|
||||
WithPath {
|
||||
inner: Box<Self>,
|
||||
path: std::path::PathBuf,
|
||||
},
|
||||
|
||||
#[error("{inner}\n{backtrace}")]
|
||||
WithBacktrace {
|
||||
inner: Box<Self>,
|
||||
@ -214,6 +221,13 @@ impl Error {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
|
||||
Self::WithPath {
|
||||
inner: Box::new(self),
|
||||
path: p.as_ref().to_path_buf(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
|
@ -33,13 +33,18 @@
|
||||
//!
|
||||
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
mod accelerate;
|
||||
pub mod backend;
|
||||
pub mod backprop;
|
||||
mod conv;
|
||||
mod convert;
|
||||
pub mod cpu;
|
||||
pub mod cpu_backend;
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod cuda_backend;
|
||||
#[cfg(feature = "cudnn")]
|
||||
pub mod cudnn;
|
||||
mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
@ -51,6 +56,7 @@ pub mod layout;
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
mod op;
|
||||
pub mod quantized;
|
||||
pub mod safetensors;
|
||||
pub mod shape;
|
||||
mod storage;
|
||||
|
@ -26,7 +26,7 @@
|
||||
//! values = np.loadz("test.npz")
|
||||
//! ```
|
||||
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
@ -307,42 +307,7 @@ impl Tensor {
|
||||
header.push('\n');
|
||||
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
|
||||
f.write_all(header.as_bytes())?;
|
||||
let elem_count = self.elem_count();
|
||||
match self.dtype() {
|
||||
DType::BF16 => {
|
||||
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F32 => {
|
||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||
for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
|
||||
f.write_f32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F64 => {
|
||||
for v in self.reshape(elem_count)?.to_vec1::<f64>()? {
|
||||
f.write_f64::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U32 => {
|
||||
for v in self.reshape(elem_count)?.to_vec1::<u32>()? {
|
||||
f.write_u32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U8 => {
|
||||
let data = self.reshape(elem_count)?.to_vec1::<u8>()?;
|
||||
f.write_all(&data)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
self.write_bytes(f)
|
||||
}
|
||||
|
||||
/// Writes a multi-dimensional array in the npy format.
|
||||
@ -373,7 +338,7 @@ pub struct NpzTensors {
|
||||
index_per_name: HashMap<String, usize>,
|
||||
path: std::path::PathBuf,
|
||||
// We do not store a zip reader as it needs mutable access to extract data. Instead we
|
||||
// re-create a zip reader each time.
|
||||
// re-create a zip reader for each tensor.
|
||||
}
|
||||
|
||||
impl NpzTensors {
|
||||
|
@ -51,6 +51,7 @@ pub enum UnaryOp {
|
||||
Cos,
|
||||
Abs,
|
||||
Neg,
|
||||
Recip,
|
||||
Sqr,
|
||||
Sqrt,
|
||||
Gelu,
|
||||
@ -79,6 +80,28 @@ pub enum Op {
|
||||
stride: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
Conv2D {
|
||||
arg: Tensor,
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
},
|
||||
|
||||
AvgPool2D {
|
||||
arg: Tensor,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
MaxPool2D {
|
||||
arg: Tensor,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
UpsampleNearest2D(Tensor),
|
||||
|
||||
Cat(Vec<Tensor>, usize),
|
||||
|
||||
#[allow(dead_code)] // add is currently unused.
|
||||
@ -264,6 +287,7 @@ pub(crate) struct Sin;
|
||||
pub(crate) struct Cos;
|
||||
pub(crate) struct Abs;
|
||||
pub(crate) struct Neg;
|
||||
pub(crate) struct Recip;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
@ -314,6 +338,21 @@ macro_rules! bin_op {
|
||||
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::$f64_vec(xs1, xs2, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F32_VEC: bool = true;
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F64_VEC: bool = true;
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
|
||||
crate::accelerate::$f32_vec(xs1, xs2, ys)
|
||||
}
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
|
||||
crate::accelerate::$f64_vec(xs1, xs2, ys)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -400,6 +439,21 @@ macro_rules! unary_op {
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::$f64_vec(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F32_VEC: bool = true;
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F64_VEC: bool = true;
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::accelerate::$f32_vec(xs, ys)
|
||||
}
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::accelerate::$f64_vec(xs, ys)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -410,6 +464,7 @@ unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Recip, "recip", v, v.recip());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
|
215
candle-core/src/quantized/ggml_file.rs
Normal file
215
candle-core/src/quantized/ggml_file.rs
Normal file
@ -0,0 +1,215 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
use super::{k_quants, GgmlDType};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum Magic {
|
||||
Ggjt,
|
||||
Ggla,
|
||||
Ggmf,
|
||||
Ggml,
|
||||
Ggsn,
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for Magic {
|
||||
type Error = crate::Error;
|
||||
fn try_from(value: u32) -> Result<Self> {
|
||||
let magic = match value {
|
||||
0x67676a74 => Self::Ggjt,
|
||||
0x67676c61 => Self::Ggla,
|
||||
0x67676d66 => Self::Ggmf,
|
||||
0x67676d6c => Self::Ggml,
|
||||
0x6767736e => Self::Ggsn,
|
||||
_ => crate::bail!("unknown magic {value:08x}"),
|
||||
};
|
||||
Ok(magic)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum VersionedMagic {
|
||||
GgmlUnversioned,
|
||||
GgmfV1,
|
||||
GgjtV1,
|
||||
GgjtV2,
|
||||
GgjtV3,
|
||||
}
|
||||
|
||||
impl VersionedMagic {
|
||||
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||
let magic = reader.read_u32::<LittleEndian>()?;
|
||||
let magic = Magic::try_from(magic)?;
|
||||
if magic == Magic::Ggml {
|
||||
return Ok(Self::GgmlUnversioned);
|
||||
}
|
||||
let version = reader.read_u32::<LittleEndian>()?;
|
||||
let versioned_magic = match (magic, version) {
|
||||
(Magic::Ggmf, 1) => Self::GgmfV1,
|
||||
(Magic::Ggjt, 1) => Self::GgjtV1,
|
||||
(Magic::Ggjt, 2) => Self::GgjtV2,
|
||||
(Magic::Ggjt, 3) => Self::GgjtV3,
|
||||
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||
};
|
||||
Ok(versioned_magic)
|
||||
}
|
||||
|
||||
fn align32(&self) -> bool {
|
||||
match self {
|
||||
Self::GgmlUnversioned | Self::GgmfV1 => false,
|
||||
Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct HParams {
|
||||
pub n_vocab: u32,
|
||||
pub n_embd: u32,
|
||||
pub n_mult: u32,
|
||||
pub n_head: u32,
|
||||
pub n_layer: u32,
|
||||
pub n_rot: u32,
|
||||
pub ftype: u32,
|
||||
}
|
||||
|
||||
impl HParams {
|
||||
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||
let n_vocab = reader.read_u32::<LittleEndian>()?;
|
||||
let n_embd = reader.read_u32::<LittleEndian>()?;
|
||||
let n_mult = reader.read_u32::<LittleEndian>()?;
|
||||
let n_head = reader.read_u32::<LittleEndian>()?;
|
||||
let n_layer = reader.read_u32::<LittleEndian>()?;
|
||||
let n_rot = reader.read_u32::<LittleEndian>()?;
|
||||
let ftype = reader.read_u32::<LittleEndian>()?;
|
||||
Ok(Self {
|
||||
n_vocab,
|
||||
n_embd,
|
||||
n_mult,
|
||||
n_head,
|
||||
n_layer,
|
||||
n_rot,
|
||||
ftype,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Vocab {
|
||||
pub token_score_pairs: Vec<(Vec<u8>, f32)>,
|
||||
}
|
||||
|
||||
impl Vocab {
|
||||
fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556
|
||||
let mut token_score_pairs = Vec::with_capacity(n_vocab);
|
||||
for _index in 0..n_vocab {
|
||||
let len = reader.read_u32::<LittleEndian>()? as usize;
|
||||
let mut word = vec![0u8; len];
|
||||
reader.read_exact(&mut word)?;
|
||||
let score = reader.read_f32::<LittleEndian>()?;
|
||||
token_score_pairs.push((word, score))
|
||||
}
|
||||
Ok(Self { token_score_pairs })
|
||||
}
|
||||
}
|
||||
|
||||
fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
raw_data: &[u8],
|
||||
size_in_bytes: usize,
|
||||
dims: Vec<usize>,
|
||||
) -> Result<super::QTensor> {
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
Ok(super::QTensor::new(data.to_vec(), dims))
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
pub fn qtensor_from_ggml(
|
||||
ggml_dtype: GgmlDType,
|
||||
raw_data: &[u8],
|
||||
dims: Vec<usize>,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
magic: VersionedMagic,
|
||||
) -> Result<(String, super::QTensor)> {
|
||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
|
||||
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
|
||||
let mut dims = vec![0u32; n_dims as usize];
|
||||
reader.read_u32_into::<LittleEndian>(&mut dims)?;
|
||||
let mut name = vec![0u8; name_len as usize];
|
||||
reader.read_exact(&mut name)?;
|
||||
let name = String::from_utf8_lossy(&name).into_owned();
|
||||
|
||||
if magic.align32() {
|
||||
let pos = reader.stream_position()?;
|
||||
reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
|
||||
}
|
||||
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
println!("{name} {ggml_dtype:?} {dims:?}");
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||
Ok(tensor) => Ok((name, tensor)),
|
||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Content {
|
||||
pub magic: VersionedMagic,
|
||||
pub hparams: HParams,
|
||||
pub vocab: Vocab,
|
||||
pub tensors: Vec<(String, super::QTensor)>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||
let magic = VersionedMagic::read(reader)?;
|
||||
let hparams = HParams::read(reader)?;
|
||||
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
|
||||
let mut tensors = vec![];
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||
tensors.push((name, tensor))
|
||||
}
|
||||
Ok(Self {
|
||||
magic,
|
||||
hparams,
|
||||
vocab,
|
||||
tensors,
|
||||
})
|
||||
}
|
||||
}
|
802
candle-core/src/quantized/k_quants.rs
Normal file
802
candle-core/src/quantized/k_quants.rs
Normal file
@ -0,0 +1,802 @@
|
||||
use super::GgmlDType;
|
||||
use crate::Result;
|
||||
use half::f16;
|
||||
|
||||
// Default to QK_K 256 rather than 64.
|
||||
pub const QK_K: usize = 256;
|
||||
pub const K_SCALE_SIZE: usize = 12;
|
||||
|
||||
pub const QK4_0: usize = 32;
|
||||
pub const QK4_1: usize = 32;
|
||||
pub const QK5_0: usize = 32;
|
||||
pub const QK5_1: usize = 32;
|
||||
pub const QK8_0: usize = 32;
|
||||
pub const QK8_1: usize = 32;
|
||||
|
||||
pub trait GgmlType: Sized + Clone {
|
||||
const DTYPE: GgmlDType;
|
||||
const BLCK_SIZE: usize;
|
||||
type VecDotType: GgmlType;
|
||||
|
||||
// This is only safe for types that include immediate values such as float/int/...
|
||||
fn zeros() -> Self {
|
||||
unsafe { std::mem::MaybeUninit::zeroed().assume_init() }
|
||||
}
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>;
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>;
|
||||
|
||||
/// Dot product used as a building block for quantized mat-mul.
|
||||
/// n is the number of elements to be considered.
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[repr(C)]
|
||||
pub struct BlockQ4_0 {
|
||||
d: f16,
|
||||
qs: [u8; QK4_0 / 2],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[repr(C)]
|
||||
pub struct BlockQ4_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qs: [u8; QK4_1 / 2],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[repr(C)]
|
||||
pub struct BlockQ5_0 {
|
||||
d: f16,
|
||||
qh: [u8; 4],
|
||||
qs: [u8; QK5_0 / 2],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[repr(C)]
|
||||
pub struct BlockQ5_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qh: [u8; 4],
|
||||
qs: [u8; QK5_1 / 2],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[repr(C)]
|
||||
pub struct BlockQ8_0 {
|
||||
d: f16,
|
||||
qs: [u8; QK8_0],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[repr(C)]
|
||||
pub struct BlockQ8_1 {
|
||||
d: f16,
|
||||
s: f16,
|
||||
qs: [u8; QK8_1],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[repr(C)]
|
||||
pub struct BlockQ2K {
|
||||
scales: [u8; QK_K / 16],
|
||||
qs: [u8; QK_K / 4],
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
}
|
||||
const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[repr(C)]
|
||||
pub struct BlockQ3K {
|
||||
hmask: [u8; QK_K / 8],
|
||||
qs: [u8; QK_K / 4],
|
||||
scales: [u8; 12],
|
||||
d: f16,
|
||||
}
|
||||
const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
|
||||
#[repr(C)]
|
||||
pub struct BlockQ4K {
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
scales: [u8; K_SCALE_SIZE],
|
||||
qs: [u8; QK_K / 2],
|
||||
}
|
||||
const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[repr(C)]
|
||||
pub struct BlockQ5K {
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
scales: [u8; K_SCALE_SIZE],
|
||||
qh: [u8; QK_K / 8],
|
||||
qs: [u8; QK_K / 2],
|
||||
}
|
||||
const _: () =
|
||||
assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[repr(C)]
|
||||
pub struct BlockQ6K {
|
||||
ql: [u8; QK_K / 2],
|
||||
qh: [u8; QK_K / 4],
|
||||
scales: [i8; QK_K / 16],
|
||||
d: f16,
|
||||
}
|
||||
const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[repr(C)]
|
||||
pub struct BlockQ8K {
|
||||
d: f32,
|
||||
qs: [i8; QK_K],
|
||||
bsums: [i16; QK_K / 16],
|
||||
}
|
||||
const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::<BlockQ8K>());
|
||||
|
||||
impl GgmlType for BlockQ4_1 {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q4_1;
|
||||
const BLCK_SIZE: usize = QK4_1;
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1545
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK4_1 != 0 {
|
||||
crate::bail!("dequantize_row_q4_1: {k} is not divisible by {QK4_1}");
|
||||
}
|
||||
|
||||
let nb = k / QK4_1;
|
||||
for i in 0..nb {
|
||||
let d = xs[i].d.to_f32();
|
||||
let m = xs[i].m.to_f32();
|
||||
|
||||
for j in 0..(QK4_1 / 2) {
|
||||
let x0 = xs[i].qs[j] & 0x0F;
|
||||
let x1 = xs[i].qs[j] >> 4;
|
||||
|
||||
ys[i * QK4_1 + j] = (x0 as f32) * d + m;
|
||||
ys[i * QK4_1 + j + QK4_1 / 2] = (x1 as f32) * d + m;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ5_0 {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q5_0;
|
||||
const BLCK_SIZE: usize = QK5_0;
|
||||
type VecDotType = BlockQ8_0;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1566
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK5_0 != 0 {
|
||||
crate::bail!("dequantize_row_q5_0: {k} is not divisible by {QK5_0}");
|
||||
}
|
||||
|
||||
let nb = k / QK5_0;
|
||||
for i in 0..nb {
|
||||
let d = xs[i].d.to_f32();
|
||||
let qh: u32 = unsafe { std::mem::transmute_copy(&xs[i].qh) };
|
||||
|
||||
for j in 0..(QK5_0 / 2) {
|
||||
let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
|
||||
let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
|
||||
|
||||
let x0 = ((xs[i].qs[j] & 0x0F) | xh_0) as i32 - 16;
|
||||
let x1 = ((xs[i].qs[j] >> 4) | xh_1) as i32 - 16;
|
||||
|
||||
ys[i * QK5_0 + j] = (x0 as f32) * d;
|
||||
ys[i * QK5_0 + j + QK5_0 / 2] = (x1 as f32) * d;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ5_1 {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q5_1;
|
||||
const BLCK_SIZE: usize = QK5_1;
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1592
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK5_1 != 0 {
|
||||
crate::bail!("dequantize_row_q5_1: {k} is not divisible by {QK5_1}");
|
||||
}
|
||||
|
||||
let nb = k / QK5_1;
|
||||
for i in 0..nb {
|
||||
let d = xs[i].d.to_f32();
|
||||
let m = xs[i].m.to_f32();
|
||||
let qh: u32 = unsafe { std::mem::transmute_copy(&xs[i].qh) };
|
||||
|
||||
for j in 0..(QK5_1 / 2) {
|
||||
let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
|
||||
let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
|
||||
|
||||
let x0 = (xs[i].qs[j] & 0x0F) | xh_0;
|
||||
let x1 = (xs[i].qs[j] >> 4) | xh_1;
|
||||
|
||||
ys[i * QK5_1 + j] = (x0 as f32) * d + m;
|
||||
ys[i * QK5_1 + j + QK5_1 / 2] = (x1 as f32) * d + m;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ2K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q2K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK_K != 0 {
|
||||
crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut ys_index = 0;
|
||||
for x in xs {
|
||||
let d = x.d.to_f32();
|
||||
let min = x.dmin.to_f32();
|
||||
let q = &x.qs;
|
||||
|
||||
let mut is = 0;
|
||||
for n in (0..QK_K).step_by(128) {
|
||||
// Step by 32 over q.
|
||||
let q = &q[n / 4..];
|
||||
let mut shift = 0;
|
||||
for _j in 0..4 {
|
||||
let sc = x.scales[is];
|
||||
is += 1;
|
||||
let dl = d * (sc & 0xF) as f32;
|
||||
let ml = min * (sc >> 4) as f32;
|
||||
for q in &q[..16] {
|
||||
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
|
||||
ys[ys_index] = y;
|
||||
ys_index += 1;
|
||||
}
|
||||
|
||||
let sc = x.scales[is];
|
||||
is += 1;
|
||||
let dl = d * (sc & 0xF) as f32;
|
||||
let ml = min * (sc >> 4) as f32;
|
||||
for q in &q[16..32] {
|
||||
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
|
||||
ys[ys_index] = y;
|
||||
ys_index += 1;
|
||||
}
|
||||
|
||||
shift += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
|
||||
if j < 4 {
|
||||
let d = q[j] & 63;
|
||||
let m = q[j + 4] & 63;
|
||||
(d, m)
|
||||
} else {
|
||||
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
|
||||
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
|
||||
(d, m)
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ4K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q4K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK_K != 0 {
|
||||
crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut ys_index = 0;
|
||||
for x in xs.iter() {
|
||||
let d = x.d.to_f32();
|
||||
let min = x.dmin.to_f32();
|
||||
let q = &x.qs;
|
||||
let mut is = 0;
|
||||
for j in (0..QK_K).step_by(64) {
|
||||
let q = &q[j / 2..j / 2 + 32];
|
||||
let (sc, m) = get_scale_min_k4(is, &x.scales);
|
||||
let d1 = d * sc as f32;
|
||||
let m1 = min * m as f32;
|
||||
let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
|
||||
let d2 = d * sc as f32;
|
||||
let m2 = min * m as f32;
|
||||
for q in q {
|
||||
let y = d1 * (q & 0xF) as f32 - m1;
|
||||
ys[ys_index] = y;
|
||||
ys_index += 1;
|
||||
}
|
||||
for q in q {
|
||||
let y = d2 * (q >> 4) as f32 - m2;
|
||||
ys[ys_index] = y;
|
||||
ys_index += 1;
|
||||
}
|
||||
is += 2;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ3K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q3K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
|
||||
fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928
|
||||
impl GgmlType for BlockQ5K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q5K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK_K != 0 {
|
||||
crate::bail!("dequantize_row_q5k: {k} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut ys_index = 0;
|
||||
for x in xs.iter() {
|
||||
let d = x.d.to_f32();
|
||||
let min = x.dmin.to_f32();
|
||||
let ql = &x.qs;
|
||||
let qh = &x.qh;
|
||||
let mut is = 0;
|
||||
let mut u1 = 1;
|
||||
let mut u2 = 2;
|
||||
for j in (0..QK_K).step_by(64) {
|
||||
let ql = &ql[j / 2..j / 2 + 32];
|
||||
let (sc, m) = get_scale_min_k4(is, &x.scales);
|
||||
let d1 = d * sc as f32;
|
||||
let m1 = min * m as f32;
|
||||
let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
|
||||
let d2 = d * sc as f32;
|
||||
let m2 = min * m as f32;
|
||||
for (ql, qh) in ql.iter().zip(qh) {
|
||||
let to_add = if qh & u1 != 0 { 16 } else { 1 };
|
||||
let y = d1 * ((ql & 0xF) + to_add) as f32 - m1;
|
||||
ys[ys_index] = y;
|
||||
ys_index += 1;
|
||||
}
|
||||
for (ql, qh) in ql.iter().zip(qh) {
|
||||
let to_add = if qh & u2 != 0 { 16 } else { 1 };
|
||||
let y = d2 * ((ql >> 4) + to_add) as f32 - m2;
|
||||
ys[ys_index] = y;
|
||||
ys_index += 1;
|
||||
}
|
||||
is += 2;
|
||||
u1 <<= 2;
|
||||
u2 <<= 2;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ6K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q6K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK_K != 0 {
|
||||
crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
|
||||
}
|
||||
for x in xs.iter() {
|
||||
let d = x.d.to_f32();
|
||||
let ql = &x.ql;
|
||||
let qh = &x.qh;
|
||||
let sc = &x.scales;
|
||||
for n in (0..QK_K).step_by(128) {
|
||||
let idx = n / 128;
|
||||
let ys = &mut ys[n..];
|
||||
let sc = &sc[8 * idx..];
|
||||
let ql = &ql[64 * idx..];
|
||||
let qh = &qh[32 * idx..];
|
||||
for l in 0..32 {
|
||||
let is = l / 16;
|
||||
let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32;
|
||||
let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32;
|
||||
let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32;
|
||||
let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32;
|
||||
ys[l] = d * sc[is] as f32 * q1 as f32;
|
||||
ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32;
|
||||
ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32;
|
||||
ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ8K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q8K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
|
||||
fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ4_0 {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q4_0;
|
||||
const BLCK_SIZE: usize = QK4_0;
|
||||
type VecDotType = BlockQ8_0;
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK4_0 != 0 {
|
||||
crate::bail!("dequantize_row_q4_0: {k} is not divisible by {QK4_0}")
|
||||
}
|
||||
|
||||
let nb = k / QK4_0;
|
||||
for i in 0..nb {
|
||||
let d = xs[i].d.to_f32();
|
||||
|
||||
for j in 0..(QK4_0 / 2) {
|
||||
let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8;
|
||||
let x1 = (xs[i].qs[j] >> 4) as i16 - 8;
|
||||
|
||||
ys[i * QK4_0 + j] = (x0 as f32) * d;
|
||||
ys[i * QK4_0 + j + QK4_0 / 2] = (x1 as f32) * d;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
// quantize_row_q4_0
|
||||
let qk = Self::BLCK_SIZE;
|
||||
let k = xs.len();
|
||||
if k % qk != 0 {
|
||||
crate::bail!("{k} is not divisible by {}", qk);
|
||||
};
|
||||
let nb = k / qk;
|
||||
if ys.len() != nb {
|
||||
crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,)
|
||||
}
|
||||
for (i, ys) in ys.iter_mut().enumerate() {
|
||||
let mut amax = 0f32;
|
||||
let mut max = 0f32;
|
||||
|
||||
let xs = &xs[i * qk..(i + 1) * qk];
|
||||
for &x in xs.iter() {
|
||||
if amax < x.abs() {
|
||||
amax = x.abs();
|
||||
max = x;
|
||||
}
|
||||
}
|
||||
let d = max / -8.0;
|
||||
let id = if d != 0f32 { 1. / d } else { 0. };
|
||||
ys.d = f16::from_f32(d);
|
||||
|
||||
for (j, q) in ys.qs.iter_mut().enumerate() {
|
||||
let x0 = xs[j] * id;
|
||||
let x1 = xs[qk / 2 + j] * id;
|
||||
let xi0 = u8::min(15, (x0 + 8.5) as u8);
|
||||
let xi1 = u8::min(15, (x1 + 8.5) as u8);
|
||||
*q = xi0 | (xi1 << 4)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
for i in 0..nb {
|
||||
let mut sum_i = 0;
|
||||
for j in 0..qk / 2 {
|
||||
let v0 = (xs[i].qs[j] & 0x0F) as i32 - 8;
|
||||
let v1 = (xs[i].qs[j] >> 4) as i32 - 8;
|
||||
sum_i += v0 * ys[i].qs[j] as i32 + v1 * ys[i].qs[j + qk / 2] as i32
|
||||
}
|
||||
sumf += sum_i as f32 * f16::to_f32(xs[i].d) * f16::to_f32(ys[i].d)
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ8_0 {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q8_0;
|
||||
const BLCK_SIZE: usize = QK8_0;
|
||||
type VecDotType = BlockQ8_0;
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK8_0 != 0 {
|
||||
crate::bail!("dequantize_row_q8_0: {k} is not divisible by {QK8_0}");
|
||||
}
|
||||
|
||||
let nb = k / QK8_0;
|
||||
|
||||
for i in 0..nb {
|
||||
let d = xs[i].d.to_f32();
|
||||
|
||||
for j in 0..QK8_0 {
|
||||
ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
// quantize_row_q8_0
|
||||
let k = xs.len();
|
||||
if k % Self::BLCK_SIZE != 0 {
|
||||
crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE);
|
||||
};
|
||||
let nb = k / Self::BLCK_SIZE;
|
||||
if ys.len() != nb {
|
||||
crate::bail!(
|
||||
"size mismatch {} {} {}",
|
||||
xs.len(),
|
||||
ys.len(),
|
||||
Self::BLCK_SIZE
|
||||
)
|
||||
}
|
||||
for (i, ys) in ys.iter_mut().enumerate() {
|
||||
let mut amax = 0f32;
|
||||
let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
|
||||
for &x in xs.iter() {
|
||||
amax = amax.max(x.abs())
|
||||
}
|
||||
let d = amax / ((1 << 7) - 1) as f32;
|
||||
let id = if d != 0f32 { 1. / d } else { 0. };
|
||||
ys.d = f16::from_f32(d);
|
||||
for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
|
||||
*y = f32::round(x * id) as u8
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn vec_dot(_: usize, _: &[Self], _: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ8_1 {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q3K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
|
||||
fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605
|
||||
pub fn matmul<T: GgmlType>(
|
||||
mkn: (usize, usize, usize),
|
||||
lhs: &[f32],
|
||||
rhs_t: &[T],
|
||||
dst: &mut [f32],
|
||||
) -> Result<()> {
|
||||
let (m, k, n) = mkn;
|
||||
if m * k != lhs.len() {
|
||||
crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
|
||||
}
|
||||
|
||||
let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE;
|
||||
let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE;
|
||||
// TODO: Do not make this copy if the DotType is f32.
|
||||
// TODO: Pre-allocate this.
|
||||
let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];
|
||||
for row_idx in 0..m {
|
||||
let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
|
||||
let lhs = &lhs[row_idx * k..(row_idx + 1) * k];
|
||||
T::VecDotType::from_float(lhs, lhs_b)?
|
||||
}
|
||||
let lhs_b = lhs_b.as_slice();
|
||||
|
||||
for row_idx in 0..m {
|
||||
let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
|
||||
let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
|
||||
for (col_idx, dst) in dst_row.iter_mut().enumerate() {
|
||||
let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
|
||||
*dst = T::vec_dot(k, rhs_col, lhs_row)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl GgmlType for f32 {
|
||||
const DTYPE: GgmlDType = GgmlDType::F32;
|
||||
const BLCK_SIZE: usize = 1;
|
||||
type VecDotType = f32;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
if ys.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", ys.len())
|
||||
}
|
||||
let mut res = 0f32;
|
||||
unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
if xs.len() != ys.len() {
|
||||
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
|
||||
}
|
||||
ys.copy_from_slice(xs);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
if xs.len() != ys.len() {
|
||||
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
|
||||
}
|
||||
ys.copy_from_slice(xs);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for f16 {
|
||||
const DTYPE: GgmlDType = GgmlDType::F16;
|
||||
const BLCK_SIZE: usize = 1;
|
||||
type VecDotType = f16;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
if ys.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", ys.len())
|
||||
}
|
||||
let mut res = 0f32;
|
||||
unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
if xs.len() != ys.len() {
|
||||
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
|
||||
}
|
||||
// TODO: vectorize
|
||||
for (x, y) in xs.iter().zip(ys.iter_mut()) {
|
||||
*y = f16::from_f32(*x)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
if xs.len() != ys.len() {
|
||||
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
|
||||
}
|
||||
// TODO: vectorize
|
||||
for (x, y) in xs.iter().zip(ys.iter_mut()) {
|
||||
*y = x.to_f32()
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
194
candle-core/src/quantized/mod.rs
Normal file
194
candle-core/src/quantized/mod.rs
Normal file
@ -0,0 +1,194 @@
|
||||
use crate::{Device, Result, Shape, Tensor};
|
||||
|
||||
pub mod ggml_file;
|
||||
pub mod k_quants;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
|
||||
pub struct QTensor {
|
||||
data: Box<dyn QuantizedType>,
|
||||
shape: Shape,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum GgmlDType {
|
||||
F32,
|
||||
F16,
|
||||
Q4_0,
|
||||
Q4_1,
|
||||
Q5_0,
|
||||
Q5_1,
|
||||
Q8_0,
|
||||
Q8_1,
|
||||
Q2K,
|
||||
Q3K,
|
||||
Q4K,
|
||||
Q5K,
|
||||
Q6K,
|
||||
Q8K,
|
||||
}
|
||||
|
||||
impl GgmlDType {
|
||||
pub(crate) fn from_u32(u: u32) -> Result<Self> {
|
||||
let dtype = match u {
|
||||
0 => Self::F32,
|
||||
1 => Self::F16,
|
||||
2 => Self::Q4_0,
|
||||
3 => Self::Q4_1,
|
||||
6 => Self::Q5_0,
|
||||
7 => Self::Q5_1,
|
||||
8 => Self::Q8_0,
|
||||
9 => Self::Q8_1,
|
||||
10 => Self::Q2K,
|
||||
11 => Self::Q3K,
|
||||
12 => Self::Q4K,
|
||||
13 => Self::Q5K,
|
||||
14 => Self::Q6K,
|
||||
15 => Self::Q8K,
|
||||
_ => crate::bail!("unknown dtype for tensor {u}"),
|
||||
};
|
||||
Ok(dtype)
|
||||
}
|
||||
|
||||
fn type_size(&self) -> usize {
|
||||
use k_quants::*;
|
||||
match self {
|
||||
Self::F32 => 4,
|
||||
Self::F16 => 2,
|
||||
Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
|
||||
Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
|
||||
Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
|
||||
Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
|
||||
Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
|
||||
Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
|
||||
Self::Q2K => std::mem::size_of::<BlockQ2K>(),
|
||||
Self::Q3K => std::mem::size_of::<BlockQ3K>(),
|
||||
Self::Q4K => std::mem::size_of::<BlockQ4K>(),
|
||||
Self::Q5K => std::mem::size_of::<BlockQ5K>(),
|
||||
Self::Q6K => std::mem::size_of::<BlockQ6K>(),
|
||||
Self::Q8K => std::mem::size_of::<BlockQ8K>(),
|
||||
}
|
||||
}
|
||||
|
||||
fn blck_size(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 1,
|
||||
Self::F16 => 1,
|
||||
Self::Q4_0 => k_quants::QK4_0,
|
||||
Self::Q4_1 => k_quants::QK4_1,
|
||||
Self::Q5_0 => k_quants::QK5_0,
|
||||
Self::Q5_1 => k_quants::QK5_1,
|
||||
Self::Q8_0 => k_quants::QK8_0,
|
||||
Self::Q8_1 => k_quants::QK8_1,
|
||||
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
|
||||
pub trait QuantizedType: Send + Sync {
|
||||
fn dtype(&self) -> GgmlDType;
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
||||
}
|
||||
|
||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
T::DTYPE
|
||||
}
|
||||
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
||||
T::to_float(self.as_slice(), ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for QTensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
|
||||
}
|
||||
}
|
||||
|
||||
impl QTensor {
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
) -> Self {
|
||||
Self {
|
||||
data: Box::new(data),
|
||||
shape: shape.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.data.dtype()
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||
self.data.to_float(&mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, &self.shape, device)
|
||||
}
|
||||
|
||||
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
self.data.matmul_t(mkn, lhs, dst)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||
|
||||
impl QMatMul {
|
||||
pub fn new(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||
Self(qtensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for QMatMul {
|
||||
fn name(&self) -> &'static str {
|
||||
"qmatmul"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
storage: &crate::CpuStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CpuStorage, Shape)> {
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
let (k, n) = self.0.shape.dims2()?;
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!(
|
||||
"input tensor {layout:?} incompatible with {:?}",
|
||||
self.0.shape
|
||||
)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let storage = storage.as_slice::<f32>()?;
|
||||
let storage =
|
||||
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
self.0.matmul_t(
|
||||
(dst_shape.elem_count() / n, k, n),
|
||||
storage,
|
||||
&mut dst_storage,
|
||||
)?;
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
}
|
@ -242,18 +242,28 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
|
||||
pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
|
||||
let data = std::fs::read(filename.as_ref())?;
|
||||
let st = safetensors::SafeTensors::deserialize(&data)?;
|
||||
load_buffer(&data[..], device)
|
||||
}
|
||||
|
||||
pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
|
||||
let st = safetensors::SafeTensors::deserialize(data)?;
|
||||
st.tensors()
|
||||
.into_iter()
|
||||
.map(|(name, view)| Ok((name, view.load(device)?)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Result<()> {
|
||||
pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
|
||||
tensors: &HashMap<K, Tensor>,
|
||||
filename: P,
|
||||
) -> Result<()> {
|
||||
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
|
||||
}
|
||||
|
||||
pub struct MmapedFile(memmap2::Mmap);
|
||||
pub struct MmapedFile {
|
||||
path: std::path::PathBuf,
|
||||
inner: memmap2::Mmap,
|
||||
}
|
||||
|
||||
impl MmapedFile {
|
||||
/// Creates a wrapper around a memory mapped file from which you can retrieve
|
||||
@ -263,13 +273,20 @@ impl MmapedFile {
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let mmap = memmap2::MmapOptions::new().map(&file)?;
|
||||
Ok(Self(mmap))
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let inner = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok(Self {
|
||||
inner,
|
||||
path: p.to_path_buf(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
|
||||
let st = safetensors::SafeTensors::deserialize(&self.0)?;
|
||||
let st = safetensors::SafeTensors::deserialize(&self.inner)
|
||||
.map_err(|e| Error::from(e).with_path(&self.path))?;
|
||||
Ok(st)
|
||||
}
|
||||
}
|
||||
|
@ -79,20 +79,25 @@ impl From<Vec<usize>> for Shape {
|
||||
|
||||
macro_rules! extract_dims {
|
||||
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
|
||||
impl Shape {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
if self.0.len() != $cnt {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: $cnt,
|
||||
got: self.0.len(),
|
||||
shape: self.clone(),
|
||||
}
|
||||
.bt())
|
||||
} else {
|
||||
Ok($dims(&self.0))
|
||||
pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
|
||||
if dims.len() != $cnt {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: $cnt,
|
||||
got: dims.len(),
|
||||
shape: Shape::from(dims),
|
||||
}
|
||||
.bt())
|
||||
} else {
|
||||
Ok($dims(dims))
|
||||
}
|
||||
}
|
||||
|
||||
impl Shape {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
$fn_name(self.0.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Tensor {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
self.shape().$fn_name()
|
||||
@ -340,7 +345,7 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
||||
}
|
||||
}
|
||||
|
||||
extract_dims!(dims0, 0, |_: &Vec<usize>| (), ());
|
||||
extract_dims!(dims0, 0, |_: &[usize]| (), ());
|
||||
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
extract_dims!(
|
||||
|
@ -266,6 +266,82 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
self.same_device(kernel, "conv2d")?;
|
||||
self.same_dtype(kernel, "conv2d")?;
|
||||
match (self, &kernel) {
|
||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cpu(s))
|
||||
}
|
||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "conv2d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn avg_pool2d(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
kernel_size: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
|
@ -269,6 +269,10 @@ impl Tensor {
|
||||
Self::rand_impl(lo, up, s, device, false)
|
||||
}
|
||||
|
||||
pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
|
||||
Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
|
||||
}
|
||||
|
||||
pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
|
||||
mean: T,
|
||||
std: T,
|
||||
@ -296,6 +300,17 @@ impl Tensor {
|
||||
Ok(from_storage(storage, s, none, is_variable))
|
||||
}
|
||||
|
||||
pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
|
||||
Tensor::randn_f64_impl(
|
||||
mean,
|
||||
stdev,
|
||||
self.shape(),
|
||||
self.dtype(),
|
||||
self.device(),
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values sampled from a normal distribution with the
|
||||
/// specified `mean` and standard deviation `std`.
|
||||
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
|
||||
@ -474,6 +489,7 @@ impl Tensor {
|
||||
broadcast_binary_op!(broadcast_sub, sub);
|
||||
broadcast_binary_op!(broadcast_div, div);
|
||||
|
||||
unary_op!(recip, Recip);
|
||||
unary_op!(neg, Neg);
|
||||
unary_op!(exp, Exp);
|
||||
unary_op!(log, Log);
|
||||
@ -548,6 +564,32 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a tensor into the specified number of chunks, this may return less chunks than
|
||||
/// specificed.
|
||||
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
|
||||
let dim = dim.to_index(self.shape(), "chunk")?;
|
||||
let size = self.dim(dim)?;
|
||||
if size < chunks {
|
||||
(0..size).map(|i| self.narrow(dim, i, 1)).collect()
|
||||
} else {
|
||||
let chunk_size = size / chunks;
|
||||
let cnt_additional = size % chunks;
|
||||
let mut tensors = vec![];
|
||||
let mut sum_chunk_size = 0;
|
||||
for i in 0..chunks {
|
||||
let chunk_size = if i < cnt_additional {
|
||||
chunk_size + 1
|
||||
} else {
|
||||
chunk_size
|
||||
};
|
||||
let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
|
||||
tensors.push(tensor);
|
||||
sum_chunk_size += chunk_size
|
||||
}
|
||||
Ok(tensors)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + len`.
|
||||
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
||||
@ -731,18 +773,7 @@ impl Tensor {
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = match *self.dims() {
|
||||
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
|
||||
[c_in, l_in] => (None, c_in, l_in),
|
||||
_ => Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
k_shape: kernel.shape().clone(),
|
||||
padding,
|
||||
stride,
|
||||
msg: "input rank is not 2 or 3",
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
@ -775,6 +806,77 @@ impl Tensor {
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = crate::conv::ParamsConv2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
stride,
|
||||
};
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
let (n, c, _h, _w) = self.dims4()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest2d(self.layout(), target_h, target_w)?;
|
||||
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
|
||||
}
|
||||
|
||||
pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||
let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
stride,
|
||||
});
|
||||
let storage = self
|
||||
.storage()
|
||||
.avg_pool2d(self.layout(), kernel_size, stride)?;
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||
let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
stride,
|
||||
});
|
||||
let storage = self
|
||||
.storage()
|
||||
.max_pool2d(self.layout(), kernel_size, stride)?;
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
@ -1717,6 +1819,32 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||
if left == 0 && right == 0 {
|
||||
Ok(self.clone())
|
||||
} else if left == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = right;
|
||||
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||
Tensor::cat(&[self, &right], dim)
|
||||
} else if right == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = left;
|
||||
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||
Tensor::cat(&[&left, self], dim)
|
||||
} else {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = left;
|
||||
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||
dims[dim] = right;
|
||||
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||
Tensor::cat(&[&left, self, &right], dim)
|
||||
}
|
||||
}
|
||||
|
||||
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
@ -11,16 +11,14 @@ pub fn get_num_threads() -> usize {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_accelerate() -> bool {
|
||||
cfg!(feature = "accelerate")
|
||||
}
|
||||
|
||||
pub fn has_mkl() -> bool {
|
||||
#[cfg(feature = "mkl")]
|
||||
return true;
|
||||
#[cfg(not(feature = "mkl"))]
|
||||
return false;
|
||||
cfg!(feature = "mkl")
|
||||
}
|
||||
|
||||
pub fn cuda_is_available() -> bool {
|
||||
#[cfg(feature = "cuda")]
|
||||
return true;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
return false;
|
||||
cfg!(feature = "cuda")
|
||||
}
|
||||
|
178
candle-core/tests/conv_tests.rs
Normal file
178
candle-core/tests/conv_tests.rs
Normal file
@ -0,0 +1,178 @@
|
||||
mod test_utils;
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
/* This test is based on the following script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 4, 5))
|
||||
w = torch.randn((2, 4, 3))
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv1d(t, w)
|
||||
print(res.flatten())
|
||||
res = torch.nn.functional.conv1d(t, w, padding=1)
|
||||
print(res.flatten())
|
||||
*/
|
||||
fn conv1d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
|
||||
1.8025, -0.1536, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278, -1.0124, 0.5599,
|
||||
],
|
||||
dev,
|
||||
)?
|
||||
.reshape((1, 4, 5))?;
|
||||
let w = Tensor::new(
|
||||
&[
|
||||
-0.8404f32, -0.3490, 0.0130, 1.3123, 0.1763, -1.9249, 1.4270, 0.9421, 0.8670, -0.7181,
|
||||
-1.1111, 0.8869, -1.2429, 1.8357, 1.6052, -1.3844, 0.3951, -1.2036, 0.6686, 1.6261,
|
||||
-0.6451, -0.0840, -1.4247, 0.5512,
|
||||
],
|
||||
dev,
|
||||
)?
|
||||
.reshape((2, 4, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
|
||||
);
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 5]);
|
||||
// Same as pytorch default padding: use zeros.
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn conv1d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
|
||||
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 2]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.4056, -0.8689]
|
||||
);
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 4]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.0, 0.4056, -0.8689, -0.0773],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/* This test is based on the following script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 4, 5, 5))
|
||||
w = torch.randn((2, 4, 3, 3))
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
*/
|
||||
fn conv2d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
|
||||
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
|
||||
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
|
||||
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
|
||||
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
|
||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
let w = Tensor::new(
|
||||
&[
|
||||
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
|
||||
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
|
||||
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
|
||||
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
|
||||
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
|
||||
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
|
||||
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
|
||||
0.5583, 0.4623, 0.6026,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
let t = t.reshape((1, 4, 5, 5))?;
|
||||
let w = w.reshape((2, 4, 3, 3))?;
|
||||
let res = t.conv2d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/* This test is based on the following script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 2, 3, 3))
|
||||
w = torch.randn((1, 2, 1, 1))
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
*/
|
||||
fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
|
||||
-0.6266, 0.3529, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
|
||||
let t = t.reshape((1, 2, 3, 3))?;
|
||||
let w = w.reshape((1, 2, 1, 1))?;
|
||||
let res = t.conv2d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn conv2d_smaller(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
|
||||
let t = t.reshape((1, 1, 3, 3))?;
|
||||
let w = w.reshape((1, 1, 3, 3))?;
|
||||
let res = t.conv2d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 1, 1]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[-0.6197]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
@ -85,8 +85,14 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let y = (x.log()? + 1.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [2.0986123, 1.0, 2.3862944, -0.89712]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.0986, 1.0, 2.3863, -0.8971]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[0.3333, 1.0, 0.25, 6.6667]
|
||||
);
|
||||
let y = x.exp()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
@ -141,7 +147,7 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3.0, 1.0, 4.0, 0.15]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1.0, 1.0, 1.0, 1.0]);
|
||||
assert_eq!(test_utils::to_vec1_round(grad_x, 4)?, [1.0, 1.0, 1.0, 1.0]);
|
||||
let y = x.neg()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
@ -155,7 +161,10 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let y = Tensor::new(1f32, device)?.broadcast_div(x)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[0.3333, 1.0, 0.25, 6.6667]
|
||||
);
|
||||
assert_eq!(
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[-0.11111111, -1.0, -0.0625, -44.444443],
|
||||
|
89
candle-core/tests/pool_tests.rs
Normal file
89
candle-core/tests/pool_tests.rs
Normal file
@ -0,0 +1,89 @@
|
||||
mod test_utils;
|
||||
use candle_core::{Device, IndexOp, Result, Tensor};
|
||||
|
||||
// https://github.com/huggingface/candle/issues/364
|
||||
fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
let data: Vec<f32> = vec![
|
||||
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn max_pool2d(dev: &Device) -> Result<()> {
|
||||
let data: Vec<f32> = vec![
|
||||
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
||||
|
||||
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/* This test corresponds to the following PyTorch script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 2, 4, 4))
|
||||
print(t.flatten())
|
||||
res = torch.nn.functional.avg_pool2d(t, 2)
|
||||
print(res)
|
||||
*/
|
||||
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
|
||||
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
|
||||
0.2477, 1.3127,
|
||||
],
|
||||
dev,
|
||||
)?
|
||||
.reshape((1, 2, 4, 4))?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(pool, 4)?,
|
||||
[
|
||||
[[-1.1926, -0.0395], [0.2688, 0.1871]],
|
||||
[[0.1835, -0.1606], [0.6249, 0.3217]]
|
||||
]
|
||||
);
|
||||
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
|
||||
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 6f32, dev)?.reshape((1, 1, 2, 3))?;
|
||||
let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?;
|
||||
assert_eq!(
|
||||
t.i(0)?.i(0)?.to_vec2::<f32>()?,
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
upsampled.to_vec2::<f32>()?,
|
||||
[
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
|
||||
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
||||
test_device!(
|
||||
avg_pool2d_pytorch,
|
||||
avg_pool2d_pytorch_cpu,
|
||||
avg_pool2d_pytorch_gpu
|
||||
);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
||||
test_device!(
|
||||
upsample_nearest2d,
|
||||
upsample_nearest2d_cpu,
|
||||
upsample_nearest2d_gpu
|
||||
);
|
46
candle-core/tests/quantized_tests.rs
Normal file
46
candle-core/tests/quantized_tests.rs
Normal file
@ -0,0 +1,46 @@
|
||||
use candle_core::{quantized, Device, Result, Tensor};
|
||||
use quantized::{k_quants, GgmlType};
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
dst,
|
||||
&[
|
||||
85120.43, 214561.61, 345454.9, 474748.1, 213474.94, 604465.25, 1000686.4, 1388317.3,
|
||||
341875.88, 994283.0, 1655708.8, 2301518.3
|
||||
]
|
||||
);
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
assert_eq!(
|
||||
mm.to_vec2::<f32>()?,
|
||||
&[
|
||||
[85344.0, 214368.0, 343392.0, 472416.0],
|
||||
[214368.0, 605536.0, 996704.0, 1387872.0],
|
||||
[343392.0, 996704.0, 1650016.0, 2303328.0]
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (64, 4));
|
||||
let op = quantized::QMatMul::new(std::sync::Arc::new(qtensor));
|
||||
let res = tensor_lhs.custom_op1(op)?;
|
||||
assert_eq!(
|
||||
res.to_vec2::<f32>()?,
|
||||
&[
|
||||
[85120.43, 214561.61, 345454.9, 474748.1],
|
||||
[213474.94, 604465.25, 1000686.4, 1388317.3],
|
||||
[341875.88, 994283.0, 1655708.8, 2301518.3]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
@ -869,3 +869,14 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
#[test]
|
||||
fn randn_hasneg() -> Result<()> {
|
||||
let t = Tensor::randn(0f32, 1f32, 200, &Device::Cpu)?.to_vec1::<f32>()?;
|
||||
if t.iter().all(|&v| v >= 0.) {
|
||||
candle_core::bail!("all values in tensors are non-negative")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,5 +1,8 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_core::{Result, Tensor};
|
||||
|
||||
#[macro_export]
|
||||
|
20
candle-datasets/Cargo.toml
Normal file
20
candle-datasets/Cargo.toml
Normal file
@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "candle-datasets"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.1.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.1" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
rand = { workspace = true }
|
1
candle-datasets/README.md
Normal file
1
candle-datasets/README.md
Normal file
@ -0,0 +1 @@
|
||||
# candle-datasets
|
6
candle-datasets/src/lib.rs
Normal file
6
candle-datasets/src/lib.rs
Normal file
@ -0,0 +1,6 @@
|
||||
//! Datasets & Dataloaders for Candle
|
||||
pub mod batcher;
|
||||
pub mod nlp;
|
||||
pub mod vision;
|
||||
|
||||
pub use batcher::Batcher;
|
1
candle-datasets/src/nlp/mod.rs
Normal file
1
candle-datasets/src/nlp/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
pub mod tinystories;
|
122
candle-datasets/src/nlp/tinystories.rs
Normal file
122
candle-datasets/src/nlp/tinystories.rs
Normal file
@ -0,0 +1,122 @@
|
||||
//! Helper functions for the tinystories dataset. This uses the pre-tokenized version as generated
|
||||
//! by the tools from https://github.com/karpathy/llama2.c
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
pub struct Dataset {
|
||||
valid_tokens: Vec<memmap2::Mmap>,
|
||||
train_tokens: Vec<memmap2::Mmap>,
|
||||
}
|
||||
|
||||
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
|
||||
Ok(mmap)
|
||||
}
|
||||
|
||||
impl Dataset {
|
||||
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
|
||||
let dir = dir.as_ref();
|
||||
let mut bin_files = vec![];
|
||||
for file in std::fs::read_dir(dir)?.flatten() {
|
||||
let file = file.path();
|
||||
if let Some(extension) = file.extension() {
|
||||
if extension == "bin" {
|
||||
bin_files.push(file)
|
||||
}
|
||||
}
|
||||
}
|
||||
if bin_files.len() < 2 {
|
||||
candle::bail!("found less than two bin files in {:?}", dir)
|
||||
}
|
||||
bin_files.sort();
|
||||
let valid_tokens = mmap_file(&bin_files[0])?;
|
||||
let train_tokens = bin_files[1..]
|
||||
.iter()
|
||||
.map(mmap_file)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
valid_tokens: vec![valid_tokens],
|
||||
train_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn train_tokens(&self) -> usize {
|
||||
self.train_tokens.len()
|
||||
}
|
||||
|
||||
pub fn valid_tokens(&self) -> usize {
|
||||
self.valid_tokens.len()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DatasetRandomIter<'a> {
|
||||
all_tokens: &'a [memmap2::Mmap],
|
||||
tokens: Vec<&'a memmap2::Mmap>,
|
||||
current_tokens: &'a memmap2::Mmap,
|
||||
indexes_in_bytes: Vec<usize>,
|
||||
seq_len: usize,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl<'a> DatasetRandomIter<'a> {
|
||||
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let all_tokens = if valid {
|
||||
&ds.valid_tokens
|
||||
} else {
|
||||
&ds.train_tokens
|
||||
};
|
||||
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
||||
tokens.shuffle(&mut thread_rng());
|
||||
let current_tokens = tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = seq_len * 2;
|
||||
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
Self {
|
||||
all_tokens,
|
||||
tokens,
|
||||
current_tokens,
|
||||
indexes_in_bytes,
|
||||
seq_len,
|
||||
device,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for DatasetRandomIter<'a> {
|
||||
type Item = Result<(Tensor, Tensor)>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let seq_len = self.seq_len;
|
||||
if self.indexes_in_bytes.is_empty() {
|
||||
if self.tokens.is_empty() {
|
||||
self.tokens = self.all_tokens.iter().collect();
|
||||
self.tokens.shuffle(&mut thread_rng());
|
||||
}
|
||||
self.current_tokens = self.tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = self.seq_len * 2;
|
||||
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
self.indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
}
|
||||
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
||||
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
||||
let mut tokens = vec![0u16; bytes.len() / 2];
|
||||
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
|
||||
return Some(Err(err.into()));
|
||||
}
|
||||
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
|
||||
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
|
||||
let targets = Tensor::new(&tokens[1..], &self.device);
|
||||
Some(candle::error::zip(inputs, targets))
|
||||
}
|
||||
}
|
@ -10,10 +10,12 @@ license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.1.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.1.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.1.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.1", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@ -21,12 +23,13 @@ num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
clap = { workspace = true }
|
||||
hf-hub = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
@ -34,13 +37,17 @@ tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
@ -48,3 +55,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
required-features = ["cuda", "nccl", "flash-attn"]
|
||||
|
||||
[[example]]
|
||||
name = "stable-diffusion"
|
||||
required-features = ["image"]
|
||||
|
@ -39,6 +39,10 @@ struct Args {
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -107,7 +111,10 @@ fn main() -> Result<()> {
|
||||
let device = &model.device;
|
||||
|
||||
if let Some(prompt) = args.prompt {
|
||||
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
@ -164,7 +171,13 @@ fn main() -> Result<()> {
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if args.normalize_embeddings {
|
||||
normalize_l2(&embeddings)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
|
||||
let mut similarities = vec![];
|
||||
for i in 0..n_sentences {
|
||||
let e_i = embeddings.get(i)?;
|
||||
@ -184,3 +197,7 @@ fn main() -> Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||
}
|
||||
|
@ -65,10 +65,7 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
let token = self
|
||||
.tokenizer
|
||||
.decode(vec![next_token], true)
|
||||
.map_err(E::msg)?;
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
@ -1,5 +1,8 @@
|
||||
// TODO: Add an offline mode.
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
@ -69,16 +72,14 @@ impl TextGeneration {
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
self.tokenizer
|
||||
.decode(vec![next_token], true)
|
||||
.map_err(E::msg)?
|
||||
self.tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{sample_len} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
sample_len as f64 / dt.as_secs_f64(),
|
||||
self.tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
self.tokenizer.decode(&new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
28
candle-examples/examples/ggml/main.rs
Normal file
28
candle-examples/examples/ggml/main.rs
Normal file
@ -0,0 +1,28 @@
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use std::fs::File;
|
||||
|
||||
use candle::quantized::ggml_file::Content;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
|
||||
#[arg(long)]
|
||||
model: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let mut file = File::open(args.model)?;
|
||||
let start = std::time::Instant::now();
|
||||
let model = Content::read(&mut file)?;
|
||||
|
||||
println!(
|
||||
"Loaded {:?} tensors in {:?}",
|
||||
model.tensors.len(),
|
||||
start.elapsed()
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -1,199 +0,0 @@
|
||||
# Adapted from:
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
|
||||
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
"""
|
||||
Sample usage:
|
||||
|
||||
```
|
||||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
|
||||
```
|
||||
"""
|
||||
|
||||
INTERMEDIATE_SIZE_MAP = {
|
||||
"7B": 11008,
|
||||
"13B": 13824,
|
||||
"30B": 17920,
|
||||
"65B": 22016,
|
||||
}
|
||||
NUM_SHARDS = {
|
||||
"7B": 1,
|
||||
"13B": 2,
|
||||
"30B": 4,
|
||||
"65B": 8,
|
||||
}
|
||||
|
||||
|
||||
def compute_intermediate_size(n):
|
||||
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(text, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_model(model_path, input_base_path, model_size):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
params = read_json(os.path.join(input_base_path, "params.json"))
|
||||
num_shards = NUM_SHARDS[model_size]
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
n_heads_per_shard = n_heads // num_shards
|
||||
dim = params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w):
|
||||
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
||||
# Load weights
|
||||
if model_size == "7B":
|
||||
# Not sharded
|
||||
# (The sharded implementation would also work, but this is simpler.)
|
||||
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
|
||||
else:
|
||||
# Sharded
|
||||
loaded = [
|
||||
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
|
||||
for i in range(num_shards)
|
||||
]
|
||||
param_count = 0
|
||||
all_dicts = {}
|
||||
for layer_i in range(n_layers):
|
||||
if model_size == "7B":
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wq.weight"]
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wk.weight"]
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
|
||||
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
|
||||
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
|
||||
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
|
||||
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
|
||||
}
|
||||
else:
|
||||
# Sharded
|
||||
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
|
||||
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
|
||||
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
|
||||
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
|
||||
f"layers.{layer_i}.attention_norm.weight"
|
||||
].clone(),
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
|
||||
f"layers.{layer_i}.ffn_norm.weight"
|
||||
].clone(),
|
||||
}
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
|
||||
)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
||||
all_dicts |= state_dict
|
||||
|
||||
if model_size == "7B":
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
|
||||
"model.norm.weight": loaded["norm.weight"],
|
||||
"lm_head.weight": loaded["output.weight"],
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
"model.norm.weight": loaded[0]["norm.weight"],
|
||||
"model.embed_tokens.weight": torch.cat(
|
||||
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
|
||||
),
|
||||
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
|
||||
}
|
||||
all_dicts |= state_dict
|
||||
all_dicts = {k: v.numpy() for k, v in all_dicts.items()}
|
||||
np.savez(os.path.join(model_path, "llama.npz"), **all_dicts)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
choices=["7B", "13B", "30B", "65B"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
write_model(
|
||||
model_path=args.output_dir,
|
||||
input_base_path=os.path.join(args.input_dir, args.model_size),
|
||||
model_size=args.model_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -5,9 +5,9 @@
|
||||
//
|
||||
// The tokenizer config can be retrieved from:
|
||||
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
|
||||
//
|
||||
// In order to convert the llama weights to a .npz file, run:
|
||||
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
@ -19,62 +19,14 @@ use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::api::sync::Api;
|
||||
use std::io::Write;
|
||||
|
||||
mod model;
|
||||
use model::{Config, Llama};
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_PROMPT: &str = r"
|
||||
EDWARD:
|
||||
I wonder how our princely father 'scaped,
|
||||
Or whether he be 'scaped away or no
|
||||
From Clifford's and Northumberland's pursuit:
|
||||
Had he been ta'en, we should have heard the news;
|
||||
Had he been slain, we should have heard the news;
|
||||
Or had he 'scaped, methinks we should have heard
|
||||
The happy tidings of his good escape.
|
||||
How fares my brother? why is he so sad?
|
||||
|
||||
RICHARD:
|
||||
I cannot joy, until I be resolved
|
||||
Where our right valiant father is become.
|
||||
I saw him in the battle range about;
|
||||
And watch'd him how he singled Clifford forth.
|
||||
Methought he bore him in the thickest troop
|
||||
As doth a lion in a herd of neat;
|
||||
Or as a bear, encompass'd round with dogs,
|
||||
Who having pinch'd a few and made them cry,
|
||||
The rest stand all aloof, and bark at him.
|
||||
So fared our father with his enemies;
|
||||
So fled his enemies my warlike father:
|
||||
Methinks, 'tis prize enough to be his son.
|
||||
See how the morning opes her golden gates,
|
||||
And takes her farewell of the glorious sun!
|
||||
How well resembles it the prime of youth,
|
||||
Trimm'd like a younker prancing to his love!
|
||||
|
||||
EDWARD:
|
||||
Dazzle mine eyes, or do I see three suns?
|
||||
|
||||
RICHARD:
|
||||
Three glorious suns, each one a perfect sun;
|
||||
Not separated with the racking clouds,
|
||||
But sever'd in a pale clear-shining sky.
|
||||
See, see! they join, embrace, and seem to kiss,
|
||||
As if they vow'd some league inviolable:
|
||||
Now are they but one lamp, one light, one sun.
|
||||
In this the heaven figures some event.
|
||||
|
||||
EDWARD:
|
||||
'Tis wondrous strange, the like yet never heard of.
|
||||
I think it cites us, brother, to the field,
|
||||
That we, the sons of brave Plantagenet,
|
||||
Each one already blazing by our meeds,
|
||||
Should notwithstanding join our lights together
|
||||
And over-shine the earth as this the world.
|
||||
Whate'er it bodes, henceforward will I bear
|
||||
Upon my target three fair-shining suns.
|
||||
";
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -111,6 +63,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
use_f32: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
@ -119,12 +75,27 @@ struct Args {
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// The folder name that contains safetensor weights and json files
|
||||
/// (same structure as huggingface online)
|
||||
#[arg(long)]
|
||||
local_weights: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config = if args.v1 {
|
||||
@ -151,14 +122,26 @@ fn main() -> Result<()> {
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let api = api.model(model_id);
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
|
||||
let tokenizer_filename = match &args.local_weights {
|
||||
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
||||
_ => api.get("tokenizer.json")?,
|
||||
};
|
||||
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
match &args.local_weights {
|
||||
Some(path) => {
|
||||
filenames.push((path.to_owned() + rfilename).into());
|
||||
}
|
||||
_ => {
|
||||
let filename = api.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
println!("building the model");
|
||||
@ -176,6 +159,7 @@ fn main() -> Result<()> {
|
||||
}
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
@ -184,12 +168,12 @@ fn main() -> Result<()> {
|
||||
.to_vec();
|
||||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let context_size = if cache.use_kv_cache && index > 0 {
|
||||
1
|
||||
} else {
|
||||
@ -202,22 +186,27 @@ fn main() -> Result<()> {
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
println!(
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||
);
|
||||
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if Some(next_token) == eos_token_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
"\n\n{} tokens generated ({} token/s)\n",
|
||||
token_generated,
|
||||
token_generated as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||
use candle_nn::{Embedding, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -47,6 +47,21 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||
// model.
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
@ -106,8 +121,9 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
Ok(Linear::new(weight, None))
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
@ -118,15 +134,18 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
eps: f64,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let scale = vb.get(size, "weight")?;
|
||||
Ok(Self { scale, eps })
|
||||
Ok(Self { scale, eps, span })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let in_dtype = x.dtype();
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
@ -155,6 +174,8 @@ struct CausalSelfAttention {
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
@ -175,6 +196,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
@ -188,6 +210,7 @@ impl CausalSelfAttention {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
@ -269,6 +292,8 @@ impl CausalSelfAttention {
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let size_in = cfg.hidden_size;
|
||||
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
|
||||
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
|
||||
@ -286,6 +311,8 @@ impl CausalSelfAttention {
|
||||
head_dim: cfg.hidden_size / cfg.n_head,
|
||||
cache: cache.clone(),
|
||||
use_flash_attn: cfg.use_flash_attn,
|
||||
span,
|
||||
span_rot,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -301,15 +328,18 @@ struct Mlp {
|
||||
c_fc1: Linear,
|
||||
c_fc2: Linear,
|
||||
c_proj: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||
let h_size = cfg.hidden_size;
|
||||
let i_size = cfg.intermediate_size;
|
||||
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
||||
@ -319,6 +349,7 @@ impl Mlp {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -328,10 +359,12 @@ struct Block {
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
@ -341,6 +374,7 @@ impl Block {
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "block");
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
@ -354,6 +388,7 @@ impl Block {
|
||||
attn,
|
||||
rms_2,
|
||||
mlp,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,8 @@
|
||||
// https://github.com/karpathy/llama2.c
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
@ -27,7 +30,7 @@ struct InferenceCmd {
|
||||
#[arg(long, default_value = "")]
|
||||
prompt: String,
|
||||
|
||||
/// Config file in binary format.
|
||||
/// Config file in binary or safetensors format.
|
||||
#[arg(long)]
|
||||
config: Option<String>,
|
||||
|
||||
@ -200,7 +203,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
|
||||
}
|
||||
});
|
||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
for inp_tgt in batch_iter {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
@ -225,11 +228,22 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
println!("{config:?}");
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (vb, config) = if is_safetensors {
|
||||
let config = Config::tiny();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
(vb, config)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
println!("{config:?}");
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
(vb, config)
|
||||
};
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
|
@ -1,118 +1,6 @@
|
||||
#![allow(dead_code)]
|
||||
#![allow(unused)]
|
||||
use crate::model::{Cache, Config, Llama};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
pub struct Dataset {
|
||||
valid_tokens: Vec<memmap2::Mmap>,
|
||||
train_tokens: Vec<memmap2::Mmap>,
|
||||
}
|
||||
|
||||
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
|
||||
Ok(mmap)
|
||||
}
|
||||
|
||||
impl Dataset {
|
||||
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
|
||||
let dir = dir.as_ref();
|
||||
let mut bin_files = vec![];
|
||||
for file in std::fs::read_dir(dir)?.flatten() {
|
||||
let file = file.path();
|
||||
if let Some(extension) = file.extension() {
|
||||
if extension == "bin" {
|
||||
bin_files.push(file)
|
||||
}
|
||||
}
|
||||
}
|
||||
if bin_files.len() < 2 {
|
||||
candle::bail!("found less than two bin files in {:?}", dir)
|
||||
}
|
||||
bin_files.sort();
|
||||
let valid_tokens = mmap_file(&bin_files[0])?;
|
||||
let train_tokens = bin_files[1..]
|
||||
.iter()
|
||||
.map(mmap_file)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
valid_tokens: vec![valid_tokens],
|
||||
train_tokens,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct DatasetRandomIter<'a> {
|
||||
all_tokens: &'a [memmap2::Mmap],
|
||||
tokens: Vec<&'a memmap2::Mmap>,
|
||||
current_tokens: &'a memmap2::Mmap,
|
||||
indexes_in_bytes: Vec<usize>,
|
||||
seq_len: usize,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl<'a> DatasetRandomIter<'a> {
|
||||
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let all_tokens = if valid {
|
||||
&ds.valid_tokens
|
||||
} else {
|
||||
&ds.train_tokens
|
||||
};
|
||||
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
||||
tokens.shuffle(&mut thread_rng());
|
||||
let current_tokens = tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = seq_len * 2;
|
||||
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
Self {
|
||||
all_tokens,
|
||||
tokens,
|
||||
current_tokens,
|
||||
indexes_in_bytes,
|
||||
seq_len,
|
||||
device,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for DatasetRandomIter<'a> {
|
||||
type Item = Result<(Tensor, Tensor)>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let seq_len = self.seq_len;
|
||||
if self.indexes_in_bytes.is_empty() {
|
||||
if self.tokens.is_empty() {
|
||||
self.tokens = self.all_tokens.iter().collect();
|
||||
self.tokens.shuffle(&mut thread_rng());
|
||||
}
|
||||
self.current_tokens = self.tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = self.seq_len * 2;
|
||||
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
self.indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
}
|
||||
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
||||
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
||||
let mut tokens = vec![0u16; bytes.len() / 2];
|
||||
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
|
||||
return Some(Err(err.into()));
|
||||
}
|
||||
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
|
||||
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
|
||||
let targets = Tensor::new(&tokens[1..], &self.device);
|
||||
Some(candle::error::zip(inputs, targets))
|
||||
}
|
||||
}
|
||||
use candle::{DType, Device, Result};
|
||||
use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
|
||||
|
||||
fn valid_loss(
|
||||
dataset: &Dataset,
|
||||
@ -121,7 +9,7 @@ fn valid_loss(
|
||||
device: &Device,
|
||||
) -> Result<f64> {
|
||||
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
let mut sum_ce = 0f64;
|
||||
let mut cnt = 0usize;
|
||||
for inp_tgt in batch_iter.take(50) {
|
||||
@ -139,14 +27,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
let dataset = Dataset::new(&args.pretokenized_dir)?;
|
||||
println!(
|
||||
"loaded dataset, train: {} files, valid: {} files",
|
||||
dataset.train_tokens.len(),
|
||||
dataset.valid_tokens.len()
|
||||
dataset.train_tokens(),
|
||||
dataset.valid_tokens()
|
||||
);
|
||||
let varmap = candle_nn::VarMap::new();
|
||||
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||
let config = Config::tiny();
|
||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
|
||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
@ -104,7 +104,14 @@ impl TransformerWeights {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
|
||||
pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
|
||||
// TODO: As of 2023-08-04, gemm is slower than expected when multiplying a matrix of
|
||||
// size (1, k) with the transpose of a matrix of size (k, n) as it ends up transposing the
|
||||
// second matrix back. We detect this case here and as a temporary hack make the weight
|
||||
// matrix column major rather than row major. This ends up speeding up text generation from
|
||||
// 120 token/s to 220 token/s on a Ryzen 2600X.
|
||||
let tr = device.is_cpu() && !candle::utils::has_mkl();
|
||||
let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) };
|
||||
let mut ws = std::collections::HashMap::new();
|
||||
let mut insert = |name: &str, t: Tensor| {
|
||||
ws.insert(name.to_string(), t);
|
||||
@ -115,36 +122,36 @@ impl TransformerWeights {
|
||||
"model.embed_tokens.weight",
|
||||
self.token_embedding_table.clone(),
|
||||
);
|
||||
insert("lm_head.weight", self.token_embedding_table.clone());
|
||||
insert("lm_head.weight", tr(self.token_embedding_table.clone())?);
|
||||
insert("model.norm.weight", self.rms_final_weight.clone());
|
||||
for layer in 0..cfg.n_layers {
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.self_attn.q_proj.weight"),
|
||||
self.wq.i(layer)?,
|
||||
tr(self.wq.i(layer)?)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.self_attn.k_proj.weight"),
|
||||
self.wk.i(layer)?,
|
||||
tr(self.wk.i(layer)?)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.self_attn.v_proj.weight"),
|
||||
self.wv.i(layer)?,
|
||||
tr(self.wv.i(layer)?)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.self_attn.o_proj.weight"),
|
||||
self.wo.i(layer)?,
|
||||
tr(self.wo.i(layer)?)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.mlp.gate_proj.weight"),
|
||||
self.w1.i(layer)?,
|
||||
tr(self.w1.i(layer)?)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.mlp.down_proj.weight"),
|
||||
self.w2.i(layer)?,
|
||||
tr(self.w2.i(layer)?)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.mlp.up_proj.weight"),
|
||||
self.w3.i(layer)?,
|
||||
tr(self.w3.i(layer)?)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.input_layernorm.weight"),
|
||||
|
@ -5,9 +5,6 @@
|
||||
//
|
||||
// The tokenizer config can be retrieved from:
|
||||
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
|
||||
//
|
||||
// In order to convert the llama weights to a .npz file, run:
|
||||
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
@ -63,7 +63,7 @@ struct TrainingArgs {
|
||||
}
|
||||
|
||||
fn training_loop<M: Model>(
|
||||
m: candle_nn::vision::Dataset,
|
||||
m: candle_datasets::vision::Dataset,
|
||||
args: &TrainingArgs,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = candle::Device::cuda_if_available(0)?;
|
||||
@ -140,7 +140,7 @@ struct Args {
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
// Load the dataset
|
||||
let m = candle_nn::vision::mnist::load_dir("data")?;
|
||||
let m = candle_datasets::vision::mnist::load_dir("data")?;
|
||||
println!("train-images: {:?}", m.train_images.shape());
|
||||
println!("train-labels: {:?}", m.train_labels.shape());
|
||||
println!("test-images: {:?}", m.test_images.shape());
|
||||
|
473
candle-examples/examples/stable-diffusion/attention.rs
Normal file
473
candle-examples/examples/stable-diffusion/attention.rs
Normal file
@ -0,0 +1,473 @@
|
||||
#![allow(dead_code)]
|
||||
//! Attention Based Building Blocks
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct GeGlu {
|
||||
proj: nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl GeGlu {
|
||||
fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
|
||||
let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "geglu");
|
||||
Ok(Self { proj, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl GeGlu {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
|
||||
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
|
||||
}
|
||||
}
|
||||
|
||||
/// A feed-forward layer.
|
||||
#[derive(Debug)]
|
||||
struct FeedForward {
|
||||
project_in: GeGlu,
|
||||
linear: nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
// The glu parameter in the python code is unused?
|
||||
// https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347
|
||||
/// Creates a new feed-forward layer based on some given input dimension, some
|
||||
/// output dimension, and a multiplier to be used for the intermediary layer.
|
||||
fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
|
||||
let inner_dim = dim * mult;
|
||||
let dim_out = dim_out.unwrap_or(dim);
|
||||
let vs = vs.pp("net");
|
||||
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
|
||||
let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "ff");
|
||||
Ok(Self {
|
||||
project_in,
|
||||
linear,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.project_in.forward(xs)?;
|
||||
self.linear.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CrossAttention {
|
||||
to_q: nn::Linear,
|
||||
to_k: nn::Linear,
|
||||
to_v: nn::Linear,
|
||||
to_out: nn::Linear,
|
||||
heads: usize,
|
||||
scale: f64,
|
||||
slice_size: Option<usize>,
|
||||
span: tracing::Span,
|
||||
span_attn: tracing::Span,
|
||||
}
|
||||
|
||||
impl CrossAttention {
|
||||
// Defaults should be heads = 8, dim_head = 64, context_dim = None
|
||||
fn new(
|
||||
vs: nn::VarBuilder,
|
||||
query_dim: usize,
|
||||
context_dim: Option<usize>,
|
||||
heads: usize,
|
||||
dim_head: usize,
|
||||
slice_size: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = dim_head * heads;
|
||||
let context_dim = context_dim.unwrap_or(query_dim);
|
||||
let scale = 1.0 / f64::sqrt(dim_head as f64);
|
||||
let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
|
||||
let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
|
||||
let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
|
||||
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "xa");
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
|
||||
Ok(Self {
|
||||
to_q,
|
||||
to_k,
|
||||
to_v,
|
||||
to_out,
|
||||
heads,
|
||||
scale,
|
||||
slice_size,
|
||||
span,
|
||||
span_attn,
|
||||
})
|
||||
}
|
||||
|
||||
fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (batch_size, seq_len, dim) = xs.dims3()?;
|
||||
xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((batch_size * self.heads, seq_len, dim / self.heads))
|
||||
}
|
||||
|
||||
fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (batch_size, seq_len, dim) = xs.dims3()?;
|
||||
xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((batch_size / self.heads, seq_len, dim * self.heads))
|
||||
}
|
||||
|
||||
fn sliced_attention(
|
||||
&self,
|
||||
query: &Tensor,
|
||||
key: &Tensor,
|
||||
value: &Tensor,
|
||||
slice_size: usize,
|
||||
) -> Result<Tensor> {
|
||||
let batch_size_attention = query.dim(0)?;
|
||||
let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
|
||||
|
||||
for i in 0..batch_size_attention / slice_size {
|
||||
let start_idx = i * slice_size;
|
||||
let end_idx = (i + 1) * slice_size;
|
||||
|
||||
let xs = query
|
||||
.i(start_idx..end_idx)?
|
||||
.matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?;
|
||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
|
||||
hidden_states.push(xs)
|
||||
}
|
||||
let hidden_states = Tensor::stack(&hidden_states, 0)?;
|
||||
self.reshape_batch_dim_to_heads(&hidden_states)
|
||||
}
|
||||
|
||||
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
|
||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
|
||||
self.reshape_batch_dim_to_heads(&xs)
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query = self.to_q.forward(xs)?;
|
||||
let context = context.unwrap_or(xs);
|
||||
let key = self.to_k.forward(context)?;
|
||||
let value = self.to_v.forward(context)?;
|
||||
let query = self.reshape_heads_to_batch_dim(&query)?;
|
||||
let key = self.reshape_heads_to_batch_dim(&key)?;
|
||||
let value = self.reshape_heads_to_batch_dim(&value)?;
|
||||
let xs = match self.slice_size {
|
||||
None => self.attention(&query, &key, &value)?,
|
||||
Some(slice_size) => {
|
||||
if query.dim(0)? / slice_size <= 1 {
|
||||
self.attention(&query, &key, &value)?
|
||||
} else {
|
||||
self.sliced_attention(&query, &key, &value, slice_size)?
|
||||
}
|
||||
}
|
||||
};
|
||||
self.to_out.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
/// A basic Transformer block.
|
||||
#[derive(Debug)]
|
||||
struct BasicTransformerBlock {
|
||||
attn1: CrossAttention,
|
||||
ff: FeedForward,
|
||||
attn2: CrossAttention,
|
||||
norm1: nn::LayerNorm,
|
||||
norm2: nn::LayerNorm,
|
||||
norm3: nn::LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BasicTransformerBlock {
|
||||
fn new(
|
||||
vs: nn::VarBuilder,
|
||||
dim: usize,
|
||||
n_heads: usize,
|
||||
d_head: usize,
|
||||
context_dim: Option<usize>,
|
||||
sliced_attention_size: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let attn1 = CrossAttention::new(
|
||||
vs.pp("attn1"),
|
||||
dim,
|
||||
None,
|
||||
n_heads,
|
||||
d_head,
|
||||
sliced_attention_size,
|
||||
)?;
|
||||
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
|
||||
let attn2 = CrossAttention::new(
|
||||
vs.pp("attn2"),
|
||||
dim,
|
||||
context_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
sliced_attention_size,
|
||||
)?;
|
||||
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
|
||||
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
|
||||
let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "basic-transformer");
|
||||
Ok(Self {
|
||||
attn1,
|
||||
ff,
|
||||
attn2,
|
||||
norm1,
|
||||
norm2,
|
||||
norm3,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
|
||||
let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
|
||||
self.ff.forward(&self.norm3.forward(&xs)?)? + xs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct SpatialTransformerConfig {
|
||||
pub depth: usize,
|
||||
pub num_groups: usize,
|
||||
pub context_dim: Option<usize>,
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
}
|
||||
|
||||
impl Default for SpatialTransformerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
depth: 1,
|
||||
num_groups: 32,
|
||||
context_dim: None,
|
||||
sliced_attention_size: None,
|
||||
use_linear_projection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Proj {
|
||||
Conv2d(nn::Conv2d),
|
||||
Linear(nn::Linear),
|
||||
}
|
||||
|
||||
// Aka Transformer2DModel
|
||||
#[derive(Debug)]
|
||||
pub struct SpatialTransformer {
|
||||
norm: nn::GroupNorm,
|
||||
proj_in: Proj,
|
||||
transformer_blocks: Vec<BasicTransformerBlock>,
|
||||
proj_out: Proj,
|
||||
span: tracing::Span,
|
||||
pub config: SpatialTransformerConfig,
|
||||
}
|
||||
|
||||
impl SpatialTransformer {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
n_heads: usize,
|
||||
d_head: usize,
|
||||
config: SpatialTransformerConfig,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = n_heads * d_head;
|
||||
let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?;
|
||||
let proj_in = if config.use_linear_projection {
|
||||
Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?)
|
||||
} else {
|
||||
Proj::Conv2d(nn::conv2d(
|
||||
in_channels,
|
||||
inner_dim,
|
||||
1,
|
||||
Default::default(),
|
||||
vs.pp("proj_in"),
|
||||
)?)
|
||||
};
|
||||
let mut transformer_blocks = vec![];
|
||||
let vs_tb = vs.pp("transformer_blocks");
|
||||
for index in 0..config.depth {
|
||||
let tb = BasicTransformerBlock::new(
|
||||
vs_tb.pp(&index.to_string()),
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
config.context_dim,
|
||||
config.sliced_attention_size,
|
||||
)?;
|
||||
transformer_blocks.push(tb)
|
||||
}
|
||||
let proj_out = if config.use_linear_projection {
|
||||
Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?)
|
||||
} else {
|
||||
Proj::Conv2d(nn::conv2d(
|
||||
inner_dim,
|
||||
in_channels,
|
||||
1,
|
||||
Default::default(),
|
||||
vs.pp("proj_out"),
|
||||
)?)
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "spatial-transformer");
|
||||
Ok(Self {
|
||||
norm,
|
||||
proj_in,
|
||||
transformer_blocks,
|
||||
proj_out,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (batch, _channel, height, weight) = xs.dims4()?;
|
||||
let residual = xs;
|
||||
let xs = self.norm.forward(xs)?;
|
||||
let (inner_dim, xs) = match &self.proj_in {
|
||||
Proj::Conv2d(p) => {
|
||||
let xs = p.forward(&xs)?;
|
||||
let inner_dim = xs.dim(1)?;
|
||||
let xs = xs
|
||||
.transpose(1, 2)?
|
||||
.t()?
|
||||
.reshape((batch, height * weight, inner_dim))?;
|
||||
(inner_dim, xs)
|
||||
}
|
||||
Proj::Linear(p) => {
|
||||
let inner_dim = xs.dim(1)?;
|
||||
let xs = xs
|
||||
.transpose(1, 2)?
|
||||
.t()?
|
||||
.reshape((batch, height * weight, inner_dim))?;
|
||||
(inner_dim, p.forward(&xs)?)
|
||||
}
|
||||
};
|
||||
let mut xs = xs;
|
||||
for block in self.transformer_blocks.iter() {
|
||||
xs = block.forward(&xs, context)?
|
||||
}
|
||||
let xs = match &self.proj_out {
|
||||
Proj::Conv2d(p) => p.forward(
|
||||
&xs.reshape((batch, height, weight, inner_dim))?
|
||||
.t()?
|
||||
.transpose(1, 2)?,
|
||||
)?,
|
||||
Proj::Linear(p) => p
|
||||
.forward(&xs)?
|
||||
.reshape((batch, height, weight, inner_dim))?
|
||||
.t()?
|
||||
.transpose(1, 2)?,
|
||||
};
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for an attention block.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct AttentionBlockConfig {
|
||||
pub num_head_channels: Option<usize>,
|
||||
pub num_groups: usize,
|
||||
pub rescale_output_factor: f64,
|
||||
pub eps: f64,
|
||||
}
|
||||
|
||||
impl Default for AttentionBlockConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_head_channels: None,
|
||||
num_groups: 32,
|
||||
rescale_output_factor: 1.,
|
||||
eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AttentionBlock {
|
||||
group_norm: nn::GroupNorm,
|
||||
query: nn::Linear,
|
||||
key: nn::Linear,
|
||||
value: nn::Linear,
|
||||
proj_attn: nn::Linear,
|
||||
channels: usize,
|
||||
num_heads: usize,
|
||||
span: tracing::Span,
|
||||
config: AttentionBlockConfig,
|
||||
}
|
||||
|
||||
impl AttentionBlock {
|
||||
pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
|
||||
let num_head_channels = config.num_head_channels.unwrap_or(channels);
|
||||
let num_heads = channels / num_head_channels;
|
||||
let group_norm =
|
||||
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
|
||||
let query = nn::linear(channels, channels, vs.pp("query"))?;
|
||||
let key = nn::linear(channels, channels, vs.pp("key"))?;
|
||||
let value = nn::linear(channels, channels, vs.pp("value"))?;
|
||||
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
|
||||
Ok(Self {
|
||||
group_norm,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
proj_attn,
|
||||
channels,
|
||||
num_heads,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let (batch, t, h_times_d) = xs.dims3()?;
|
||||
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
|
||||
.transpose(1, 2)
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = xs;
|
||||
let (batch, channel, height, width) = xs.dims4()?;
|
||||
let xs = self
|
||||
.group_norm
|
||||
.forward(xs)?
|
||||
.reshape((batch, channel, height * width))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let query_proj = self.query.forward(&xs)?;
|
||||
let key_proj = self.key.forward(&xs)?;
|
||||
let value_proj = self.value.forward(&xs)?;
|
||||
|
||||
let query_states = self.transpose_for_scores(query_proj)?;
|
||||
let key_states = self.transpose_for_scores(key_proj)?;
|
||||
let value_states = self.transpose_for_scores(value_proj)?;
|
||||
|
||||
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
|
||||
let attention_scores =
|
||||
// TODO: Check that this needs two multiplication by `scale`.
|
||||
(query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
|
||||
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
|
||||
|
||||
let xs = attention_probs.matmul(&value_states.contiguous()?)?;
|
||||
let xs = xs.transpose(1, 2)?.contiguous()?;
|
||||
let xs = xs.flatten_from(D::Minus2)?;
|
||||
let xs = self
|
||||
.proj_attn
|
||||
.forward(&xs)?
|
||||
.t()?
|
||||
.reshape((batch, channel, height, width))?;
|
||||
(xs + residual)? / self.config.rescale_output_factor
|
||||
}
|
||||
}
|
305
candle-examples/examples/stable-diffusion/clip.rs
Normal file
305
candle-examples/examples/stable-diffusion/clip.rs
Normal file
@ -0,0 +1,305 @@
|
||||
#![allow(dead_code)]
|
||||
//! Contrastive Language-Image Pre-Training
|
||||
//!
|
||||
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
//! pairs of images with related texts.
|
||||
//!
|
||||
//! https://github.com/openai/CLIP
|
||||
use candle::{Device, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Activation {
|
||||
QuickGelu,
|
||||
Gelu,
|
||||
}
|
||||
|
||||
impl Activation {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
|
||||
Activation::Gelu => xs.gelu(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
embed_dim: usize, // aka config.hidden_size
|
||||
activation: Activation, // aka config.hidden_act
|
||||
intermediate_size: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
// The character to use for padding, use EOS when not set.
|
||||
pub pad_with: Option<String>,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
#[allow(dead_code)]
|
||||
projection_dim: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// The config details can be found in the "text_config" section of this json file:
|
||||
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
||||
pub fn v1_5() -> Self {
|
||||
Self {
|
||||
vocab_size: 49408,
|
||||
embed_dim: 768,
|
||||
intermediate_size: 3072,
|
||||
max_position_embeddings: 77,
|
||||
pad_with: None,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
projection_dim: 768,
|
||||
activation: Activation::QuickGelu,
|
||||
}
|
||||
}
|
||||
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json
|
||||
pub fn v2_1() -> Self {
|
||||
Self {
|
||||
vocab_size: 49408,
|
||||
embed_dim: 1024,
|
||||
intermediate_size: 4096,
|
||||
max_position_embeddings: 77,
|
||||
pad_with: Some("!".to_string()),
|
||||
num_hidden_layers: 23,
|
||||
num_attention_heads: 16,
|
||||
projection_dim: 512,
|
||||
activation: Activation::Gelu,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CLIP Text Model
|
||||
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py
|
||||
#[derive(Debug)]
|
||||
struct ClipTextEmbeddings {
|
||||
token_embedding: candle_nn::Embedding,
|
||||
position_embedding: candle_nn::Embedding,
|
||||
position_ids: Tensor,
|
||||
}
|
||||
|
||||
impl ClipTextEmbeddings {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let token_embedding =
|
||||
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
|
||||
let position_embedding = candle_nn::embedding(
|
||||
c.max_position_embeddings,
|
||||
c.embed_dim,
|
||||
vs.pp("position_embedding"),
|
||||
)?;
|
||||
let position_ids =
|
||||
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
|
||||
Ok(ClipTextEmbeddings {
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
position_ids,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ClipTextEmbeddings {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let token_embedding = self.token_embedding.forward(xs)?;
|
||||
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
||||
token_embedding.broadcast_add(&position_embedding)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ClipAttention {
|
||||
k_proj: candle_nn::Linear,
|
||||
v_proj: candle_nn::Linear,
|
||||
q_proj: candle_nn::Linear,
|
||||
out_proj: candle_nn::Linear,
|
||||
head_dim: usize,
|
||||
scale: f64,
|
||||
num_attention_heads: usize,
|
||||
}
|
||||
|
||||
impl ClipAttention {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let embed_dim = c.embed_dim;
|
||||
let num_attention_heads = c.num_attention_heads;
|
||||
let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
|
||||
let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
|
||||
let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
|
||||
let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
|
||||
let head_dim = embed_dim / num_attention_heads;
|
||||
let scale = (head_dim as f64).powf(-0.5);
|
||||
Ok(ClipAttention {
|
||||
k_proj,
|
||||
v_proj,
|
||||
q_proj,
|
||||
out_proj,
|
||||
head_dim,
|
||||
scale,
|
||||
num_attention_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
|
||||
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let (bsz, seq_len, embed_dim) = xs.dims3()?;
|
||||
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
|
||||
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
|
||||
let query_states = self
|
||||
.shape(&query_states, seq_len, bsz)?
|
||||
.reshape(proj_shape)?;
|
||||
let key_states = self
|
||||
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
|
||||
.reshape(proj_shape)?;
|
||||
let value_states = self
|
||||
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
|
||||
.reshape(proj_shape)?;
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||
|
||||
let src_len = key_states.dim(1)?;
|
||||
let attn_weights = attn_weights
|
||||
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
|
||||
.broadcast_add(causal_attention_mask)?;
|
||||
let attn_weights =
|
||||
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
|
||||
let attn_output = attn_weights.matmul(&value_states)?;
|
||||
let attn_output = attn_output
|
||||
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((bsz, seq_len, embed_dim))?;
|
||||
self.out_proj.forward(&attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ClipMlp {
|
||||
fc1: candle_nn::Linear,
|
||||
fc2: candle_nn::Linear,
|
||||
activation: Activation,
|
||||
}
|
||||
|
||||
impl ClipMlp {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?;
|
||||
let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?;
|
||||
Ok(ClipMlp {
|
||||
fc1,
|
||||
fc2,
|
||||
activation: c.activation,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ClipMlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?;
|
||||
self.fc2.forward(&self.activation.forward(&xs)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ClipEncoderLayer {
|
||||
self_attn: ClipAttention,
|
||||
layer_norm1: candle_nn::LayerNorm,
|
||||
mlp: ClipMlp,
|
||||
layer_norm2: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl ClipEncoderLayer {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
|
||||
let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?;
|
||||
let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
|
||||
let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?;
|
||||
Ok(ClipEncoderLayer {
|
||||
self_attn,
|
||||
layer_norm1,
|
||||
mlp,
|
||||
layer_norm2,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.layer_norm1.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
|
||||
let xs = (xs + residual)?;
|
||||
|
||||
let residual = &xs;
|
||||
let xs = self.layer_norm2.forward(&xs)?;
|
||||
let xs = self.mlp.forward(&xs)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ClipEncoder {
|
||||
layers: Vec<ClipEncoderLayer>,
|
||||
}
|
||||
|
||||
impl ClipEncoder {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let vs = vs.pp("layers");
|
||||
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
|
||||
for index in 0..c.num_hidden_layers {
|
||||
let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(ClipEncoder { layers })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, causal_attention_mask)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
/// A CLIP transformer based model.
|
||||
#[derive(Debug)]
|
||||
pub struct ClipTextTransformer {
|
||||
embeddings: ClipTextEmbeddings,
|
||||
encoder: ClipEncoder,
|
||||
final_layer_norm: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl ClipTextTransformer {
|
||||
pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let vs = vs.pp("text_model");
|
||||
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
|
||||
let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
|
||||
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
|
||||
Ok(ClipTextTransformer {
|
||||
embeddings,
|
||||
encoder,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
|
||||
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||
mask.broadcast_as((bsz, seq_len, seq_len))
|
||||
}
|
||||
}
|
||||
|
||||
impl ClipTextTransformer {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (bsz, seq_len) = xs.dims2()?;
|
||||
let xs = self.embeddings.forward(xs)?;
|
||||
let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
|
||||
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
|
||||
self.final_layer_norm.forward(&xs)
|
||||
}
|
||||
}
|
181
candle-examples/examples/stable-diffusion/ddim.rs
Normal file
181
candle-examples/examples/stable-diffusion/ddim.rs
Normal file
@ -0,0 +1,181 @@
|
||||
#![allow(dead_code)]
|
||||
//! # Denoising Diffusion Implicit Models
|
||||
//!
|
||||
//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler
|
||||
//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM
|
||||
//! generative process is the reverse of a Markovian process, DDIM generalizes
|
||||
//! this to non-Markovian guidance.
|
||||
//!
|
||||
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
|
||||
//! https://arxiv.org/abs/2010.02502
|
||||
use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
/// The configuration for the DDIM scheduler.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct DDIMSchedulerConfig {
|
||||
/// The value of beta at the beginning of training.
|
||||
pub beta_start: f64,
|
||||
/// The value of beta at the end of training.
|
||||
pub beta_end: f64,
|
||||
/// How beta evolved during training.
|
||||
pub beta_schedule: BetaSchedule,
|
||||
/// The amount of noise to be added at each step.
|
||||
pub eta: f64,
|
||||
/// Adjust the indexes of the inference schedule by this value.
|
||||
pub steps_offset: usize,
|
||||
/// prediction type of the scheduler function, one of `epsilon` (predicting
|
||||
/// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
|
||||
/// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
|
||||
pub prediction_type: PredictionType,
|
||||
/// number of diffusion steps used to train the model
|
||||
pub train_timesteps: usize,
|
||||
}
|
||||
|
||||
impl Default for DDIMSchedulerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
beta_start: 0.00085f64,
|
||||
beta_end: 0.012f64,
|
||||
beta_schedule: BetaSchedule::ScaledLinear,
|
||||
eta: 0.,
|
||||
steps_offset: 1,
|
||||
prediction_type: PredictionType::Epsilon,
|
||||
train_timesteps: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The DDIM scheduler.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DDIMScheduler {
|
||||
timesteps: Vec<usize>,
|
||||
alphas_cumprod: Vec<f64>,
|
||||
step_ratio: usize,
|
||||
init_noise_sigma: f64,
|
||||
pub config: DDIMSchedulerConfig,
|
||||
}
|
||||
|
||||
// clip_sample: False, set_alpha_to_one: False
|
||||
impl DDIMScheduler {
|
||||
/// Creates a new DDIM scheduler given the number of steps to be
|
||||
/// used for inference as well as the number of steps that was used
|
||||
/// during training.
|
||||
pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
|
||||
let step_ratio = config.train_timesteps / inference_steps;
|
||||
let timesteps: Vec<usize> = (0..(inference_steps))
|
||||
.map(|s| s * step_ratio + config.steps_offset)
|
||||
.rev()
|
||||
.collect();
|
||||
let betas = match config.beta_schedule {
|
||||
BetaSchedule::ScaledLinear => crate::utils::linspace(
|
||||
config.beta_start.sqrt(),
|
||||
config.beta_end.sqrt(),
|
||||
config.train_timesteps,
|
||||
)?
|
||||
.sqr()?,
|
||||
BetaSchedule::Linear => {
|
||||
crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
|
||||
}
|
||||
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
|
||||
};
|
||||
let betas = betas.to_vec1::<f64>()?;
|
||||
let mut alphas_cumprod = Vec::with_capacity(betas.len());
|
||||
for &beta in betas.iter() {
|
||||
let alpha = 1.0 - beta;
|
||||
alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
|
||||
}
|
||||
Ok(Self {
|
||||
alphas_cumprod,
|
||||
timesteps,
|
||||
step_ratio,
|
||||
init_noise_sigma: 1.,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn timesteps(&self) -> &[usize] {
|
||||
self.timesteps.as_slice()
|
||||
}
|
||||
|
||||
/// Ensures interchangeability with schedulers that need to scale the denoising model input
|
||||
/// depending on the current timestep.
|
||||
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
|
||||
Ok(sample)
|
||||
}
|
||||
|
||||
/// Performs a backward step during inference.
|
||||
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
let timestep = if timestep >= self.alphas_cumprod.len() {
|
||||
timestep - 1
|
||||
} else {
|
||||
timestep
|
||||
};
|
||||
// https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
|
||||
let prev_timestep = if timestep > self.step_ratio {
|
||||
timestep - self.step_ratio
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let alpha_prod_t = self.alphas_cumprod[timestep];
|
||||
let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
|
||||
let beta_prod_t = 1. - alpha_prod_t;
|
||||
let beta_prod_t_prev = 1. - alpha_prod_t_prev;
|
||||
|
||||
let (pred_original_sample, pred_epsilon) = match self.config.prediction_type {
|
||||
PredictionType::Epsilon => {
|
||||
let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)?
|
||||
* (1. / alpha_prod_t.sqrt()))?;
|
||||
(pred_original_sample, model_output.clone())
|
||||
}
|
||||
PredictionType::VPrediction => {
|
||||
let pred_original_sample =
|
||||
((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?;
|
||||
let pred_epsilon =
|
||||
((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?;
|
||||
(pred_original_sample, pred_epsilon)
|
||||
}
|
||||
PredictionType::Sample => {
|
||||
let pred_original_sample = model_output.clone();
|
||||
let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())?
|
||||
* (1. / beta_prod_t.sqrt()))?;
|
||||
(pred_original_sample, pred_epsilon)
|
||||
}
|
||||
};
|
||||
|
||||
let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev);
|
||||
let std_dev_t = self.config.eta * variance.sqrt();
|
||||
|
||||
let pred_sample_direction =
|
||||
(pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?;
|
||||
let prev_sample =
|
||||
((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?;
|
||||
if self.config.eta > 0. {
|
||||
&prev_sample
|
||||
+ Tensor::randn(
|
||||
0f32,
|
||||
std_dev_t as f32,
|
||||
prev_sample.shape(),
|
||||
prev_sample.device(),
|
||||
)?
|
||||
} else {
|
||||
Ok(prev_sample)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
let timestep = if timestep >= self.alphas_cumprod.len() {
|
||||
timestep - 1
|
||||
} else {
|
||||
timestep
|
||||
};
|
||||
let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
|
||||
let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
|
||||
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
|
||||
}
|
||||
|
||||
pub fn init_noise_sigma(&self) -> f64 {
|
||||
self.init_noise_sigma
|
||||
}
|
||||
}
|
65
candle-examples/examples/stable-diffusion/embeddings.rs
Normal file
65
candle-examples/examples/stable-diffusion/embeddings.rs
Normal file
@ -0,0 +1,65 @@
|
||||
#![allow(dead_code)]
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TimestepEmbedding {
|
||||
linear_1: nn::Linear,
|
||||
linear_2: nn::Linear,
|
||||
}
|
||||
|
||||
impl TimestepEmbedding {
|
||||
// act_fn: "silu"
|
||||
pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
|
||||
let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
|
||||
let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
|
||||
Ok(Self { linear_1, linear_2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl TimestepEmbedding {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
|
||||
self.linear_2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Timesteps {
|
||||
num_channels: usize,
|
||||
flip_sin_to_cos: bool,
|
||||
downscale_freq_shift: f64,
|
||||
}
|
||||
|
||||
impl Timesteps {
|
||||
pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
|
||||
Self {
|
||||
num_channels,
|
||||
flip_sin_to_cos,
|
||||
downscale_freq_shift,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Timesteps {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let half_dim = (self.num_channels / 2) as u32;
|
||||
let exponent =
|
||||
(Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
|
||||
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
|
||||
let emb = exponent.exp()?;
|
||||
// emb = timesteps[:, None].float() * emb[None, :]
|
||||
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
|
||||
let (cos, sin) = (emb.cos()?, emb.sin()?);
|
||||
let emb = if self.flip_sin_to_cos {
|
||||
Tensor::cat(&[&cos, &sin], D::Minus1)?
|
||||
} else {
|
||||
Tensor::cat(&[&sin, &cos], D::Minus1)?
|
||||
};
|
||||
if self.num_channels % 2 == 1 {
|
||||
emb.pad_with_zeros(D::Minus2, 0, 1)
|
||||
} else {
|
||||
Ok(emb)
|
||||
}
|
||||
}
|
||||
}
|
326
candle-examples/examples/stable-diffusion/main.rs
Normal file
326
candle-examples/examples/stable-diffusion/main.rs
Normal file
@ -0,0 +1,326 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod attention;
|
||||
mod clip;
|
||||
mod ddim;
|
||||
mod embeddings;
|
||||
mod resnet;
|
||||
mod schedulers;
|
||||
mod stable_diffusion;
|
||||
mod unet_2d;
|
||||
mod unet_2d_blocks;
|
||||
mod utils;
|
||||
mod vae;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use clap::Parser;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const GUIDANCE_SCALE: f64 = 7.5;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The prompt to be used for image generation.
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
|
||||
)]
|
||||
prompt: String,
|
||||
|
||||
#[arg(long, default_value = "")]
|
||||
uncond_prompt: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The height in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
height: Option<usize>,
|
||||
|
||||
/// The width in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
width: Option<usize>,
|
||||
|
||||
/// The UNet weight file, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
unet_weights: Option<String>,
|
||||
|
||||
/// The CLIP weight file, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
clip_weights: Option<String>,
|
||||
|
||||
/// The VAE weight file, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
vae_weights: Option<String>,
|
||||
|
||||
#[arg(long, value_name = "FILE")]
|
||||
/// The file specifying the tokenizer to used for tokenization.
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
|
||||
#[arg(long)]
|
||||
sliced_attention_size: Option<usize>,
|
||||
|
||||
/// The number of steps to run the diffusion for.
|
||||
#[arg(long, default_value_t = 30)]
|
||||
n_steps: usize,
|
||||
|
||||
/// The number of samples to generate.
|
||||
#[arg(long, default_value_t = 1)]
|
||||
num_samples: i64,
|
||||
|
||||
/// The name of the final image to generate.
|
||||
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
|
||||
final_image: String,
|
||||
|
||||
#[arg(long, value_enum, default_value = "v2-1")]
|
||||
sd_version: StableDiffusionVersion,
|
||||
|
||||
/// Generate intermediary images at each step.
|
||||
#[arg(long, action)]
|
||||
intermediary_images: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
||||
enum StableDiffusionVersion {
|
||||
V1_5,
|
||||
V2_1,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ModelFile {
|
||||
Tokenizer,
|
||||
Clip,
|
||||
Unet,
|
||||
Vae,
|
||||
}
|
||||
|
||||
impl StableDiffusionVersion {
|
||||
fn repo(&self) -> &'static str {
|
||||
match self {
|
||||
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
|
||||
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
|
||||
}
|
||||
}
|
||||
|
||||
fn unet_file(&self) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 => "unet/diffusion_pytorch_model.safetensors",
|
||||
}
|
||||
}
|
||||
|
||||
fn vae_file(&self) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 => "vae/diffusion_pytorch_model.safetensors",
|
||||
}
|
||||
}
|
||||
|
||||
fn clip_file(&self) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 => "text_encoder/model.safetensors",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelFile {
|
||||
const TOKENIZER_REPO: &str = "openai/clip-vit-base-patch32";
|
||||
const TOKENIZER_PATH: &str = "tokenizer.json";
|
||||
|
||||
fn get(
|
||||
&self,
|
||||
filename: Option<String>,
|
||||
version: StableDiffusionVersion,
|
||||
) -> Result<std::path::PathBuf> {
|
||||
use hf_hub::api::sync::Api;
|
||||
match filename {
|
||||
Some(filename) => Ok(std::path::PathBuf::from(filename)),
|
||||
None => {
|
||||
let (repo, path) = match self {
|
||||
Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
|
||||
Self::Clip => (version.repo(), version.clip_file()),
|
||||
Self::Unet => (version.repo(), version.unet_file()),
|
||||
Self::Vae => (version.repo(), version.vae_file()),
|
||||
};
|
||||
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
||||
Ok(filename)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn output_filename(
|
||||
basename: &str,
|
||||
sample_idx: i64,
|
||||
num_samples: i64,
|
||||
timestep_idx: Option<usize>,
|
||||
) -> String {
|
||||
let filename = if num_samples > 1 {
|
||||
match basename.rsplit_once('.') {
|
||||
None => format!("{basename}.{sample_idx}.png"),
|
||||
Some((filename_no_extension, extension)) => {
|
||||
format!("{filename_no_extension}.{sample_idx}.{extension}")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
basename.to_string()
|
||||
};
|
||||
match timestep_idx {
|
||||
None => filename,
|
||||
Some(timestep_idx) => match filename.rsplit_once('.') {
|
||||
None => format!("{filename}-{timestep_idx}.png"),
|
||||
Some((filename_no_extension, extension)) => {
|
||||
format!("{filename_no_extension}-{timestep_idx}.{extension}")
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let Args {
|
||||
prompt,
|
||||
uncond_prompt,
|
||||
cpu,
|
||||
height,
|
||||
width,
|
||||
n_steps,
|
||||
tokenizer,
|
||||
final_image,
|
||||
sliced_attention_size,
|
||||
num_samples,
|
||||
sd_version,
|
||||
clip_weights,
|
||||
vae_weights,
|
||||
unet_weights,
|
||||
tracing,
|
||||
..
|
||||
} = args;
|
||||
|
||||
let _guard = if tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let sd_config = match sd_version {
|
||||
StableDiffusionVersion::V1_5 => {
|
||||
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)
|
||||
}
|
||||
StableDiffusionVersion::V2_1 => {
|
||||
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
|
||||
}
|
||||
};
|
||||
|
||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||
let device = candle_examples::device(cpu)?;
|
||||
|
||||
let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let pad_id = match &sd_config.clip.pad_with {
|
||||
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
||||
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
|
||||
};
|
||||
println!("Running with prompt \"{prompt}\".");
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
tokens.push(pad_id)
|
||||
}
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
|
||||
let mut uncond_tokens = tokenizer
|
||||
.encode(uncond_prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
|
||||
println!("Building the Clip transformer.");
|
||||
let clip_weights = ModelFile::Clip.get(clip_weights, sd_version)?;
|
||||
let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?;
|
||||
let text_embeddings = text_model.forward(&tokens)?;
|
||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
||||
|
||||
println!("Building the autoencoder.");
|
||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version)?;
|
||||
let vae = sd_config.build_vae(&vae_weights, &device)?;
|
||||
println!("Building the unet.");
|
||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?;
|
||||
let unet = sd_config.build_unet(&unet_weights, &device, 4)?;
|
||||
|
||||
let bsize = 1;
|
||||
for idx in 0..num_samples {
|
||||
let mut latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
|
||||
&device,
|
||||
)?;
|
||||
|
||||
// scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = (latents * scheduler.init_noise_sigma())?;
|
||||
|
||||
for (timestep_index, ×tep) in scheduler.timesteps().iter().enumerate() {
|
||||
println!("Timestep {timestep_index}/{n_steps}");
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
|
||||
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
|
||||
let noise_pred =
|
||||
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
|
||||
let noise_pred = noise_pred.chunk(2, 0)?;
|
||||
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
|
||||
let noise_pred =
|
||||
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
|
||||
latents = scheduler.step(&noise_pred, timestep, &latents)?;
|
||||
|
||||
if args.intermediary_images {
|
||||
let image = vae.decode(&(&latents / 0.18215)?)?;
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename =
|
||||
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
|
||||
crate::utils::save_image(&image, image_filename)?
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"Generating the final image for sample {}/{}.",
|
||||
idx + 1,
|
||||
num_samples
|
||||
);
|
||||
let image = vae.decode(&(&latents / 0.18215)?)?;
|
||||
// TODO: Add the clamping between 0 and 1.
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||
crate::utils::save_image(&image, image_filename)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
run(args)
|
||||
}
|
134
candle-examples/examples/stable-diffusion/resnet.rs
Normal file
134
candle-examples/examples/stable-diffusion/resnet.rs
Normal file
@ -0,0 +1,134 @@
|
||||
#![allow(dead_code)]
|
||||
//! ResNet Building Blocks
|
||||
//!
|
||||
//! Some Residual Network blocks used in UNet models.
|
||||
//!
|
||||
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
|
||||
//! https://arxiv.org/abs/1512.03385
|
||||
use crate::utils::{conv2d, Conv2d};
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
/// Configuration for a ResNet block.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ResnetBlock2DConfig {
|
||||
/// The number of output channels, defaults to the number of input channels.
|
||||
pub out_channels: Option<usize>,
|
||||
pub temb_channels: Option<usize>,
|
||||
/// The number of groups to use in group normalization.
|
||||
pub groups: usize,
|
||||
pub groups_out: Option<usize>,
|
||||
/// The epsilon to be used in the group normalization operations.
|
||||
pub eps: f64,
|
||||
/// Whether to use a 2D convolution in the skip connection. When using None,
|
||||
/// such a convolution is used if the number of input channels is different from
|
||||
/// the number of output channels.
|
||||
pub use_in_shortcut: Option<bool>,
|
||||
// non_linearity: silu
|
||||
/// The final output is scaled by dividing by this value.
|
||||
pub output_scale_factor: f64,
|
||||
}
|
||||
|
||||
impl Default for ResnetBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
out_channels: None,
|
||||
temb_channels: Some(512),
|
||||
groups: 32,
|
||||
groups_out: None,
|
||||
eps: 1e-6,
|
||||
use_in_shortcut: None,
|
||||
output_scale_factor: 1.,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ResnetBlock2D {
|
||||
norm1: nn::GroupNorm,
|
||||
conv1: Conv2d,
|
||||
norm2: nn::GroupNorm,
|
||||
conv2: Conv2d,
|
||||
time_emb_proj: Option<nn::Linear>,
|
||||
conv_shortcut: Option<Conv2d>,
|
||||
span: tracing::Span,
|
||||
config: ResnetBlock2DConfig,
|
||||
}
|
||||
|
||||
impl ResnetBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
config: ResnetBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let out_channels = config.out_channels.unwrap_or(in_channels);
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
};
|
||||
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
||||
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
||||
let groups_out = config.groups_out.unwrap_or(config.groups);
|
||||
let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
|
||||
let conv2 = conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
|
||||
let use_in_shortcut = config
|
||||
.use_in_shortcut
|
||||
.unwrap_or(in_channels != out_channels);
|
||||
let conv_shortcut = if use_in_shortcut {
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 0,
|
||||
};
|
||||
Some(conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg,
|
||||
vs.pp("conv_shortcut"),
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let time_emb_proj = match config.temb_channels {
|
||||
None => None,
|
||||
Some(temb_channels) => Some(nn::linear(
|
||||
temb_channels,
|
||||
out_channels,
|
||||
vs.pp("time_emb_proj"),
|
||||
)?),
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "resnet2d");
|
||||
Ok(Self {
|
||||
norm1,
|
||||
conv1,
|
||||
norm2,
|
||||
conv2,
|
||||
time_emb_proj,
|
||||
span,
|
||||
config,
|
||||
conv_shortcut,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let shortcut_xs = match &self.conv_shortcut {
|
||||
Some(conv_shortcut) => conv_shortcut.forward(xs)?,
|
||||
None => xs.clone(),
|
||||
};
|
||||
let xs = self.norm1.forward(xs)?;
|
||||
let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
|
||||
let xs = match (temb, &self.time_emb_proj) {
|
||||
(Some(temb), Some(time_emb_proj)) => time_emb_proj
|
||||
.forward(&nn::ops::silu(temb)?)?
|
||||
.unsqueeze(D::Minus1)?
|
||||
.unsqueeze(D::Minus1)?
|
||||
.broadcast_add(&xs)?,
|
||||
_ => xs,
|
||||
};
|
||||
let xs = self
|
||||
.conv2
|
||||
.forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
|
||||
(shortcut_xs + xs)? / self.config.output_scale_factor
|
||||
}
|
||||
}
|
45
candle-examples/examples/stable-diffusion/schedulers.rs
Normal file
45
candle-examples/examples/stable-diffusion/schedulers.rs
Normal file
@ -0,0 +1,45 @@
|
||||
#![allow(dead_code)]
|
||||
//! # Diffusion pipelines and models
|
||||
//!
|
||||
//! Noise schedulers can be used to set the trade-off between
|
||||
//! inference speed and quality.
|
||||
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
/// This represents how beta ranges from its minimum value to the maximum
|
||||
/// during training.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum BetaSchedule {
|
||||
/// Linear interpolation.
|
||||
Linear,
|
||||
/// Linear interpolation of the square root of beta.
|
||||
ScaledLinear,
|
||||
/// Glide cosine schedule
|
||||
SquaredcosCapV2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum PredictionType {
|
||||
Epsilon,
|
||||
VPrediction,
|
||||
Sample,
|
||||
}
|
||||
|
||||
/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
/// `(1-beta)` over time from `t = [0,1]`.
|
||||
///
|
||||
/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)`
|
||||
/// up to that part of the diffusion process.
|
||||
pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> {
|
||||
let alpha_bar = |time_step: usize| {
|
||||
f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2)
|
||||
};
|
||||
let mut betas = Vec::with_capacity(num_diffusion_timesteps);
|
||||
for i in 0..num_diffusion_timesteps {
|
||||
let t1 = i / num_diffusion_timesteps;
|
||||
let t2 = (i + 1) / num_diffusion_timesteps;
|
||||
betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta));
|
||||
}
|
||||
let betas_len = betas.len();
|
||||
Tensor::from_vec(betas, betas_len, &candle::Device::Cpu)
|
||||
}
|
216
candle-examples/examples/stable-diffusion/stable_diffusion.rs
Normal file
216
candle-examples/examples/stable-diffusion/stable_diffusion.rs
Normal file
@ -0,0 +1,216 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::schedulers::PredictionType;
|
||||
use crate::{clip, ddim, unet_2d, vae};
|
||||
use candle::{DType, Device, Result};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct StableDiffusionConfig {
|
||||
pub width: usize,
|
||||
pub height: usize,
|
||||
pub clip: clip::Config,
|
||||
autoencoder: vae::AutoEncoderKLConfig,
|
||||
unet: unet_2d::UNet2DConditionModelConfig,
|
||||
scheduler: ddim::DDIMSchedulerConfig,
|
||||
}
|
||||
|
||||
impl StableDiffusionConfig {
|
||||
pub fn v1_5(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
) -> Self {
|
||||
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
};
|
||||
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
|
||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||
blocks: vec![
|
||||
bc(320, true, 8),
|
||||
bc(640, true, 8),
|
||||
bc(1280, true, 8),
|
||||
bc(1280, false, 8),
|
||||
],
|
||||
center_input_sample: false,
|
||||
cross_attention_dim: 768,
|
||||
downsample_padding: 1,
|
||||
flip_sin_to_cos: true,
|
||||
freq_shift: 0.,
|
||||
layers_per_block: 2,
|
||||
mid_block_scale_factor: 1.,
|
||||
norm_eps: 1e-5,
|
||||
norm_num_groups: 32,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: false,
|
||||
};
|
||||
let autoencoder = vae::AutoEncoderKLConfig {
|
||||
block_out_channels: vec![128, 256, 512, 512],
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
|
||||
height
|
||||
} else {
|
||||
512
|
||||
};
|
||||
|
||||
let width = if let Some(width) = width {
|
||||
assert_eq!(width % 8, 0, "width has to be divisible by 8");
|
||||
width
|
||||
} else {
|
||||
512
|
||||
};
|
||||
|
||||
Self {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::v1_5(),
|
||||
autoencoder,
|
||||
scheduler: Default::default(),
|
||||
unet,
|
||||
}
|
||||
}
|
||||
|
||||
fn v2_1_(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
prediction_type: PredictionType,
|
||||
) -> Self {
|
||||
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
|
||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||
blocks: vec![
|
||||
bc(320, true, 5),
|
||||
bc(640, true, 10),
|
||||
bc(1280, true, 20),
|
||||
bc(1280, false, 20),
|
||||
],
|
||||
center_input_sample: false,
|
||||
cross_attention_dim: 1024,
|
||||
downsample_padding: 1,
|
||||
flip_sin_to_cos: true,
|
||||
freq_shift: 0.,
|
||||
layers_per_block: 2,
|
||||
mid_block_scale_factor: 1.,
|
||||
norm_eps: 1e-5,
|
||||
norm_num_groups: 32,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: true,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json
|
||||
let autoencoder = vae::AutoEncoderKLConfig {
|
||||
block_out_channels: vec![128, 256, 512, 512],
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = ddim::DDIMSchedulerConfig {
|
||||
prediction_type,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
|
||||
height
|
||||
} else {
|
||||
768
|
||||
};
|
||||
|
||||
let width = if let Some(width) = width {
|
||||
assert_eq!(width % 8, 0, "width has to be divisible by 8");
|
||||
width
|
||||
} else {
|
||||
768
|
||||
};
|
||||
|
||||
Self {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::v2_1(),
|
||||
autoencoder,
|
||||
scheduler,
|
||||
unet,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn v2_1(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
) -> Self {
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json
|
||||
Self::v2_1_(
|
||||
sliced_attention_size,
|
||||
height,
|
||||
width,
|
||||
PredictionType::VPrediction,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn v2_1_inpaint(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
) -> Self {
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json
|
||||
// This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction
|
||||
// type being "epsilon" by default and not "v_prediction".
|
||||
Self::v2_1_(
|
||||
sliced_attention_size,
|
||||
height,
|
||||
width,
|
||||
PredictionType::Epsilon,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn build_vae<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
vae_weights: P,
|
||||
device: &Device,
|
||||
) -> Result<vae::AutoEncoderKL> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
|
||||
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
|
||||
Ok(autoencoder)
|
||||
}
|
||||
|
||||
pub fn build_unet<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
unet_weights: P,
|
||||
device: &Device,
|
||||
in_channels: usize,
|
||||
) -> Result<unet_2d::UNet2DConditionModel> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||
let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?;
|
||||
Ok(unet)
|
||||
}
|
||||
|
||||
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
|
||||
ddim::DDIMScheduler::new(n_steps, self.scheduler)
|
||||
}
|
||||
|
||||
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
clip_weights: P,
|
||||
device: &Device,
|
||||
) -> Result<clip::ClipTextTransformer> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vs = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||
let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
|
||||
Ok(text_model)
|
||||
}
|
||||
}
|
386
candle-examples/examples/stable-diffusion/unet_2d.rs
Normal file
386
candle-examples/examples/stable-diffusion/unet_2d.rs
Normal file
@ -0,0 +1,386 @@
|
||||
#![allow(dead_code)]
|
||||
//! 2D UNet Denoising Models
|
||||
//!
|
||||
//! The 2D Unet models take as input a noisy sample and the current diffusion
|
||||
//! timestep and return a denoised version of the input.
|
||||
use crate::embeddings::{TimestepEmbedding, Timesteps};
|
||||
use crate::unet_2d_blocks::*;
|
||||
use crate::utils::{conv2d, Conv2d};
|
||||
use candle::{DType, Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct BlockConfig {
|
||||
pub out_channels: usize,
|
||||
pub use_cross_attn: bool,
|
||||
pub attention_head_dim: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UNet2DConditionModelConfig {
|
||||
pub center_input_sample: bool,
|
||||
pub flip_sin_to_cos: bool,
|
||||
pub freq_shift: f64,
|
||||
pub blocks: Vec<BlockConfig>,
|
||||
pub layers_per_block: usize,
|
||||
pub downsample_padding: usize,
|
||||
pub mid_block_scale_factor: f64,
|
||||
pub norm_num_groups: usize,
|
||||
pub norm_eps: f64,
|
||||
pub cross_attention_dim: usize,
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
}
|
||||
|
||||
impl Default for UNet2DConditionModelConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
center_input_sample: false,
|
||||
flip_sin_to_cos: true,
|
||||
freq_shift: 0.,
|
||||
blocks: vec![
|
||||
BlockConfig {
|
||||
out_channels: 320,
|
||||
use_cross_attn: true,
|
||||
attention_head_dim: 8,
|
||||
},
|
||||
BlockConfig {
|
||||
out_channels: 640,
|
||||
use_cross_attn: true,
|
||||
attention_head_dim: 8,
|
||||
},
|
||||
BlockConfig {
|
||||
out_channels: 1280,
|
||||
use_cross_attn: true,
|
||||
attention_head_dim: 8,
|
||||
},
|
||||
BlockConfig {
|
||||
out_channels: 1280,
|
||||
use_cross_attn: false,
|
||||
attention_head_dim: 8,
|
||||
},
|
||||
],
|
||||
layers_per_block: 2,
|
||||
downsample_padding: 1,
|
||||
mid_block_scale_factor: 1.,
|
||||
norm_num_groups: 32,
|
||||
norm_eps: 1e-5,
|
||||
cross_attention_dim: 1280,
|
||||
sliced_attention_size: None,
|
||||
use_linear_projection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum UNetDownBlock {
|
||||
Basic(DownBlock2D),
|
||||
CrossAttn(CrossAttnDownBlock2D),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum UNetUpBlock {
|
||||
Basic(UpBlock2D),
|
||||
CrossAttn(CrossAttnUpBlock2D),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UNet2DConditionModel {
|
||||
conv_in: Conv2d,
|
||||
time_proj: Timesteps,
|
||||
time_embedding: TimestepEmbedding,
|
||||
down_blocks: Vec<UNetDownBlock>,
|
||||
mid_block: UNetMidBlock2DCrossAttn,
|
||||
up_blocks: Vec<UNetUpBlock>,
|
||||
conv_norm_out: nn::GroupNorm,
|
||||
conv_out: Conv2d,
|
||||
span: tracing::Span,
|
||||
config: UNet2DConditionModelConfig,
|
||||
}
|
||||
|
||||
impl UNet2DConditionModel {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
config: UNet2DConditionModelConfig,
|
||||
) -> Result<Self> {
|
||||
let n_blocks = config.blocks.len();
|
||||
let b_channels = config.blocks[0].out_channels;
|
||||
let bl_channels = config.blocks.last().unwrap().out_channels;
|
||||
let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
|
||||
let time_embed_dim = b_channels * 4;
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
};
|
||||
let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
|
||||
|
||||
let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
|
||||
let time_embedding =
|
||||
TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
|
||||
|
||||
let vs_db = vs.pp("down_blocks");
|
||||
let down_blocks = (0..n_blocks)
|
||||
.map(|i| {
|
||||
let BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
} = config.blocks[i];
|
||||
|
||||
// Enable automatic attention slicing if the config sliced_attention_size is set to 0.
|
||||
let sliced_attention_size = match config.sliced_attention_size {
|
||||
Some(0) => Some(attention_head_dim / 2),
|
||||
_ => config.sliced_attention_size,
|
||||
};
|
||||
|
||||
let in_channels = if i > 0 {
|
||||
config.blocks[i - 1].out_channels
|
||||
} else {
|
||||
b_channels
|
||||
};
|
||||
let db_cfg = DownBlock2DConfig {
|
||||
num_layers: config.layers_per_block,
|
||||
resnet_eps: config.norm_eps,
|
||||
resnet_groups: config.norm_num_groups,
|
||||
add_downsample: i < n_blocks - 1,
|
||||
downsample_padding: config.downsample_padding,
|
||||
..Default::default()
|
||||
};
|
||||
if use_cross_attn {
|
||||
let config = CrossAttnDownBlock2DConfig {
|
||||
downblock: db_cfg,
|
||||
attn_num_head_channels: attention_head_dim,
|
||||
cross_attention_dim: config.cross_attention_dim,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
};
|
||||
let block = CrossAttnDownBlock2D::new(
|
||||
vs_db.pp(&i.to_string()),
|
||||
in_channels,
|
||||
out_channels,
|
||||
Some(time_embed_dim),
|
||||
config,
|
||||
)?;
|
||||
Ok(UNetDownBlock::CrossAttn(block))
|
||||
} else {
|
||||
let block = DownBlock2D::new(
|
||||
vs_db.pp(&i.to_string()),
|
||||
in_channels,
|
||||
out_channels,
|
||||
Some(time_embed_dim),
|
||||
db_cfg,
|
||||
)?;
|
||||
Ok(UNetDownBlock::Basic(block))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let mid_cfg = UNetMidBlock2DCrossAttnConfig {
|
||||
resnet_eps: config.norm_eps,
|
||||
output_scale_factor: config.mid_block_scale_factor,
|
||||
cross_attn_dim: config.cross_attention_dim,
|
||||
attn_num_head_channels: bl_attention_head_dim,
|
||||
resnet_groups: Some(config.norm_num_groups),
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
..Default::default()
|
||||
};
|
||||
let mid_block = UNetMidBlock2DCrossAttn::new(
|
||||
vs.pp("mid_block"),
|
||||
bl_channels,
|
||||
Some(time_embed_dim),
|
||||
mid_cfg,
|
||||
)?;
|
||||
|
||||
let vs_ub = vs.pp("up_blocks");
|
||||
let up_blocks = (0..n_blocks)
|
||||
.map(|i| {
|
||||
let BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
} = config.blocks[n_blocks - 1 - i];
|
||||
|
||||
// Enable automatic attention slicing if the config sliced_attention_size is set to 0.
|
||||
let sliced_attention_size = match config.sliced_attention_size {
|
||||
Some(0) => Some(attention_head_dim / 2),
|
||||
_ => config.sliced_attention_size,
|
||||
};
|
||||
|
||||
let prev_out_channels = if i > 0 {
|
||||
config.blocks[n_blocks - i].out_channels
|
||||
} else {
|
||||
bl_channels
|
||||
};
|
||||
let in_channels = {
|
||||
let index = if i == n_blocks - 1 {
|
||||
0
|
||||
} else {
|
||||
n_blocks - i - 2
|
||||
};
|
||||
config.blocks[index].out_channels
|
||||
};
|
||||
let ub_cfg = UpBlock2DConfig {
|
||||
num_layers: config.layers_per_block + 1,
|
||||
resnet_eps: config.norm_eps,
|
||||
resnet_groups: config.norm_num_groups,
|
||||
add_upsample: i < n_blocks - 1,
|
||||
..Default::default()
|
||||
};
|
||||
if use_cross_attn {
|
||||
let config = CrossAttnUpBlock2DConfig {
|
||||
upblock: ub_cfg,
|
||||
attn_num_head_channels: attention_head_dim,
|
||||
cross_attention_dim: config.cross_attention_dim,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
};
|
||||
let block = CrossAttnUpBlock2D::new(
|
||||
vs_ub.pp(&i.to_string()),
|
||||
in_channels,
|
||||
prev_out_channels,
|
||||
out_channels,
|
||||
Some(time_embed_dim),
|
||||
config,
|
||||
)?;
|
||||
Ok(UNetUpBlock::CrossAttn(block))
|
||||
} else {
|
||||
let block = UpBlock2D::new(
|
||||
vs_ub.pp(&i.to_string()),
|
||||
in_channels,
|
||||
prev_out_channels,
|
||||
out_channels,
|
||||
Some(time_embed_dim),
|
||||
ub_cfg,
|
||||
)?;
|
||||
Ok(UNetUpBlock::Basic(block))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let conv_norm_out = nn::group_norm(
|
||||
config.norm_num_groups,
|
||||
b_channels,
|
||||
config.norm_eps,
|
||||
vs.pp("conv_norm_out"),
|
||||
)?;
|
||||
let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "unet2d");
|
||||
Ok(Self {
|
||||
conv_in,
|
||||
time_proj,
|
||||
time_embedding,
|
||||
down_blocks,
|
||||
mid_block,
|
||||
up_blocks,
|
||||
conv_norm_out,
|
||||
conv_out,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
timestep: f64,
|
||||
encoder_hidden_states: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
|
||||
}
|
||||
|
||||
pub fn forward_with_additional_residuals(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
timestep: f64,
|
||||
encoder_hidden_states: &Tensor,
|
||||
down_block_additional_residuals: Option<&[Tensor]>,
|
||||
mid_block_additional_residual: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (bsize, _channels, height, width) = xs.dims4()?;
|
||||
let device = xs.device();
|
||||
let n_blocks = self.config.blocks.len();
|
||||
let num_upsamplers = n_blocks - 1;
|
||||
let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
|
||||
let forward_upsample_size =
|
||||
height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
|
||||
// 0. center input if necessary
|
||||
let xs = if self.config.center_input_sample {
|
||||
((xs * 2.0)? - 1.0)?
|
||||
} else {
|
||||
xs.clone()
|
||||
};
|
||||
// 1. time
|
||||
let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?;
|
||||
let emb = self.time_proj.forward(&emb)?;
|
||||
let emb = self.time_embedding.forward(&emb)?;
|
||||
// 2. pre-process
|
||||
let xs = self.conv_in.forward(&xs)?;
|
||||
// 3. down
|
||||
let mut down_block_res_xs = vec![xs.clone()];
|
||||
let mut xs = xs;
|
||||
for down_block in self.down_blocks.iter() {
|
||||
let (_xs, res_xs) = match down_block {
|
||||
UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
|
||||
UNetDownBlock::CrossAttn(b) => {
|
||||
b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
|
||||
}
|
||||
};
|
||||
down_block_res_xs.extend(res_xs);
|
||||
xs = _xs;
|
||||
}
|
||||
|
||||
let new_down_block_res_xs =
|
||||
if let Some(down_block_additional_residuals) = down_block_additional_residuals {
|
||||
let mut v = vec![];
|
||||
// A previous version of this code had a bug because of the addition being made
|
||||
// in place via += hence modifying the input of the mid block.
|
||||
for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
|
||||
v.push((&down_block_res_xs[i] + residuals)?)
|
||||
}
|
||||
v
|
||||
} else {
|
||||
down_block_res_xs
|
||||
};
|
||||
let mut down_block_res_xs = new_down_block_res_xs;
|
||||
|
||||
// 4. mid
|
||||
let xs = self
|
||||
.mid_block
|
||||
.forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
|
||||
let xs = match mid_block_additional_residual {
|
||||
None => xs,
|
||||
Some(m) => (m + xs)?,
|
||||
};
|
||||
// 5. up
|
||||
let mut xs = xs;
|
||||
let mut upsample_size = None;
|
||||
for (i, up_block) in self.up_blocks.iter().enumerate() {
|
||||
let n_resnets = match up_block {
|
||||
UNetUpBlock::Basic(b) => b.resnets.len(),
|
||||
UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
|
||||
};
|
||||
let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
|
||||
if i < n_blocks - 1 && forward_upsample_size {
|
||||
let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
|
||||
upsample_size = Some((h, w))
|
||||
}
|
||||
xs = match up_block {
|
||||
UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
|
||||
UNetUpBlock::CrossAttn(b) => b.forward(
|
||||
&xs,
|
||||
&res_xs,
|
||||
Some(&emb),
|
||||
upsample_size,
|
||||
Some(encoder_hidden_states),
|
||||
)?,
|
||||
};
|
||||
}
|
||||
// 6. post-process
|
||||
let xs = self.conv_norm_out.forward(&xs)?;
|
||||
let xs = nn::ops::silu(&xs)?;
|
||||
self.conv_out.forward(&xs)
|
||||
}
|
||||
}
|
851
candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
Normal file
851
candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
Normal file
@ -0,0 +1,851 @@
|
||||
#![allow(dead_code)]
|
||||
//! 2D UNet Building Blocks
|
||||
//!
|
||||
use crate::attention::{
|
||||
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
|
||||
};
|
||||
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
||||
use crate::utils::{conv2d, Conv2d};
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Downsample2D {
|
||||
conv: Option<Conv2d>,
|
||||
padding: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Downsample2D {
|
||||
fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
use_conv: bool,
|
||||
out_channels: usize,
|
||||
padding: usize,
|
||||
) -> Result<Self> {
|
||||
let conv = if use_conv {
|
||||
let config = nn::Conv2dConfig { stride: 2, padding };
|
||||
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
||||
Some(conv)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "downsample2d");
|
||||
Ok(Self {
|
||||
conv,
|
||||
padding,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Downsample2D {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
match &self.conv {
|
||||
None => xs.avg_pool2d((2, 2), (2, 2)),
|
||||
Some(conv) => {
|
||||
if self.padding == 0 {
|
||||
let xs = xs
|
||||
.pad_with_zeros(D::Minus1, 0, 1)?
|
||||
.pad_with_zeros(D::Minus2, 0, 1)?;
|
||||
conv.forward(&xs)
|
||||
} else {
|
||||
conv.forward(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This does not support the conv-transpose mode.
|
||||
#[derive(Debug)]
|
||||
struct Upsample2D {
|
||||
conv: Conv2d,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Upsample2D {
|
||||
fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
|
||||
let config = nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "upsample2d");
|
||||
Ok(Self { conv, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Upsample2D {
|
||||
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = match size {
|
||||
None => {
|
||||
let (_bsize, _channels, h, w) = xs.dims4()?;
|
||||
xs.upsample_nearest2d(2 * h, 2 * w)?
|
||||
}
|
||||
Some((h, w)) => xs.upsample_nearest2d(h, w)?,
|
||||
};
|
||||
self.conv.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct DownEncoderBlock2DConfig {
|
||||
pub num_layers: usize,
|
||||
pub resnet_eps: f64,
|
||||
pub resnet_groups: usize,
|
||||
pub output_scale_factor: f64,
|
||||
pub add_downsample: bool,
|
||||
pub downsample_padding: usize,
|
||||
}
|
||||
|
||||
impl Default for DownEncoderBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: 32,
|
||||
output_scale_factor: 1.,
|
||||
add_downsample: true,
|
||||
downsample_padding: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DownEncoderBlock2D {
|
||||
resnets: Vec<ResnetBlock2D>,
|
||||
downsampler: Option<Downsample2D>,
|
||||
span: tracing::Span,
|
||||
pub config: DownEncoderBlock2DConfig,
|
||||
}
|
||||
|
||||
impl DownEncoderBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
config: DownEncoderBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let resnets: Vec<_> = {
|
||||
let vs = vs.pp("resnets");
|
||||
let conv_cfg = ResnetBlock2DConfig {
|
||||
eps: config.resnet_eps,
|
||||
out_channels: Some(out_channels),
|
||||
groups: config.resnet_groups,
|
||||
output_scale_factor: config.output_scale_factor,
|
||||
temb_channels: None,
|
||||
..Default::default()
|
||||
};
|
||||
(0..(config.num_layers))
|
||||
.map(|i| {
|
||||
let in_channels = if i == 0 { in_channels } else { out_channels };
|
||||
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
};
|
||||
let downsampler = if config.add_downsample {
|
||||
let downsample = Downsample2D::new(
|
||||
vs.pp("downsamplers").pp("0"),
|
||||
out_channels,
|
||||
true,
|
||||
out_channels,
|
||||
config.downsample_padding,
|
||||
)?;
|
||||
Some(downsample)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "down-enc2d");
|
||||
Ok(Self {
|
||||
resnets,
|
||||
downsampler,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl DownEncoderBlock2D {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for resnet in self.resnets.iter() {
|
||||
xs = resnet.forward(&xs, None)?
|
||||
}
|
||||
match &self.downsampler {
|
||||
Some(downsampler) => downsampler.forward(&xs),
|
||||
None => Ok(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct UpDecoderBlock2DConfig {
|
||||
pub num_layers: usize,
|
||||
pub resnet_eps: f64,
|
||||
pub resnet_groups: usize,
|
||||
pub output_scale_factor: f64,
|
||||
pub add_upsample: bool,
|
||||
}
|
||||
|
||||
impl Default for UpDecoderBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: 32,
|
||||
output_scale_factor: 1.,
|
||||
add_upsample: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UpDecoderBlock2D {
|
||||
resnets: Vec<ResnetBlock2D>,
|
||||
upsampler: Option<Upsample2D>,
|
||||
span: tracing::Span,
|
||||
pub config: UpDecoderBlock2DConfig,
|
||||
}
|
||||
|
||||
impl UpDecoderBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
config: UpDecoderBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let resnets: Vec<_> = {
|
||||
let vs = vs.pp("resnets");
|
||||
let conv_cfg = ResnetBlock2DConfig {
|
||||
out_channels: Some(out_channels),
|
||||
eps: config.resnet_eps,
|
||||
groups: config.resnet_groups,
|
||||
output_scale_factor: config.output_scale_factor,
|
||||
temb_channels: None,
|
||||
..Default::default()
|
||||
};
|
||||
(0..(config.num_layers))
|
||||
.map(|i| {
|
||||
let in_channels = if i == 0 { in_channels } else { out_channels };
|
||||
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
};
|
||||
let upsampler = if config.add_upsample {
|
||||
let upsample =
|
||||
Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
|
||||
Some(upsample)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "up-dec2d");
|
||||
Ok(Self {
|
||||
resnets,
|
||||
upsampler,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl UpDecoderBlock2D {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for resnet in self.resnets.iter() {
|
||||
xs = resnet.forward(&xs, None)?
|
||||
}
|
||||
match &self.upsampler {
|
||||
Some(upsampler) => upsampler.forward(&xs, None),
|
||||
None => Ok(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct UNetMidBlock2DConfig {
|
||||
pub num_layers: usize,
|
||||
pub resnet_eps: f64,
|
||||
pub resnet_groups: Option<usize>,
|
||||
pub attn_num_head_channels: Option<usize>,
|
||||
// attention_type "default"
|
||||
pub output_scale_factor: f64,
|
||||
}
|
||||
|
||||
impl Default for UNetMidBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: Some(32),
|
||||
attn_num_head_channels: Some(1),
|
||||
output_scale_factor: 1.,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UNetMidBlock2D {
|
||||
resnet: ResnetBlock2D,
|
||||
attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
|
||||
span: tracing::Span,
|
||||
pub config: UNetMidBlock2DConfig,
|
||||
}
|
||||
|
||||
impl UNetMidBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
config: UNetMidBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let vs_resnets = vs.pp("resnets");
|
||||
let vs_attns = vs.pp("attentions");
|
||||
let resnet_groups = config
|
||||
.resnet_groups
|
||||
.unwrap_or_else(|| usize::min(in_channels / 4, 32));
|
||||
let resnet_cfg = ResnetBlock2DConfig {
|
||||
eps: config.resnet_eps,
|
||||
groups: resnet_groups,
|
||||
output_scale_factor: config.output_scale_factor,
|
||||
temb_channels,
|
||||
..Default::default()
|
||||
};
|
||||
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
|
||||
let attn_cfg = AttentionBlockConfig {
|
||||
num_head_channels: config.attn_num_head_channels,
|
||||
num_groups: resnet_groups,
|
||||
rescale_output_factor: config.output_scale_factor,
|
||||
eps: config.resnet_eps,
|
||||
};
|
||||
let mut attn_resnets = vec![];
|
||||
for index in 0..config.num_layers {
|
||||
let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
|
||||
let resnet = ResnetBlock2D::new(
|
||||
vs_resnets.pp(&(index + 1).to_string()),
|
||||
in_channels,
|
||||
resnet_cfg,
|
||||
)?;
|
||||
attn_resnets.push((attn, resnet))
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mid2d");
|
||||
Ok(Self {
|
||||
resnet,
|
||||
attn_resnets,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = self.resnet.forward(xs, temb)?;
|
||||
for (attn, resnet) in self.attn_resnets.iter() {
|
||||
xs = resnet.forward(&attn.forward(&xs)?, temb)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct UNetMidBlock2DCrossAttnConfig {
|
||||
pub num_layers: usize,
|
||||
pub resnet_eps: f64,
|
||||
pub resnet_groups: Option<usize>,
|
||||
pub attn_num_head_channels: usize,
|
||||
// attention_type "default"
|
||||
pub output_scale_factor: f64,
|
||||
pub cross_attn_dim: usize,
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
}
|
||||
|
||||
impl Default for UNetMidBlock2DCrossAttnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: Some(32),
|
||||
attn_num_head_channels: 1,
|
||||
output_scale_factor: 1.,
|
||||
cross_attn_dim: 1280,
|
||||
sliced_attention_size: None, // Sliced attention disabled
|
||||
use_linear_projection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UNetMidBlock2DCrossAttn {
|
||||
resnet: ResnetBlock2D,
|
||||
attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
|
||||
span: tracing::Span,
|
||||
pub config: UNetMidBlock2DCrossAttnConfig,
|
||||
}
|
||||
|
||||
impl UNetMidBlock2DCrossAttn {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
config: UNetMidBlock2DCrossAttnConfig,
|
||||
) -> Result<Self> {
|
||||
let vs_resnets = vs.pp("resnets");
|
||||
let vs_attns = vs.pp("attentions");
|
||||
let resnet_groups = config
|
||||
.resnet_groups
|
||||
.unwrap_or_else(|| usize::min(in_channels / 4, 32));
|
||||
let resnet_cfg = ResnetBlock2DConfig {
|
||||
eps: config.resnet_eps,
|
||||
groups: resnet_groups,
|
||||
output_scale_factor: config.output_scale_factor,
|
||||
temb_channels,
|
||||
..Default::default()
|
||||
};
|
||||
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
|
||||
let n_heads = config.attn_num_head_channels;
|
||||
let attn_cfg = SpatialTransformerConfig {
|
||||
depth: 1,
|
||||
num_groups: resnet_groups,
|
||||
context_dim: Some(config.cross_attn_dim),
|
||||
sliced_attention_size: config.sliced_attention_size,
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
};
|
||||
let mut attn_resnets = vec![];
|
||||
for index in 0..config.num_layers {
|
||||
let attn = SpatialTransformer::new(
|
||||
vs_attns.pp(&index.to_string()),
|
||||
in_channels,
|
||||
n_heads,
|
||||
in_channels / n_heads,
|
||||
attn_cfg,
|
||||
)?;
|
||||
let resnet = ResnetBlock2D::new(
|
||||
vs_resnets.pp(&(index + 1).to_string()),
|
||||
in_channels,
|
||||
resnet_cfg,
|
||||
)?;
|
||||
attn_resnets.push((attn, resnet))
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d");
|
||||
Ok(Self {
|
||||
resnet,
|
||||
attn_resnets,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
temb: Option<&Tensor>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = self.resnet.forward(xs, temb)?;
|
||||
for (attn, resnet) in self.attn_resnets.iter() {
|
||||
xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct DownBlock2DConfig {
|
||||
pub num_layers: usize,
|
||||
pub resnet_eps: f64,
|
||||
// resnet_time_scale_shift: "default"
|
||||
// resnet_act_fn: "swish"
|
||||
pub resnet_groups: usize,
|
||||
pub output_scale_factor: f64,
|
||||
pub add_downsample: bool,
|
||||
pub downsample_padding: usize,
|
||||
}
|
||||
|
||||
impl Default for DownBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: 32,
|
||||
output_scale_factor: 1.,
|
||||
add_downsample: true,
|
||||
downsample_padding: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DownBlock2D {
|
||||
resnets: Vec<ResnetBlock2D>,
|
||||
downsampler: Option<Downsample2D>,
|
||||
span: tracing::Span,
|
||||
pub config: DownBlock2DConfig,
|
||||
}
|
||||
|
||||
impl DownBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
config: DownBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let vs_resnets = vs.pp("resnets");
|
||||
let resnet_cfg = ResnetBlock2DConfig {
|
||||
out_channels: Some(out_channels),
|
||||
eps: config.resnet_eps,
|
||||
output_scale_factor: config.output_scale_factor,
|
||||
temb_channels,
|
||||
..Default::default()
|
||||
};
|
||||
let resnets = (0..config.num_layers)
|
||||
.map(|i| {
|
||||
let in_channels = if i == 0 { in_channels } else { out_channels };
|
||||
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let downsampler = if config.add_downsample {
|
||||
let downsampler = Downsample2D::new(
|
||||
vs.pp("downsamplers").pp("0"),
|
||||
out_channels,
|
||||
true,
|
||||
out_channels,
|
||||
config.downsample_padding,
|
||||
)?;
|
||||
Some(downsampler)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "down2d");
|
||||
Ok(Self {
|
||||
resnets,
|
||||
downsampler,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
let mut output_states = vec![];
|
||||
for resnet in self.resnets.iter() {
|
||||
xs = resnet.forward(&xs, temb)?;
|
||||
output_states.push(xs.clone());
|
||||
}
|
||||
let xs = match &self.downsampler {
|
||||
Some(downsampler) => {
|
||||
let xs = downsampler.forward(&xs)?;
|
||||
output_states.push(xs.clone());
|
||||
xs
|
||||
}
|
||||
None => xs,
|
||||
};
|
||||
Ok((xs, output_states))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct CrossAttnDownBlock2DConfig {
|
||||
pub downblock: DownBlock2DConfig,
|
||||
pub attn_num_head_channels: usize,
|
||||
pub cross_attention_dim: usize,
|
||||
// attention_type: "default"
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
}
|
||||
|
||||
impl Default for CrossAttnDownBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
downblock: Default::default(),
|
||||
attn_num_head_channels: 1,
|
||||
cross_attention_dim: 1280,
|
||||
sliced_attention_size: None,
|
||||
use_linear_projection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CrossAttnDownBlock2D {
|
||||
downblock: DownBlock2D,
|
||||
attentions: Vec<SpatialTransformer>,
|
||||
span: tracing::Span,
|
||||
pub config: CrossAttnDownBlock2DConfig,
|
||||
}
|
||||
|
||||
impl CrossAttnDownBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
config: CrossAttnDownBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let downblock = DownBlock2D::new(
|
||||
vs.clone(),
|
||||
in_channels,
|
||||
out_channels,
|
||||
temb_channels,
|
||||
config.downblock,
|
||||
)?;
|
||||
let n_heads = config.attn_num_head_channels;
|
||||
let cfg = SpatialTransformerConfig {
|
||||
depth: 1,
|
||||
context_dim: Some(config.cross_attention_dim),
|
||||
num_groups: config.downblock.resnet_groups,
|
||||
sliced_attention_size: config.sliced_attention_size,
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
};
|
||||
let vs_attn = vs.pp("attentions");
|
||||
let attentions = (0..config.downblock.num_layers)
|
||||
.map(|i| {
|
||||
SpatialTransformer::new(
|
||||
vs_attn.pp(&i.to_string()),
|
||||
out_channels,
|
||||
n_heads,
|
||||
out_channels / n_heads,
|
||||
cfg,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "xa-down2d");
|
||||
Ok(Self {
|
||||
downblock,
|
||||
attentions,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
temb: Option<&Tensor>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Vec<Tensor>)> {
|
||||
let _enter = self.span.enter();
|
||||
let mut output_states = vec![];
|
||||
let mut xs = xs.clone();
|
||||
for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
|
||||
xs = resnet.forward(&xs, temb)?;
|
||||
xs = attn.forward(&xs, encoder_hidden_states)?;
|
||||
output_states.push(xs.clone());
|
||||
}
|
||||
let xs = match &self.downblock.downsampler {
|
||||
Some(downsampler) => {
|
||||
let xs = downsampler.forward(&xs)?;
|
||||
output_states.push(xs.clone());
|
||||
xs
|
||||
}
|
||||
None => xs,
|
||||
};
|
||||
Ok((xs, output_states))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct UpBlock2DConfig {
|
||||
pub num_layers: usize,
|
||||
pub resnet_eps: f64,
|
||||
// resnet_time_scale_shift: "default"
|
||||
// resnet_act_fn: "swish"
|
||||
pub resnet_groups: usize,
|
||||
pub output_scale_factor: f64,
|
||||
pub add_upsample: bool,
|
||||
}
|
||||
|
||||
impl Default for UpBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: 32,
|
||||
output_scale_factor: 1.,
|
||||
add_upsample: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UpBlock2D {
|
||||
pub resnets: Vec<ResnetBlock2D>,
|
||||
upsampler: Option<Upsample2D>,
|
||||
span: tracing::Span,
|
||||
pub config: UpBlock2DConfig,
|
||||
}
|
||||
|
||||
impl UpBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
prev_output_channels: usize,
|
||||
out_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
config: UpBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let vs_resnets = vs.pp("resnets");
|
||||
let resnet_cfg = ResnetBlock2DConfig {
|
||||
out_channels: Some(out_channels),
|
||||
temb_channels,
|
||||
eps: config.resnet_eps,
|
||||
output_scale_factor: config.output_scale_factor,
|
||||
..Default::default()
|
||||
};
|
||||
let resnets = (0..config.num_layers)
|
||||
.map(|i| {
|
||||
let res_skip_channels = if i == config.num_layers - 1 {
|
||||
in_channels
|
||||
} else {
|
||||
out_channels
|
||||
};
|
||||
let resnet_in_channels = if i == 0 {
|
||||
prev_output_channels
|
||||
} else {
|
||||
out_channels
|
||||
};
|
||||
let in_channels = resnet_in_channels + res_skip_channels;
|
||||
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let upsampler = if config.add_upsample {
|
||||
let upsampler =
|
||||
Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
|
||||
Some(upsampler)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "up2d");
|
||||
Ok(Self {
|
||||
resnets,
|
||||
upsampler,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
res_xs: &[Tensor],
|
||||
temb: Option<&Tensor>,
|
||||
upsample_size: Option<(usize, usize)>,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for (index, resnet) in self.resnets.iter().enumerate() {
|
||||
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
||||
xs = resnet.forward(&xs, temb)?;
|
||||
}
|
||||
match &self.upsampler {
|
||||
Some(upsampler) => upsampler.forward(&xs, upsample_size),
|
||||
None => Ok(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct CrossAttnUpBlock2DConfig {
|
||||
pub upblock: UpBlock2DConfig,
|
||||
pub attn_num_head_channels: usize,
|
||||
pub cross_attention_dim: usize,
|
||||
// attention_type: "default"
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
}
|
||||
|
||||
impl Default for CrossAttnUpBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
upblock: Default::default(),
|
||||
attn_num_head_channels: 1,
|
||||
cross_attention_dim: 1280,
|
||||
sliced_attention_size: None,
|
||||
use_linear_projection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CrossAttnUpBlock2D {
|
||||
pub upblock: UpBlock2D,
|
||||
pub attentions: Vec<SpatialTransformer>,
|
||||
span: tracing::Span,
|
||||
pub config: CrossAttnUpBlock2DConfig,
|
||||
}
|
||||
|
||||
impl CrossAttnUpBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
prev_output_channels: usize,
|
||||
out_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
config: CrossAttnUpBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let upblock = UpBlock2D::new(
|
||||
vs.clone(),
|
||||
in_channels,
|
||||
prev_output_channels,
|
||||
out_channels,
|
||||
temb_channels,
|
||||
config.upblock,
|
||||
)?;
|
||||
let n_heads = config.attn_num_head_channels;
|
||||
let cfg = SpatialTransformerConfig {
|
||||
depth: 1,
|
||||
context_dim: Some(config.cross_attention_dim),
|
||||
num_groups: config.upblock.resnet_groups,
|
||||
sliced_attention_size: config.sliced_attention_size,
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
};
|
||||
let vs_attn = vs.pp("attentions");
|
||||
let attentions = (0..config.upblock.num_layers)
|
||||
.map(|i| {
|
||||
SpatialTransformer::new(
|
||||
vs_attn.pp(&i.to_string()),
|
||||
out_channels,
|
||||
n_heads,
|
||||
out_channels / n_heads,
|
||||
cfg,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "xa-up2d");
|
||||
Ok(Self {
|
||||
upblock,
|
||||
attentions,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
res_xs: &[Tensor],
|
||||
temb: Option<&Tensor>,
|
||||
upsample_size: Option<(usize, usize)>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
|
||||
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
||||
xs = resnet.forward(&xs, temb)?;
|
||||
xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
|
||||
}
|
||||
match &self.upblock.upsampler {
|
||||
Some(upsampler) => upsampler.forward(&xs, upsample_size),
|
||||
None => Ok(xs),
|
||||
}
|
||||
}
|
||||
}
|
57
candle-examples/examples/stable-diffusion/utils.rs
Normal file
57
candle-examples/examples/stable-diffusion/utils.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
||||
if steps < 1 {
|
||||
candle::bail!("cannot use linspace with steps {steps} <= 1")
|
||||
}
|
||||
let delta = (stop - start) / (steps - 1) as f64;
|
||||
let vs = (0..steps)
|
||||
.map(|step| start + step as f64 * delta)
|
||||
.collect::<Vec<_>>();
|
||||
Tensor::from_vec(vs, steps, &Device::Cpu)
|
||||
}
|
||||
|
||||
/// Saves an image to disk using the image crate, this expects an input with shape
|
||||
/// (c, width, height).
|
||||
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
||||
let p = p.as_ref();
|
||||
let (channel, width, height) = img.dims3()?;
|
||||
if channel != 3 {
|
||||
candle::bail!("save_image expects an input of shape (3, width, height)")
|
||||
}
|
||||
let img = img.transpose(0, 1)?.t()?.flatten_all()?;
|
||||
let pixels = img.to_vec1::<u8>()?;
|
||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
||||
Some(image) => image,
|
||||
None => candle::bail!("error saving image {p:?}"),
|
||||
};
|
||||
image.save(p).map_err(candle::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Wrap the conv2d op to provide some tracing.
|
||||
#[derive(Debug)]
|
||||
pub struct Conv2d {
|
||||
inner: candle_nn::Conv2d,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Conv2d {
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conv2d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: candle_nn::Conv2dConfig,
|
||||
vs: candle_nn::VarBuilder,
|
||||
) -> Result<Conv2d> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
|
||||
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
|
||||
Ok(Conv2d { inner, span })
|
||||
}
|
378
candle-examples/examples/stable-diffusion/vae.rs
Normal file
378
candle-examples/examples/stable-diffusion/vae.rs
Normal file
@ -0,0 +1,378 @@
|
||||
#![allow(dead_code)]
|
||||
//! # Variational Auto-Encoder (VAE) Models.
|
||||
//!
|
||||
//! Auto-encoder models compress their input to a usually smaller latent space
|
||||
//! before expanding it back to its original shape. This results in the latent values
|
||||
//! compressing the original information.
|
||||
use crate::unet_2d_blocks::{
|
||||
DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
|
||||
UpDecoderBlock2D, UpDecoderBlock2DConfig,
|
||||
};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct EncoderConfig {
|
||||
// down_block_types: DownEncoderBlock2D
|
||||
block_out_channels: Vec<usize>,
|
||||
layers_per_block: usize,
|
||||
norm_num_groups: usize,
|
||||
double_z: bool,
|
||||
}
|
||||
|
||||
impl Default for EncoderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
block_out_channels: vec![64],
|
||||
layers_per_block: 2,
|
||||
norm_num_groups: 32,
|
||||
double_z: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Encoder {
|
||||
conv_in: nn::Conv2d,
|
||||
down_blocks: Vec<DownEncoderBlock2D>,
|
||||
mid_block: UNetMidBlock2D,
|
||||
conv_norm_out: nn::GroupNorm,
|
||||
conv_out: nn::Conv2d,
|
||||
#[allow(dead_code)]
|
||||
config: EncoderConfig,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
config: EncoderConfig,
|
||||
) -> Result<Self> {
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
};
|
||||
let conv_in = nn::conv2d(
|
||||
in_channels,
|
||||
config.block_out_channels[0],
|
||||
3,
|
||||
conv_cfg,
|
||||
vs.pp("conv_in"),
|
||||
)?;
|
||||
let mut down_blocks = vec![];
|
||||
let vs_down_blocks = vs.pp("down_blocks");
|
||||
for index in 0..config.block_out_channels.len() {
|
||||
let out_channels = config.block_out_channels[index];
|
||||
let in_channels = if index > 0 {
|
||||
config.block_out_channels[index - 1]
|
||||
} else {
|
||||
config.block_out_channels[0]
|
||||
};
|
||||
let is_final = index + 1 == config.block_out_channels.len();
|
||||
let cfg = DownEncoderBlock2DConfig {
|
||||
num_layers: config.layers_per_block,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: config.norm_num_groups,
|
||||
add_downsample: !is_final,
|
||||
downsample_padding: 0,
|
||||
..Default::default()
|
||||
};
|
||||
let down_block = DownEncoderBlock2D::new(
|
||||
vs_down_blocks.pp(&index.to_string()),
|
||||
in_channels,
|
||||
out_channels,
|
||||
cfg,
|
||||
)?;
|
||||
down_blocks.push(down_block)
|
||||
}
|
||||
let last_block_out_channels = *config.block_out_channels.last().unwrap();
|
||||
let mid_cfg = UNetMidBlock2DConfig {
|
||||
resnet_eps: 1e-6,
|
||||
output_scale_factor: 1.,
|
||||
attn_num_head_channels: None,
|
||||
resnet_groups: Some(config.norm_num_groups),
|
||||
..Default::default()
|
||||
};
|
||||
let mid_block =
|
||||
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
|
||||
let conv_norm_out = nn::group_norm(
|
||||
config.norm_num_groups,
|
||||
last_block_out_channels,
|
||||
1e-6,
|
||||
vs.pp("conv_norm_out"),
|
||||
)?;
|
||||
let conv_out_channels = if config.double_z {
|
||||
2 * out_channels
|
||||
} else {
|
||||
out_channels
|
||||
};
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv_out = nn::conv2d(
|
||||
last_block_out_channels,
|
||||
conv_out_channels,
|
||||
3,
|
||||
conv_cfg,
|
||||
vs.pp("conv_out"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
conv_in,
|
||||
down_blocks,
|
||||
mid_block,
|
||||
conv_norm_out,
|
||||
conv_out,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.conv_in.forward(xs)?;
|
||||
for down_block in self.down_blocks.iter() {
|
||||
xs = down_block.forward(&xs)?
|
||||
}
|
||||
let xs = self.mid_block.forward(&xs, None)?;
|
||||
let xs = self.conv_norm_out.forward(&xs)?;
|
||||
let xs = nn::ops::silu(&xs)?;
|
||||
self.conv_out.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderConfig {
|
||||
// up_block_types: UpDecoderBlock2D
|
||||
block_out_channels: Vec<usize>,
|
||||
layers_per_block: usize,
|
||||
norm_num_groups: usize,
|
||||
}
|
||||
|
||||
impl Default for DecoderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
block_out_channels: vec![64],
|
||||
layers_per_block: 2,
|
||||
norm_num_groups: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Decoder {
|
||||
conv_in: nn::Conv2d,
|
||||
up_blocks: Vec<UpDecoderBlock2D>,
|
||||
mid_block: UNetMidBlock2D,
|
||||
conv_norm_out: nn::GroupNorm,
|
||||
conv_out: nn::Conv2d,
|
||||
#[allow(dead_code)]
|
||||
config: DecoderConfig,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
config: DecoderConfig,
|
||||
) -> Result<Self> {
|
||||
let n_block_out_channels = config.block_out_channels.len();
|
||||
let last_block_out_channels = *config.block_out_channels.last().unwrap();
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
};
|
||||
let conv_in = nn::conv2d(
|
||||
in_channels,
|
||||
last_block_out_channels,
|
||||
3,
|
||||
conv_cfg,
|
||||
vs.pp("conv_in"),
|
||||
)?;
|
||||
let mid_cfg = UNetMidBlock2DConfig {
|
||||
resnet_eps: 1e-6,
|
||||
output_scale_factor: 1.,
|
||||
attn_num_head_channels: None,
|
||||
resnet_groups: Some(config.norm_num_groups),
|
||||
..Default::default()
|
||||
};
|
||||
let mid_block =
|
||||
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
|
||||
let mut up_blocks = vec![];
|
||||
let vs_up_blocks = vs.pp("up_blocks");
|
||||
let reversed_block_out_channels: Vec<_> =
|
||||
config.block_out_channels.iter().copied().rev().collect();
|
||||
for index in 0..n_block_out_channels {
|
||||
let out_channels = reversed_block_out_channels[index];
|
||||
let in_channels = if index > 0 {
|
||||
reversed_block_out_channels[index - 1]
|
||||
} else {
|
||||
reversed_block_out_channels[0]
|
||||
};
|
||||
let is_final = index + 1 == n_block_out_channels;
|
||||
let cfg = UpDecoderBlock2DConfig {
|
||||
num_layers: config.layers_per_block + 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: config.norm_num_groups,
|
||||
add_upsample: !is_final,
|
||||
..Default::default()
|
||||
};
|
||||
let up_block = UpDecoderBlock2D::new(
|
||||
vs_up_blocks.pp(&index.to_string()),
|
||||
in_channels,
|
||||
out_channels,
|
||||
cfg,
|
||||
)?;
|
||||
up_blocks.push(up_block)
|
||||
}
|
||||
let conv_norm_out = nn::group_norm(
|
||||
config.norm_num_groups,
|
||||
config.block_out_channels[0],
|
||||
1e-6,
|
||||
vs.pp("conv_norm_out"),
|
||||
)?;
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv_out = nn::conv2d(
|
||||
config.block_out_channels[0],
|
||||
out_channels,
|
||||
3,
|
||||
conv_cfg,
|
||||
vs.pp("conv_out"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
conv_in,
|
||||
up_blocks,
|
||||
mid_block,
|
||||
conv_norm_out,
|
||||
conv_out,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?;
|
||||
for up_block in self.up_blocks.iter() {
|
||||
xs = up_block.forward(&xs)?
|
||||
}
|
||||
let xs = self.conv_norm_out.forward(&xs)?;
|
||||
let xs = nn::ops::silu(&xs)?;
|
||||
self.conv_out.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutoEncoderKLConfig {
|
||||
pub block_out_channels: Vec<usize>,
|
||||
pub layers_per_block: usize,
|
||||
pub latent_channels: usize,
|
||||
pub norm_num_groups: usize,
|
||||
}
|
||||
|
||||
impl Default for AutoEncoderKLConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
block_out_channels: vec![64],
|
||||
layers_per_block: 1,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DiagonalGaussianDistribution {
|
||||
mean: Tensor,
|
||||
std: Tensor,
|
||||
}
|
||||
|
||||
impl DiagonalGaussianDistribution {
|
||||
pub fn new(parameters: &Tensor) -> Result<Self> {
|
||||
let mut parameters = parameters.chunk(2, 1)?.into_iter();
|
||||
let mean = parameters.next().unwrap();
|
||||
let logvar = parameters.next().unwrap();
|
||||
let std = (logvar * 0.5)?.exp()?;
|
||||
Ok(DiagonalGaussianDistribution { mean, std })
|
||||
}
|
||||
|
||||
pub fn sample(&self) -> Result<Tensor> {
|
||||
let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device());
|
||||
&self.mean + &self.std * sample
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485
|
||||
// This implementation is specific to the config used in stable-diffusion-v1-5
|
||||
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
|
||||
#[derive(Debug)]
|
||||
pub struct AutoEncoderKL {
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
quant_conv: nn::Conv2d,
|
||||
post_quant_conv: nn::Conv2d,
|
||||
pub config: AutoEncoderKLConfig,
|
||||
}
|
||||
|
||||
impl AutoEncoderKL {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
config: AutoEncoderKLConfig,
|
||||
) -> Result<Self> {
|
||||
let latent_channels = config.latent_channels;
|
||||
let encoder_cfg = EncoderConfig {
|
||||
block_out_channels: config.block_out_channels.clone(),
|
||||
layers_per_block: config.layers_per_block,
|
||||
norm_num_groups: config.norm_num_groups,
|
||||
double_z: true,
|
||||
};
|
||||
let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?;
|
||||
let decoder_cfg = DecoderConfig {
|
||||
block_out_channels: config.block_out_channels.clone(),
|
||||
layers_per_block: config.layers_per_block,
|
||||
norm_num_groups: config.norm_num_groups,
|
||||
};
|
||||
let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
|
||||
let conv_cfg = Default::default();
|
||||
let quant_conv = nn::conv2d(
|
||||
2 * latent_channels,
|
||||
2 * latent_channels,
|
||||
1,
|
||||
conv_cfg,
|
||||
vs.pp("quant_conv"),
|
||||
)?;
|
||||
let post_quant_conv = nn::conv2d(
|
||||
latent_channels,
|
||||
latent_channels,
|
||||
1,
|
||||
conv_cfg,
|
||||
vs.pp("post_quant_conv"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
quant_conv,
|
||||
post_quant_conv,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the distribution in the latent space.
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
|
||||
let xs = self.encoder.forward(xs)?;
|
||||
let parameters = self.quant_conv.forward(&xs)?;
|
||||
DiagonalGaussianDistribution::new(¶meters)
|
||||
}
|
||||
|
||||
/// Takes as input some sampled values.
|
||||
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.post_quant_conv.forward(xs)?;
|
||||
self.decoder.forward(&xs)
|
||||
}
|
||||
}
|
@ -1,18 +1,18 @@
|
||||
#![allow(dead_code)]
|
||||
// https://github.com/openai/whisper/blob/main/whisper/model.py
|
||||
// https://github.com/openai/whisper/blob/main/whisper/model.py/rgs
|
||||
// TODO:
|
||||
// - kv-cache support?
|
||||
// - Language detection?
|
||||
// - Batch size greater than 1.
|
||||
// - More token filters (SuppressBlanks, ApplyTimestampRules).
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{safetensors::Load, DType, Device, Tensor};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
@ -20,6 +20,7 @@ use tokenizers::Tokenizer;
|
||||
mod audio;
|
||||
mod model;
|
||||
use model::{Config, Whisper};
|
||||
mod multilingual;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
@ -31,9 +32,6 @@ const HOP_LENGTH: usize = 160;
|
||||
const CHUNK_LENGTH: usize = 30;
|
||||
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
||||
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
|
||||
const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2
|
||||
const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame
|
||||
const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token
|
||||
|
||||
const NO_SPEECH_THRESHOLD: f64 = 0.6;
|
||||
const LOGPROB_THRESHOLD: f64 = -1.0;
|
||||
@ -41,21 +39,12 @@ const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
||||
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||
|
||||
// Tokenizer dependent bits.
|
||||
const SOT_TOKEN: u32 = 50257;
|
||||
const EOT_TOKEN: u32 = 50256;
|
||||
const NO_SPEECH_TOKEN: u32 = 50361;
|
||||
const NO_TIMESTAMP_TOKEN: u32 = 50362;
|
||||
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
||||
const SUPPRESS_TOKENS: [u32; 91] = [
|
||||
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
|
||||
366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
|
||||
1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
|
||||
10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
|
||||
19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
|
||||
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
|
||||
];
|
||||
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||
const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecodingResult {
|
||||
tokens: Vec<u32>,
|
||||
@ -66,6 +55,7 @@ struct DecodingResult {
|
||||
compression_ratio: f64,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct Segment {
|
||||
start: f64,
|
||||
@ -78,13 +68,24 @@ struct Decoder {
|
||||
rng: rand::rngs::StdRng,
|
||||
tokenizer: Tokenizer,
|
||||
suppress_tokens: Tensor,
|
||||
sot_token: u32,
|
||||
transcribe_token: u32,
|
||||
eot_token: u32,
|
||||
no_speech_token: u32,
|
||||
language_token: Option<u32>,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
fn new(model: Whisper, tokenizer: Tokenizer, seed: u64, device: &Device) -> Result<Self> {
|
||||
fn new(
|
||||
model: Whisper,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
device: &Device,
|
||||
language_token: Option<u32>,
|
||||
) -> Result<Self> {
|
||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||
.map(|i| {
|
||||
if SUPPRESS_TOKENS.contains(&i) {
|
||||
if model.config.suppress_tokens.contains(&i) {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0f32
|
||||
@ -92,43 +93,59 @@ impl Decoder {
|
||||
})
|
||||
.collect();
|
||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
tokenizer,
|
||||
suppress_tokens,
|
||||
sot_token,
|
||||
transcribe_token,
|
||||
eot_token,
|
||||
no_speech_token,
|
||||
language_token,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||
let model = &self.model;
|
||||
let audio_features = model.encoder.forward(mel)?;
|
||||
let model = &mut self.model;
|
||||
let audio_features = model.encoder.forward(mel, true)?;
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
let sample_len = model.config.max_target_positions / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![SOT_TOKEN];
|
||||
let mut tokens = vec![self.sot_token];
|
||||
if let Some(language_token) = self.language_token {
|
||||
tokens.push(language_token)
|
||||
}
|
||||
tokens.push(self.transcribe_token);
|
||||
for i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
||||
.get(NO_SPEECH_TOKEN as usize)?
|
||||
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
no_speech_prob = softmax(&logits, 0)?
|
||||
.i(self.no_speech_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
|
||||
let (seq_len, _) = logits.dims2()?;
|
||||
let logits = logits
|
||||
.get(seq_len - 1)?
|
||||
.broadcast_add(&self.suppress_tokens)?;
|
||||
let (_, seq_len, _) = ys.dims3()?;
|
||||
let logits = model
|
||||
.decoder
|
||||
.final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
let logits = logits.broadcast_add(&self.suppress_tokens)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
@ -145,17 +162,14 @@ impl Decoder {
|
||||
};
|
||||
tokens.push(next_token);
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.get(next_token as usize)?
|
||||
.i(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
|
||||
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
||||
break;
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
}
|
||||
let text = self
|
||||
.tokenizer
|
||||
.decode(tokens.clone(), true)
|
||||
.map_err(E::msg)?;
|
||||
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
|
||||
let avg_logprob = sum_logprob / tokens.len() as f64;
|
||||
|
||||
Ok(DecodingResult {
|
||||
@ -219,6 +233,44 @@ impl Decoder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
||||
match tokenizer.token_to_id(token) {
|
||||
None => candle::bail!("no token-id for {token}"),
|
||||
Some(id) => Ok(id),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum WhichModel {
|
||||
Tiny,
|
||||
TinyEn,
|
||||
Base,
|
||||
BaseEn,
|
||||
SmallEn,
|
||||
MediumEn,
|
||||
LargeV2,
|
||||
}
|
||||
|
||||
impl WhichModel {
|
||||
fn is_multilingual(&self) -> bool {
|
||||
match self {
|
||||
Self::Tiny | Self::Base | Self::LargeV2 => true,
|
||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
|
||||
}
|
||||
}
|
||||
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
||||
match self {
|
||||
Self::Tiny => ("openai/whisper-tiny", "main"),
|
||||
Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
|
||||
Self::Base => ("openai/whisper-base", "refs/pr/22"),
|
||||
Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"),
|
||||
Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
|
||||
Self::MediumEn => ("openai/whisper-medium.en", "refs/pr/11"),
|
||||
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -234,6 +286,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// The model to be used, can be tiny, small, medium.
|
||||
#[arg(long, default_value = "tiny-en")]
|
||||
model: WhichModel,
|
||||
|
||||
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
|
||||
/// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following
|
||||
/// repo: https://huggingface.co/datasets/Narsil/candle_demo/
|
||||
@ -244,20 +300,33 @@ struct Args {
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The mel filters in safetensors format.
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
|
||||
)]
|
||||
filters: String,
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Language.
|
||||
#[arg(long)]
|
||||
language: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let default_model = "openai/whisper-tiny.en".to_string();
|
||||
let (default_model, default_revision) = args.model.model_and_revision();
|
||||
let default_model = default_model.to_string();
|
||||
let default_revision = default_revision.to_string();
|
||||
let path = std::path::PathBuf::from(default_model.clone());
|
||||
let default_revision = "refs/pr/15".to_string();
|
||||
let (model_id, revision) = match (args.model_id, args.revision) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
@ -301,11 +370,9 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
|
||||
let mel_filters = mel_filters.deserialize()?;
|
||||
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
||||
println!("loaded mel filters {:?}", mel_filters.shape());
|
||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||
let mel_bytes = include_bytes!("melfilters.bytes");
|
||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||
|
||||
let mut input = std::fs::File::open(input)?;
|
||||
let (header, data) = wav::read(&mut input)?;
|
||||
@ -328,8 +395,20 @@ fn main() -> Result<()> {
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let model = Whisper::load(&vb, config)?;
|
||||
let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?;
|
||||
let mut model = Whisper::load(&vb, config)?;
|
||||
|
||||
let language_token = match (args.model.is_multilingual(), args.language) {
|
||||
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),
|
||||
(false, None) => None,
|
||||
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
|
||||
Ok(token_id) => Some(token_id),
|
||||
Err(_) => anyhow::bail!("language {language} is not supported"),
|
||||
},
|
||||
(false, Some(_)) => {
|
||||
anyhow::bail!("a language cannot be set for non-multilingual models")
|
||||
}
|
||||
};
|
||||
let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?;
|
||||
dc.run(&mel)?;
|
||||
Ok(())
|
||||
}
|
||||
|
Binary file not shown.
@ -1,8 +1,5 @@
|
||||
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle::{Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
@ -19,10 +16,21 @@ pub struct Config {
|
||||
// pub n_text_state: usize,
|
||||
pub decoder_attention_heads: usize, // n_text_head
|
||||
pub decoder_layers: usize, // n_text_layer
|
||||
pub suppress_tokens: Vec<u32>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
#[allow(dead_code)]
|
||||
pub fn tiny_en() -> Self {
|
||||
let suppress_tokens = vec![
|
||||
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93,
|
||||
357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391,
|
||||
1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329,
|
||||
7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306,
|
||||
16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409,
|
||||
34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361,
|
||||
50362,
|
||||
];
|
||||
Self {
|
||||
num_mel_bins: 80,
|
||||
vocab_size: 51864,
|
||||
@ -34,6 +42,7 @@ impl Config {
|
||||
// n_text_state: 384,
|
||||
decoder_attention_heads: 6,
|
||||
decoder_layers: 4,
|
||||
suppress_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -42,16 +51,32 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
//
|
||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||
// model.
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = vb.get(size2, "bias")?;
|
||||
Ok(Linear::new(weight, Some(bias)))
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
Ok(Linear::new(weight, None))
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
@ -66,32 +91,6 @@ fn conv1d(
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
fn conv1d_no_bias(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
|
||||
Ok(Conv1d::new(weight, None, config))
|
||||
}
|
||||
|
||||
struct Dropout {
|
||||
pr: f64,
|
||||
}
|
||||
|
||||
impl Dropout {
|
||||
fn new(pr: f64) -> Self {
|
||||
Self { pr }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// TODO
|
||||
Ok(x.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?;
|
||||
let bias = vb.get(size, "bias")?;
|
||||
@ -105,10 +104,17 @@ struct MultiHeadAttention {
|
||||
value: Linear,
|
||||
out: Linear,
|
||||
n_head: usize,
|
||||
span: tracing::Span,
|
||||
softmax_span: tracing::Span,
|
||||
matmul_span: tracing::Span,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
|
||||
let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
|
||||
let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
|
||||
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
||||
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
||||
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
||||
@ -119,13 +125,42 @@ impl MultiHeadAttention {
|
||||
value,
|
||||
out,
|
||||
n_head,
|
||||
span,
|
||||
softmax_span,
|
||||
matmul_span,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
flush_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let q = self.query.forward(x)?;
|
||||
let k = self.key.forward(xa.unwrap_or(x))?;
|
||||
let v = self.value.forward(xa.unwrap_or(x))?;
|
||||
let (k, v) = match xa {
|
||||
None => {
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
(k, v)
|
||||
}
|
||||
Some(x) => {
|
||||
if flush_cache {
|
||||
self.kv_cache = None;
|
||||
}
|
||||
if let Some((k, v)) = &self.kv_cache {
|
||||
(k.clone(), v.clone())
|
||||
} else {
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
||||
let out = self.out.forward(&wv)?;
|
||||
Ok(out)
|
||||
@ -134,7 +169,7 @@ impl MultiHeadAttention {
|
||||
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
||||
x.reshape(target_dims)?.transpose(1, 2)
|
||||
}
|
||||
|
||||
fn qkv_attention(
|
||||
@ -149,13 +184,24 @@ impl MultiHeadAttention {
|
||||
let q = (self.reshape_head(q)? * scale)?;
|
||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||
let v = self.reshape_head(v)?.contiguous()?;
|
||||
let mut qk = q.matmul(&k)?;
|
||||
let mut qk = {
|
||||
let _enter = self.matmul_span.enter();
|
||||
q.matmul(&k)?
|
||||
};
|
||||
if let Some(mask) = mask {
|
||||
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
|
||||
let mask = mask.i((0..n_ctx, 0..n_ctx))?;
|
||||
qk = qk.broadcast_add(&mask)?
|
||||
}
|
||||
let w = softmax(&qk, candle::D::Minus1)?;
|
||||
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
|
||||
let w = {
|
||||
let _enter = self.softmax_span.enter();
|
||||
softmax(&qk, candle::D::Minus1)?
|
||||
};
|
||||
let wv = {
|
||||
let _enter = self.matmul_span.enter();
|
||||
w.matmul(&v)?
|
||||
}
|
||||
.transpose(1, 2)?
|
||||
.flatten_from(2)?;
|
||||
Ok(wv)
|
||||
}
|
||||
}
|
||||
@ -168,10 +214,12 @@ struct ResidualAttentionBlock {
|
||||
mlp_linear1: Linear,
|
||||
mlp_linear2: Linear,
|
||||
mlp_ln: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ResidualAttentionBlock {
|
||||
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
|
||||
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
|
||||
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
|
||||
let cross_attn = if ca {
|
||||
@ -192,14 +240,24 @@ impl ResidualAttentionBlock {
|
||||
mlp_linear1,
|
||||
mlp_linear2,
|
||||
mlp_ln,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
flush_kv_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let attn = self
|
||||
.attn
|
||||
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
|
||||
let mut x = (x + attn)?;
|
||||
if let Some((attn, ln)) = &self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?;
|
||||
if let Some((attn, ln)) = &mut self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
||||
}
|
||||
let mlp = self.mlp_linear2.forward(
|
||||
&self
|
||||
@ -207,7 +265,7 @@ impl ResidualAttentionBlock {
|
||||
.forward(&self.mlp_ln.forward(&x)?)?
|
||||
.gelu()?,
|
||||
)?;
|
||||
Ok((x + mlp)?)
|
||||
x + mlp
|
||||
}
|
||||
}
|
||||
|
||||
@ -234,10 +292,16 @@ pub struct AudioEncoder {
|
||||
positional_embedding: Tensor,
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln_post: LayerNorm,
|
||||
span: tracing::Span,
|
||||
conv1_span: tracing::Span,
|
||||
conv2_span: tracing::Span,
|
||||
}
|
||||
|
||||
impl AudioEncoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
|
||||
let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
|
||||
let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.encoder_attention_heads;
|
||||
let n_ctx = cfg.max_source_positions;
|
||||
@ -264,17 +328,28 @@ impl AudioEncoder {
|
||||
positional_embedding,
|
||||
blocks,
|
||||
ln_post,
|
||||
conv1_span,
|
||||
conv2_span,
|
||||
span,
|
||||
})
|
||||
}
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.conv1.forward(x)?.gelu()?;
|
||||
let x = self.conv2.forward(&x)?.gelu()?;
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x = {
|
||||
let _enter = self.conv1_span.enter();
|
||||
self.conv1.forward(x)?.gelu()?
|
||||
};
|
||||
let x = {
|
||||
let _enter = self.conv2_span.enter();
|
||||
self.conv2.forward(&x)?.gelu()?
|
||||
};
|
||||
let x = x.transpose(1, 2)?;
|
||||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, None, None)?
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, None, None, flush_kv_cache)?
|
||||
}
|
||||
let x = self.ln_post.forward(&x)?;
|
||||
Ok(x)
|
||||
@ -288,10 +363,14 @@ pub struct TextDecoder {
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln: LayerNorm,
|
||||
mask: Tensor,
|
||||
span: tracing::Span,
|
||||
span_final: tracing::Span,
|
||||
}
|
||||
|
||||
impl TextDecoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
|
||||
let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.decoder_attention_heads;
|
||||
let n_ctx = cfg.max_target_positions;
|
||||
@ -307,31 +386,37 @@ impl TextDecoder {
|
||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
|
||||
|
||||
Ok(Self {
|
||||
token_embedding,
|
||||
positional_embedding,
|
||||
blocks,
|
||||
ln,
|
||||
mask,
|
||||
span,
|
||||
span_final,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x_dims = x.dims();
|
||||
let last = x_dims[x_dims.len() - 1];
|
||||
let token_embedding = self.token_embedding.forward(x)?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, Some(xa), Some(&self.mask))?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
||||
}
|
||||
let x = self.ln.forward(&x)?;
|
||||
let w = self
|
||||
.token_embedding
|
||||
.embeddings()
|
||||
.broadcast_left(x_dims[0])?;
|
||||
let logits = x.matmul(&w.t()?)?;
|
||||
self.ln.forward(&x)
|
||||
}
|
||||
|
||||
pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let b_size = x.dim(0)?;
|
||||
let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
|
||||
let logits = {
|
||||
let _enter = self.span_final.enter();
|
||||
x.matmul(&w.t()?)?
|
||||
};
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
||||
@ -353,10 +438,4 @@ impl Whisper {
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
|
||||
let enc = self.encoder.forward(mel)?;
|
||||
let dec = self.decoder.forward(tokens, &enc)?;
|
||||
Ok(dec)
|
||||
}
|
||||
}
|
||||
|
135
candle-examples/examples/whisper/multilingual.rs
Normal file
135
candle-examples/examples/whisper/multilingual.rs
Normal file
@ -0,0 +1,135 @@
|
||||
use crate::Whisper;
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const LANGUAGES: [(&str, &str); 99] = [
|
||||
("en", "english"),
|
||||
("zh", "chinese"),
|
||||
("de", "german"),
|
||||
("es", "spanish"),
|
||||
("ru", "russian"),
|
||||
("ko", "korean"),
|
||||
("fr", "french"),
|
||||
("ja", "japanese"),
|
||||
("pt", "portuguese"),
|
||||
("tr", "turkish"),
|
||||
("pl", "polish"),
|
||||
("ca", "catalan"),
|
||||
("nl", "dutch"),
|
||||
("ar", "arabic"),
|
||||
("sv", "swedish"),
|
||||
("it", "italian"),
|
||||
("id", "indonesian"),
|
||||
("hi", "hindi"),
|
||||
("fi", "finnish"),
|
||||
("vi", "vietnamese"),
|
||||
("he", "hebrew"),
|
||||
("uk", "ukrainian"),
|
||||
("el", "greek"),
|
||||
("ms", "malay"),
|
||||
("cs", "czech"),
|
||||
("ro", "romanian"),
|
||||
("da", "danish"),
|
||||
("hu", "hungarian"),
|
||||
("ta", "tamil"),
|
||||
("no", "norwegian"),
|
||||
("th", "thai"),
|
||||
("ur", "urdu"),
|
||||
("hr", "croatian"),
|
||||
("bg", "bulgarian"),
|
||||
("lt", "lithuanian"),
|
||||
("la", "latin"),
|
||||
("mi", "maori"),
|
||||
("ml", "malayalam"),
|
||||
("cy", "welsh"),
|
||||
("sk", "slovak"),
|
||||
("te", "telugu"),
|
||||
("fa", "persian"),
|
||||
("lv", "latvian"),
|
||||
("bn", "bengali"),
|
||||
("sr", "serbian"),
|
||||
("az", "azerbaijani"),
|
||||
("sl", "slovenian"),
|
||||
("kn", "kannada"),
|
||||
("et", "estonian"),
|
||||
("mk", "macedonian"),
|
||||
("br", "breton"),
|
||||
("eu", "basque"),
|
||||
("is", "icelandic"),
|
||||
("hy", "armenian"),
|
||||
("ne", "nepali"),
|
||||
("mn", "mongolian"),
|
||||
("bs", "bosnian"),
|
||||
("kk", "kazakh"),
|
||||
("sq", "albanian"),
|
||||
("sw", "swahili"),
|
||||
("gl", "galician"),
|
||||
("mr", "marathi"),
|
||||
("pa", "punjabi"),
|
||||
("si", "sinhala"),
|
||||
("km", "khmer"),
|
||||
("sn", "shona"),
|
||||
("yo", "yoruba"),
|
||||
("so", "somali"),
|
||||
("af", "afrikaans"),
|
||||
("oc", "occitan"),
|
||||
("ka", "georgian"),
|
||||
("be", "belarusian"),
|
||||
("tg", "tajik"),
|
||||
("sd", "sindhi"),
|
||||
("gu", "gujarati"),
|
||||
("am", "amharic"),
|
||||
("yi", "yiddish"),
|
||||
("lo", "lao"),
|
||||
("uz", "uzbek"),
|
||||
("fo", "faroese"),
|
||||
("ht", "haitian creole"),
|
||||
("ps", "pashto"),
|
||||
("tk", "turkmen"),
|
||||
("nn", "nynorsk"),
|
||||
("mt", "maltese"),
|
||||
("sa", "sanskrit"),
|
||||
("lb", "luxembourgish"),
|
||||
("my", "myanmar"),
|
||||
("bo", "tibetan"),
|
||||
("tl", "tagalog"),
|
||||
("mg", "malagasy"),
|
||||
("as", "assamese"),
|
||||
("tt", "tatar"),
|
||||
("haw", "hawaiian"),
|
||||
("ln", "lingala"),
|
||||
("ha", "hausa"),
|
||||
("ba", "bashkir"),
|
||||
("jw", "javanese"),
|
||||
("su", "sundanese"),
|
||||
];
|
||||
|
||||
/// Returns the token id for the selected language.
|
||||
pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
||||
let (_bsize, _, seq_len) = mel.dims3()?;
|
||||
let mel = mel.narrow(2, 0, usize::min(seq_len, model.config.max_source_positions))?;
|
||||
let device = mel.device();
|
||||
let language_token_ids = LANGUAGES
|
||||
.iter()
|
||||
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
|
||||
let audio_features = model.encoder.forward(&mel, true)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
let logits = model
|
||||
.decoder
|
||||
.forward(&tokens, &audio_features, true)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||
let probs = probs.to_vec1::<f32>()?;
|
||||
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
|
||||
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for ((_, language), p) in probs.iter().take(5) {
|
||||
println!("{language}: {p}")
|
||||
}
|
||||
let language = crate::token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?;
|
||||
Ok(language)
|
||||
}
|
@ -11,3 +11,102 @@ pub fn device(cpu: bool) -> Result<Device> {
|
||||
Ok(device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||
#[rustfmt::skip]
|
||||
#[tokio::test]
|
||||
async fn book_hub_1() {
|
||||
// ANCHOR: book_hub_1
|
||||
use candle::Device;
|
||||
use hf_hub::api::tokio::Api;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
|
||||
let weights_filename = repo.get("model.safetensors").await.unwrap();
|
||||
|
||||
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_1
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_2() {
|
||||
// ANCHOR: book_hub_2
|
||||
use candle::Device;
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_2
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_3() {
|
||||
// ANCHOR: book_hub_3
|
||||
use candle::{DType, Device, Tensor};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use safetensors::slice::IndexOp;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
|
||||
// Use safetensors directly
|
||||
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
|
||||
let view = tensors
|
||||
.tensor("bert.encoder.layer.0.attention.self.query.weight")
|
||||
.unwrap();
|
||||
|
||||
// We're going to load shard with rank 1, within a world_size of 4
|
||||
// We're going to split along dimension 0 doing VIEW[start..stop, :]
|
||||
let rank = 1;
|
||||
let world_size = 4;
|
||||
let dim = 0;
|
||||
let dtype = view.dtype();
|
||||
let mut tp_shape = view.shape().to_vec();
|
||||
let size = tp_shape[0];
|
||||
|
||||
if size % world_size != 0 {
|
||||
panic!("The dimension is not divisble by `world_size`");
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
let stop = (rank + 1) * block_size;
|
||||
|
||||
// Everything is expressed in tensor dimension
|
||||
// bytes offsets is handled automatically for safetensors.
|
||||
|
||||
let iterator = view.slice(start..stop).unwrap();
|
||||
|
||||
tp_shape[dim] = block_size;
|
||||
|
||||
// Convert safetensors Dtype to candle DType
|
||||
let dtype: DType = dtype.try_into().unwrap();
|
||||
|
||||
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
|
||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_3
|
||||
assert_eq!(view.shape(), &[768, 768]);
|
||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||
}
|
||||
}
|
||||
|
@ -1,17 +1,17 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT/Apache-2.0"
|
||||
license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.1.0", package = "candle-core" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.1.1", package = "candle-core" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
@ -21,4 +21,4 @@ rayon = "1.7.0"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.0", features = ["cuda"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.1", features = ["cuda"] }
|
||||
|
@ -88,6 +88,7 @@ fn main() -> Result<()> {
|
||||
.map(|(cu_file, obj_file)| {
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command
|
||||
.arg("-std=c++17")
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("-c")
|
||||
.args(["-o", obj_file.to_str().unwrap()])
|
||||
|
@ -1,13 +1,13 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT/Apache-2.0"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
|
||||
|
@ -6,21 +6,12 @@
|
||||
|
||||
// FIXME: the minimum compute capabilities are just guesses since the table is not specific enough
|
||||
|
||||
// #if __CUDA_ARCH__ < 600
|
||||
// __device__ __forceinline__ __half __hmax(__half a, __half b) {
|
||||
// return __float2half(fmaxf(__half2float(a), __half2float(b)));
|
||||
// }
|
||||
// __device__ __forceinline__ __half __hmin(__half a, __half b) {
|
||||
// return __float2half(fminf(__half2float(a), __half2float(b)));
|
||||
// }
|
||||
// #endif
|
||||
|
||||
#if __CUDA_ARCH__ < 800
|
||||
#if (__CUDACC_VER_MAJOR__ < 12 || __CUDACC_VER_MINOR__ < 2) && __CUDA_ARCH__ < 800
|
||||
__device__ __forceinline__ __half __hmax_nan(__half a, __half b) {
|
||||
// return __hisnan(a) ? a : (__hisnan(b) ? b : __hmax(a, b));
|
||||
return __hisnan(a) ? a : (__hisnan(b) ? b : __hmax(a, b));
|
||||
}
|
||||
__device__ __forceinline__ __half __hmin_nan(__half a, __half b) {
|
||||
// return __hisnan(a) ? a : (__hisnan(b) ? b : __hmin(a, b));
|
||||
return __hisnan(a) ? a : (__hisnan(b) ? b : __hmin(a, b));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user