Compare commits

..

139 Commits

Author SHA1 Message Date
0bb344f798 [RFC] Start removal of VarBuilder.
- Uses `Initializer` trait instead.
- Allows more decoupling between init and load, which are very different
  ops.
- Allows more decoupling between backends (safetensors, npy, ggml,
  etc...)

This is a minimum viable change.

There are 3 kind of objects with various relations.

The `Model`:
    This is `Llama`, `Linear`, `Rms` ...
    They contain tensors (and possibly other things).  and are used to
    call `forward` basically.
    They should have no ownership of any internals like Rng state or
    actual shapes of the tensors (the tensors already own those)

The `Initializer`:
    This is a struct containing necessary information to generate new
    random tensors. Typically they should own a random generator, and
    generate different kind of random tensors based on what kind of
    `Model` they are initializing.
    This do not own any information about the `Model` itself.
    Default init stores the `Vec<Var>` for now, in order to send to the
    optimizer.

Ths `Config`:
    This is the necessary information to link between the `Model` and
    the `Initializer`. This is another struct which is a companion of
    the implementation of the initalization.
    Typical information is the shape of the tensors for simple `Model`,
    the `eps` for RMS, the `use_bias` boolean to know whether we should
    have a bias in the linear layer.

This should remove all need for `VarBuilder` during intialization, and
allow removing every initialization bit within `VarBuilder`.

Modifying `llama2-c` to follow that initialization is left on purpose
for a follow-up to keep the current PR rather small.
2023-08-16 14:39:36 +02:00
965597a873 Add a test for qmatmul. (#459) 2023-08-16 06:36:27 +01:00
ca449f9ee1 Add quantized tensors. (#458)
* Add quantized tensors.

* Implement the debug trait for QTensor.

* Add the QMatMul custom op.
2023-08-15 22:45:53 +01:00
b8263aa15c Quantized support for f16 and f32 (#457)
* Add f32 as a quantized type.

* Add f16 as a quantized type too.
2023-08-15 21:09:37 +01:00
e68b2accb4 Split out the quantized file. (#456) 2023-08-15 20:26:27 +01:00
08effe3762 More quantization support (#455)
* Properly initialize wdata.

* Simplify the matmul bits.

* Add from_float for q4_0.

* Fix a couple bugs.

* Get the test to work.

* Get clippy to be happy.
2023-08-15 18:58:04 +01:00
8ad4a21ffc Add a basic optimizer example. (#454) 2023-08-15 17:19:18 +01:00
5e49922be2 Basic quantization support (#453)
* Add a vecdot trait.

* Start implementing mul_mat.

* Add to the mul mat implementation.

* Add q8_0 quantization.

* Implement the GgmlType trait for all types.

* Add the missing block.

* Add a TODO.
2023-08-15 15:53:19 +01:00
ebcfd96d94 add c++17 flags (#452) 2023-08-15 15:29:34 +01:00
5b1690fffa Tweak the llama example. (#450) 2023-08-15 12:18:20 +01:00
3cc87058b7 Support local weights & dynamic outputs (#447)
* Support local weights & dynamic outputs

* Revise as suggested

* Cargo code format
2023-08-15 11:51:57 +01:00
531f23b4d0 Rename vec-dot to vec-ops. (#449)
* Rename vec-dot to vec-ops.

* Also bump the crate version.

* Add a currently empty readme.
2023-08-15 10:48:57 +01:00
495e0b7580 Simd support (#448)
* Import the simd intrinsics in candle-core.

* simd version of reduce-sum.

* Bugfix.

* Fix some clippy lints.
2023-08-15 09:50:38 +01:00
90374097dc Cudnn support (#445)
* Add a cudnn feature to be used for conv2d.

* Allocate the proper workspace.

* Only create a single cudnn handle per cuda device.

* Proper cudnn usage.

* Bugfix.
2023-08-14 21:30:41 +01:00
c84883ecf2 Add a cuda kernel for upsampling. (#441)
* Add a cuda kernel for upsampling.

* Update for the latest tokenizers version.
2023-08-14 13:12:17 +01:00
a094dc503d Add a cuda kernel for avg-pool2d. (#440)
* Add a cuda kernel for avg-pool2d.

* Avoid running out of bounds.

* Finish wiring the avg pool kernel + add some testing.

* Support for max-pool + testing.
2023-08-14 12:32:05 +01:00
34f4b3187e Add a naive conv2d cuda kernel. (#438)
* Add a naive conv2d cuda kernel.

* Proper conv2d support on the rust side.

* Conv1d testing on gpu.

* Also use the test on gpus.

* Fix the clean-ptx target.
2023-08-14 10:34:42 +01:00
eab54e4490 Fix the tests for mkl. (#437) 2023-08-14 08:09:27 +01:00
9e7e6e0288 Add dequantization for ggmls q4_0, q4_1, q5_0, q5_1 and q8_0 (#407)
* Added dequantization for `q4_0`, `q4_1`, `q5_0`, `q5_1` and `q8_0`

* expose `tensor_from_ggml` for external usage

* bugfixes & example
2023-08-13 23:22:57 +01:00
8bd2b22b33 Optimize the logit computations in the whisper example. (#434) 2023-08-13 22:00:13 +01:00
d379a76a9e Add a softmax bench. (#433)
* Add a softmax bench.

* Add the vectorized sum reduce.
2023-08-13 20:09:18 +01:00
9af438ac1b Track the conv2d operations in stable-diffusion. (#431)
* Track the conv2d operations in stable-diffusion.

* Add more tracing to stable-diffusion.

* Also trace the resnet bits.

* Trace the attention blocks.

* Also trace the attention inner part.

* Small tweak.
2023-08-13 15:58:26 +01:00
b1ff78f762 Allow using accelerate with stable-diffusion. (#430) 2023-08-13 14:14:20 +01:00
5a63b51f14 Add a matmul benchmark. (#429) 2023-08-13 13:41:03 +01:00
6d694554b8 Support longer sequences in language detection. (#428) 2023-08-13 13:16:15 +01:00
9aca398a4f More accelerate optimizations (#427)
* Add more tracing to the whisper example.

* Support accelerate in more examples.

* Use accelerate for pointwise functions.

* Use accelerate for binary operations too.

* Bugfix for binary operation: use the rhs before the lhs.
2023-08-13 12:53:34 +01:00
60cd1551ca Add a KV cache to whisper. (#426) 2023-08-12 21:17:08 +01:00
a0908d212c Add a -language argument. (#425) 2023-08-12 17:08:40 +01:00
972078e1ae Update the readme with the discord server and common errors. (#423) 2023-08-12 16:45:58 +01:00
16b89f5b83 fix: can directly save the loaded weights (#421) 2023-08-12 16:33:29 +01:00
0741ebbd51 More multilingual support for whisper. (#419)
* More multilingual support for whisper.

* Use the language token appropriately.
2023-08-12 15:32:52 +01:00
0c3f109faa Basic multilingual support for whisper (#417)
* Multi-lingual support for whisper.

* Avoid hardcoding the token names.

* More multi-lingual support.

* Remove the todo.
2023-08-12 11:23:04 +01:00
2ba6b2826f Fix the readme instructions for stable-diffusion. (#415) 2023-08-11 18:59:04 +01:00
1d0157bbc4 Stable diffusion: retrieve the model files from the HF hub. (#414)
* Retrieve the model files from the HF hub in the stable diffusion example.

* Add to the readme.
2023-08-11 18:57:06 +01:00
91dbf907d3 Add more whisper variants. (#413) 2023-08-11 17:33:55 +01:00
e12372021b Expose the tensor write-bytes function. (#412) 2023-08-11 17:13:42 +01:00
55e428c8ae Expose the varmap inner data. (#411) 2023-08-11 16:58:56 +01:00
01ea57da8c Fix the conv tests. (#409) 2023-08-11 14:59:54 +01:00
662db45fc3 Use zero padding in conv1d and conv2d (same as pytorch). (#408) 2023-08-11 14:53:05 +01:00
906c0f3eb5 Remove the checkpoint conversion script. (#405)
* Remove the checkpoint conversion script.

* Remove references to the script.
2023-08-11 05:59:48 +01:00
e29c7809ec Parallelise the CPU kernels for the conv ops. (#401)
* Parallelise the conv2d op.

* Tighter control on threading.

* Also parallelise conv1d.

* Add some safety comment.
2023-08-11 05:51:58 +01:00
a325c1aa50 Upsample test + bugfix. (#399) 2023-08-10 21:02:35 +02:00
b6cf26e48e Merge pull request #393 from huggingface/older_gpus
Working on older GPUs (still not compute 52 it seems but > 6 could be OK)
2023-08-10 20:49:23 +02:00
379eadc68e Working now. 2023-08-10 19:43:25 +02:00
7e4fbc1e17 [DO NOT MERGE] temporary PR so users can try out on older GPUs. 2023-08-10 19:36:31 +02:00
80f0482f26 Fix the stable-diffusion vae. (#398)
* Fix the stable-diffusion vae.

* Fix for saving images.
2023-08-10 18:24:31 +01:00
94eff56aee Optimize the cpu conv2d kernel (#396)
* Conv2d simd optimization.

* Fix the contiguous copying.

* Small tweak.
2023-08-10 17:40:09 +01:00
a55133effd Merge pull request #395 from huggingface/fix_compat_windows
Compat windows.
2023-08-10 18:05:12 +02:00
ff53f38467 Small example for benchmarking some cpu ops (#394)
* Refactor the benchmark example.

* Rename the example.

* Add some comments.
2023-08-10 17:00:17 +01:00
4a95d34c83 Compat windows. 2023-08-10 17:46:47 +02:00
7f710a573d Merge pull request #374 from Rocketknight1/readme_fixes
README.md typos and grammar fixes
2023-08-10 16:34:19 +02:00
c8039579a5 Conv1d optimize (#392)
* Reorder the conv1d loops in the cpu backend.

* Optimize the 1d convolution.

* Conv1D optimize.

* Fix some clippy lints.
2023-08-10 15:23:52 +01:00
0b0fa56978 Merge pull request #386 from huggingface/enabling_61_maybe
This is duplicated code on Cuda 12.2.
2023-08-10 16:23:17 +02:00
385f0d261c Normalize embeddings in the bert example. (#390) 2023-08-10 13:05:55 +01:00
b765f2c37f Update the wasm build instructions. (#389) 2023-08-10 11:29:43 +01:00
66d1c093e0 This is duplicated code on Cuda 12.2.
Without it we can compile for 52 (but I get Operation Not supported
when actually trying to use those kernels).
2023-08-10 09:20:18 +02:00
de7c31bfe9 Merge pull request #368 from huggingface/add_cuda_ci
Adding cuda CI
2023-08-10 08:49:39 +02:00
8e7ef96588 Fix CI cuda. 2023-08-10 08:47:15 +02:00
f3fe730a30 Npy tweaks & error with path (#384)
* Simplify the npy writing.

* Wrap the file path so as to provide better errors.
2023-08-10 06:21:58 +01:00
c7f92f985e Further randn tweaks: use the appropriate rng rather than the f64 one, some cleanup. (#383) 2023-08-10 05:48:19 +01:00
Lei
3bbc08a8df Fix randn cpu (#382)
* Change distributions

Standard generates in [0, 1), Normal is correct.

* Add test

Not sure if this is the best place to put  the test

* Remove unnecessary use
2023-08-10 05:33:44 +01:00
6a2137af4f Update README.md 2023-08-10 00:19:58 +01:00
0dc1e5f387 Merge branch 'main' into readme_fixes 2023-08-10 00:19:20 +01:00
bd2fb6216b Testing in release mode because debug is too slow. 2023-08-09 23:19:55 +02:00
3542b26143 ssl update. 2023-08-09 23:11:45 +02:00
a690f14a77 Fix by hardcoding paths 2023-08-09 23:08:50 +02:00
90d778c059 ? 2023-08-09 23:02:11 +02:00
171fcbe539 CI ssh in the meantime. 2023-08-09 22:58:47 +02:00
07e83c55c0 Attempt nb2 2023-08-09 22:47:01 +02:00
25ec2d9f6b fix: remove incorrect unwrap (#379) 2023-08-09 21:45:24 +01:00
da26e2832c Update gemm to 0.15.6. (#378) 2023-08-09 21:04:28 +01:00
fcfdcbd337 Add a conv1d benchmark based on the whisper sizes. (#377)
* Add a conv1d benchmark based on the whisper sizes.

* Enforce the batch-dim in conv1d.
2023-08-09 20:27:03 +01:00
653ec5abc1 Update README.md (#376)
add missing word
2023-08-09 20:09:21 +01:00
c3a0761e62 Add some tracing to the whisper example. (#375) 2023-08-09 19:58:36 +01:00
0cef3998fd README.md typos and grammar fixes 2023-08-09 19:36:03 +01:00
e5f510d209 SSH to debug. 2023-08-09 19:54:40 +02:00
0dd94eff4c Merge pull request #367 from eltociear/eltociear-patch-1
Update README.md
2023-08-09 19:48:31 +02:00
a3b1699409 Embed the mel filters in the whisper binary. (#373) 2023-08-09 18:27:26 +01:00
5b79b38bc7 Remove extra square bracket (#372) 2023-08-09 18:24:28 +01:00
a5c5a893aa add max_pool2d (#371)
Co-authored-by: 赵理山 <ls@zhaolishandeMacBook-Air.local>
2023-08-09 18:05:26 +01:00
e6ce47f9e0 ? 2023-08-09 19:00:25 +02:00
1892bd139c Extract the strides in the conv ops. (#370) 2023-08-09 17:57:05 +01:00
749c8c7f51 Better rust GH action. 2023-08-09 18:42:53 +02:00
d9b4fef189 Chnage name 2023-08-09 18:14:29 +02:00
8fa329aca2 Adding cuda CI 2023-08-09 18:13:27 +02:00
cd225bd3b1 More testing for avg-pool2d. (#366)
* More testing for avg-pool2d.

* Another fix.

* Add a max-pool test with non-divisible kernel sizes.
2023-08-09 16:12:23 +01:00
a4f6977087 Update README.md
dauting -> daunting
2023-08-10 00:11:11 +09:00
dece0b8a76 Merge pull request #263 from huggingface/book_3
Book 3 (advanced loading + hub)
2023-08-09 16:50:11 +02:00
b80348d22f Bugfix for avg-pool + add some test. (#365) 2023-08-09 15:44:16 +01:00
3a62aee91f Write the generated images using the image crate. (#363)
* Use the image crate to write the generated images.

* Make the dependency optional.
2023-08-09 15:26:44 +01:00
be21d7e75a Fix the padding used in stable diffusion. (#362) 2023-08-09 13:23:59 +01:00
9c4cf6804b Merge pull request #355 from cksac/fix_book
fix repo link
2023-08-09 09:08:16 +02:00
dbc6f281c9 Conv1d test with padding. (#356) 2023-08-09 05:45:38 +01:00
47a5bee249 fix repo link 2023-08-09 11:29:48 +08:00
cf965ecaa8 Simplify the conv1d and conv2d code. (#352) 2023-08-08 22:10:59 +01:00
b9864e1357 Fix size-in-bytes for u8. (#351) 2023-08-08 21:15:18 +01:00
608b2358c6 Add some conv1d test + bugfix using padding. (#349) 2023-08-08 20:50:20 +01:00
1e6dbeac01 Add some conv2d tests. (#347)
* Add some conv2d tests.

* Add a simpler conv2d test.

* More conv2d testing + bugfix.

* Add a todo.
2023-08-08 19:02:42 +01:00
13ce68ff9b Bugfix for conv2d. (#343) 2023-08-08 15:20:00 +01:00
89d3926c9b Fixes for the stable diffusion example. (#342)
* Fixes for the stable diffusion example.

* Bugfix.

* Another fix.

* Fix for group-norm.

* More fixes to get SD to work.
2023-08-08 14:57:09 +01:00
ab35684326 Naive implementation for conv2d. (#341) 2023-08-08 06:34:36 +01:00
b5bb5e056d Add more conv2d support. (#340)
* Add more conv2d support.

* Conv2d cpu work.

* Conv2d output shape.
2023-08-08 06:04:32 +01:00
d0d7010682 CPU implementation for upsample-nearest2d. (#339) 2023-08-07 20:07:10 +01:00
fc265d9dcf Some CLIP fixes for stable diffusion. (#338)
* Some CLIP fixes for stable diffusion.

* Add the avg-pool2d operation on cpu.
2023-08-07 18:31:45 +01:00
2345b8ce3f Skeleton for the avg-pool2d and upsample-nearest2d ops. (#337)
* Skeleton for the avg-pool2d and upsample-nearest2d ops.

* Preliminary conv2d support.
2023-08-07 16:15:38 +01:00
f53a333ea9 Simple pad support. (#336)
* Simple pad support.

* Fix the tensor indexing when padding.
2023-08-07 15:24:56 +01:00
e72ba0b9e7 Add the license files. (#335) 2023-08-07 14:11:27 +01:00
5bb2fce998 Implement group-norm. (#334)
* Implement group-norm.

* Add some testing for group-norm.
2023-08-07 06:53:05 +01:00
2c9f605976 Add rand-like/randn-like. (#333) 2023-08-06 21:51:08 +01:00
141df4ad2b Main diffusion loop for the SD example. (#332) 2023-08-06 21:39:53 +01:00
166bfd5847 Add the recip op + use it in stable-diffusion. (#331)
* Add the recip unary op.

* Fix the cuda kernel.

* Use the recip op in sigmoid.
2023-08-06 21:14:52 +01:00
1c062bf06b Add the ddim scheduler. (#330) 2023-08-06 20:44:00 +01:00
d34039e352 Add a stable diffusion example (#328)
* Start adding a stable-diffusion example.

* Proper computation of the causal mask.

* Add the chunk operation.

* Work in progress: port the attention module.

* Add some dummy modules for conv2d and group-norm, get the attention module to compile.

* Re-enable the 2d convolution.

* Add the embeddings module.

* Add the resnet module.

* Add the unet blocks.

* Add the unet.

* And add the variational auto-encoder.

* Use the pad function from utils.
2023-08-06 17:49:43 +01:00
93cfe5642f Pyo3 dtype (#327)
* Better handling of dtypes in pyo3.

* More pyo3 dtype.
2023-08-06 10:17:43 +01:00
88bd3b604a Add some tensor creation functions to the pyo3 bindings. (#326) 2023-08-06 06:50:33 +01:00
b278834267 Support the Accelerate BLAS on macOS. (#325)
* Add the accelerate feature.

* Ffi tweaks.
2023-08-05 17:25:24 +01:00
0b175fcbbd Fix the pyo3 build for macos. (#324)
* Fix the pyo3 build for macos.

* rustfmt fix.
2023-08-05 14:53:57 +01:00
620f83cf66 Add the candle-datasets crate (#322)
* Move the vision datasets to a separate crate.

* Move the batcher bits.

* Update the readme.

* Move the tiny-stories bits.

---------

Co-authored-by: Jane Doe <jane.doe@example.org>
2023-08-05 08:56:50 +01:00
f7b2a0391d Transpose the weight matrixes for llama2.c. (#321) 2023-08-04 13:32:20 +01:00
8b6f5be1cc Support q5k quantized data. (#320) 2023-08-04 09:51:30 +01:00
df6667ba88 Add some tracing to llama. (#318) 2023-08-03 13:52:22 +01:00
a79286885c Support safetensors weights in llama2.c inference. (#317) 2023-08-03 11:10:58 +01:00
74845a4dcd Use the assert! function as it turns out to be const. (#316) 2023-08-03 10:03:43 +01:00
aa76b783eb Q6K dequantization. (#315) 2023-08-03 09:31:20 +01:00
25564357f7 Support some ggml quantized types (#314)
* Add the quantized types for GGML loading.

* Support quantization for Q2K.

* More quantization support.

* Fix some clippy lints.
2023-08-03 09:16:26 +01:00
634700d84a Use some consts for ggml values. (#312) 2023-08-02 22:03:05 +01:00
e635f18eda Initial support for reading ggml files. (#311)
* Start adding support for reading ggml files.

* Compute the proper tensor size.

* Print the read tensors.

* Fix file reading.
2023-08-02 21:59:02 +01:00
dba31473d4 Typos and format and CD only when PR lands. 2023-08-02 19:18:43 +02:00
1b2b32e58d Remove dead page.t 2023-08-02 18:59:36 +02:00
166f4d1101 s/candle/candle_core/g 2023-08-02 18:40:24 +02:00
ae68635af9 Add small error management. 2023-08-02 18:40:24 +02:00
c11e78b334 Odd rebase artifact. 2023-08-02 18:40:24 +02:00
1b705a426f Remove duplicate. 2023-08-02 18:40:24 +02:00
a70b95f9e7 Marking unwritten chapters as Draft (disables the link). 2023-08-02 18:40:24 +02:00
a44471a305 Adding more details on how to load things.
- Loading with memmap
- Loading a sharded tensor
- Moved some snippets to `candle-examples/src/lib.rs` This is because
managing book specific dependencies is a pain https://github.com/rust-lang/mdBook/issues/706
- This causes a non aligned inclusion  https://github.com/rust-lang/mdBook/pull/1856 which we have
to ignore fmt to remove.

mdbook might need some more love :)
2023-08-02 18:40:24 +02:00
45642a8530 Fixing examples. 2023-08-02 18:40:24 +02:00
82464166e4 3rd phase. 2023-08-02 18:40:24 +02:00
52414ba5c8 Bugfix for the llama2 wasm example. (#310)
* Clean-up the llama2.c wasm example.

* Use a proper tokenizer.

* Add a prompt.

* Bugfix for the llama2 wasm example.
2023-08-02 17:32:36 +01:00
186c308d51 Wasm llama2 tweaks (#309)
* Clean-up the llama2.c wasm example.

* Use a proper tokenizer.
2023-08-02 15:49:43 +01:00
123 changed files with 9640 additions and 876 deletions

View File

@ -1,7 +1,5 @@
name: Deploy Rust book
on:
# TODO put this back only when merging after this PR lands.
pull_request:
push:
branches:
- main

87
.github/workflows/ci_cuda.yaml vendored Normal file
View File

@ -0,0 +1,87 @@
name: CI / cuda
on:
workflow_dispatch:
pull_request:
jobs:
start-runner:
name: Start self-hosted EC2 runner
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
EC2_AMI_ID: ami-03cfed9ea28f4b002
EC2_INSTANCE_TYPE: g5.xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
EC2_SECURITY_GROUP: sg-030175c435ac141d6
outputs:
label: ${{ steps.start-ec2-runner.outputs.label }}
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Start EC2 runner
id: start-ec2-runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: start
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
ec2-image-id: ${{ env.EC2_AMI_ID }}
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
subnet-id: ${{ env.EC2_SUBNET_ID }}
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
aws-resource-tags: > # optional, requires additional permissions
[
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
]
test-cuda:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs: start-runner # required to start the main job when the runner is ready
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
permissions:
contents: write
packages: write
# This is used to complete the identity challenge
# with sigstore/fulcio when running outside of PRs.
id-token: write
security-events: write
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install Rust Stable
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2
- run: apt update -y && apt install libssl-dev -y
- name: Test (cuda)
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:
name: Stop self-hosted EC2 runner
needs:
- start-runner
- test-cuda
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Stop EC2 runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: stop
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}

1
.gitignore vendored
View File

@ -20,6 +20,7 @@ Cargo.lock
perf.data
flamegraph.svg
*.dylib
*.so
*.swp
trace-*.json

View File

@ -1,6 +1,7 @@
[workspace]
members = [
"candle-core",
"candle-datasets",
"candle-examples",
"candle-nn",
"candle-pyo3",
@ -14,23 +15,25 @@ exclude = [
]
[workspace.package]
version = "0.1.0"
version = "0.1.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
license = "MIT OR Apache-2.0"
[workspace.dependencies]
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.13", features = ["f16"] }
cudarc = { version = "0.9.14", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.15.5", package = "candle-gemm" }
gemm = { version = "0.15.6", package = "candle-gemm" }
hf-hub = "0.2.0"
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
@ -38,11 +41,13 @@ memmap2 = "0.7.1"
num_cpus = "1.15.0"
num-traits = "0.2.15"
rand = "0.8.5"
rand_distr = "0.4.3"
rayon = "1.7.0"
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.99"
thiserror = "1"
tokenizers = { version = "0.13.3", default-features = false }
tokenizers = { version = "0.13.4", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"

201
LICENSE-APACHE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

23
LICENSE-MIT Normal file
View File

@ -0,0 +1,23 @@
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

View File

@ -2,6 +2,8 @@ clean-ptx:
find target -name "*.ptx" -type f -delete
echo "" > candle-kernels/src/lib.rs
touch candle-kernels/build.rs
touch candle-examples/build.rs
touch candle-flash-attn/build.rs
clean:
cargo clean
@ -9,4 +11,8 @@ clean:
test:
cargo test
pyo3-test:
cargo build --profile=release-with-debug --package candle-pyo3
python3 candle-pyo3/test.py
all: test

View File

@ -1,10 +1,11 @@
# candle
[![discord server](https://dcbadge.vercel.app/api/server/hugging-face-879548962464493619)](https://discord.com/channels/879548962464493619/1136218819447238726)
[![Latest version](https://img.shields.io/crates/v/candle-core.svg)](https://crates.io/crates/candle-core)
[![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core)
![License](https://img.shields.io/crates/l/candle-core.svg)
Candle is a minimalist ML framework for Rust with a focus on easiness of use and
on performance (including GPU support). Try our online demos:
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
and ease of use. Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
@ -26,6 +27,8 @@ Check out our [examples](./candle-examples/examples/):
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
generation.
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, yet to be optimized.
Run them using the following commands:
```
@ -34,6 +37,7 @@ cargo run --example llama --release
cargo run --example falcon --release
cargo run --example bert --release
cargo run --example bigcode --release
cargo run --example stable-diffusion --release --features image -- --prompt "a rusty robot holding a fire torch"
```
In order to use **CUDA** add `--features cuda` to the example command line.
@ -48,37 +52,40 @@ For llama2, run the following command to retrieve the weight files and start a
test server:
```bash
cd candle-wasm-examples/llama2-c
wget https://karpathy.ai/llama2c/model.bin
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
trunk serve --release --public-url /candle-llama2/ --port 8081
```
And then browse to
And then head over to
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
<!--- ANCHOR: features --->
## Features
- Simple syntax, looks and like PyTorch.
- CPU and Cuda backends, m1, f16, bf16.
- Enable serverless (CPU), small and fast deployments
- WASM support, run your models in a browser.
- Model training.
- Distributed computing using NCCL.
- Models out of the box: Llama, Whisper, Falcon, StarCoder...
- Embed user-defined ops/kernels, such as [flash-attention
v2](https://github.com/LaurentMazare/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
- Simple syntax, looks and feels like PyTorch.
- Model training.
- Embed user-defined ops/kernels, such as [flash-attention v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
- Backends.
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
- WASM support, run your models in a browser.
- Model support out of the box.
- LLMs: Llama v1 and v2, Falcon, StarCoder.
- Whisper.
- Stable Diffusion.
- Serverless (on CPU), small and fast deployments.
<!--- ANCHOR_END: features --->
## How to use ?
## How to use
<!--- ANCHOR: cheatsheet --->
Cheatsheet:
| | Using PyTorch | Using Candle |
|------------|------------------------------------------|------------------------------------------------------------------|
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.]], [3., 4.]], &Device::Cpu)?` |
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?` |
| Creation | `torch.zeros((2, 2))` | `Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?` |
| Indexing | `tensor[:, :4]` | `tensor.i((.., ..4))?` |
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
@ -95,43 +102,46 @@ Cheatsheet:
## Structure
- [candle-core](./candle-core): Core ops, devices, and `Tensor` struct definition
- [candle-nn](./candle-nn/): Facilities to build real models
- [candle-examples](./candle-examples/): Real-world like examples on how to use the library in real settings
- [candle-nn](./candle-nn/): Tools to build real models
- [candle-examples](./candle-examples/): Examples of using the library in realistic settings
- [candle-kernels](./candle-kernels/): CUDA custom kernels
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
- [candle-transformers](./candle-transformers): transformers-related utilities.
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
## FAQ
### Why Candle?
### Why should I use Candle?
Candle stems from the need to reduce binary size in order to *enable serverless*
possible by making the whole engine smaller than PyTorch very large library volume.
This enables creating runtimes on a cluster much faster.
Candle's core goal is to *make serverless inference possible*. Full machine learning frameworks like PyTorch
are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight
binaries.
And simply *removing Python* from production workloads.
Python can really add overhead in more complex workflows and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
Secondly, Candle lets you *remove Python* from production workloads. Python overhead can seriously hurt performance,
and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
### Other ML frameworks
- [dfdx](https://github.com/coreylowman/dfdx) is a formidable crate, with shapes being included
in types preventing a lot of headaches by getting compiler to complain about shape mismatch right off the bat
However we found that some features still require nightly and writing code can be a bit dauting for non rust experts.
in types. This prevents a lot of headaches by getting the compiler to complain about shape mismatches right off the bat.
However, we found that some features still require nightly, and writing code can be a bit daunting for non rust experts.
We're leveraging and contributing to other core crates for the runtime so hopefully both crates can benefit from each
other
other.
- [burn](https://github.com/burn-rs/burn) is a general crate that can leverage multiple backends so you can choose the best
engine for your workload
engine for your workload.
- [tch-rs](https://github.com/LaurentMazare/tch-rs.git) Bindings to the torch library in Rust. Extremely versatile, but they
do bring in the entire torch library into the runtime. The main contributor of `tch-rs` is also involved in the development
bring in the entire torch library into the runtime. The main contributor of `tch-rs` is also involved in the development
of `candle`.
### Missing symbols when compiling with the mkl feature.
### Common Errors
#### Missing symbols when compiling with the mkl feature.
If you get some missing symbols when compiling binaries/tests using the mkl
features, e.g.:
@ -144,13 +154,25 @@ features, e.g.:
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo (see https://doc.rust-lang.org/cargo/reference/build-scripts.html#cargorustc-link-libkindname)
```
This is likely due to some missing linker flag that enable the mkl library. You
This is likely due to a missing linker flag that was needed to enable the mkl library. You
can try adding the following at the top of your binary:
```
extern crate intel_mkl_src;
```
### How to know where an error comes from.
#### Cannot run llama example : access to source requires login credentials
```
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401
```
This is likely because you're not permissioned for the llama-v2 model. To fix
this, you have to register on the huggingface-hub, accept the [llama-v2 model
conditions](https://huggingface.co/meta-llama/Llama-2-7b-hf), and set up your
authentication token. See issue
[#350](https://github.com/huggingface/candle/issues/350) for more details.
#### Tracking down errors
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
error is generated.

View File

@ -12,16 +12,16 @@
- [Running a model](inference/README.md)
- [Using the hub](inference/hub.md)
- [Serialization](inference/serialization.md)
- [Advanced Cuda usage](inference/cuda/README.md)
- [Writing a custom kernel](inference/cuda/writing.md)
- [Porting a custom kernel](inference/cuda/porting.md)
- [Error management](error_manage.md)
- [Creating apps](apps/README.md)
- [Creating a WASM app](apps/wasm.md)
- [Creating a REST api webserver](apps/rest.md)
- [Creating a desktop Tauri app](apps/dekstop.md)
- [Training](training/README.md)
- [MNIST](training/mnist.md)
- [Fine-tuning](training/finetuning.md)
- [Using MKL](advanced/mkl.md)
- [Error management]()
- [Advanced Cuda usage]()
- [Writing a custom kernel]()
- [Porting a custom kernel]()
- [Using MKL]()
- [Creating apps]()
- [Creating a WASM app]()
- [Creating a REST api webserver]()
- [Creating a desktop Tauri app]()
- [Training]()
- [MNIST]()
- [Fine-tuning]()
- [Serialization]()

View File

@ -0,0 +1 @@
# Advanced Cuda usage

View File

@ -0,0 +1 @@
# Porting a custom kernel

View File

@ -0,0 +1 @@
# Writing a custom kernel

View File

@ -1 +1,51 @@
# Error management
You might have seen in the code base a lot of `.unwrap()` or `?`.
If you're unfamiliar with Rust check out the [Rust book](https://doc.rust-lang.org/book/ch09-02-recoverable-errors-with-result.html)
for more information.
What's important to know though, is that if you want to know *where* a particular operation failed
You can simply use `RUST_BACKTRACE=1` to get the location of where the model actually failed.
Let's see on failing code:
```rust,ignore
let x = Tensor::zeros((1, 784), DType::F32, &device)?;
let y = Tensor::zeros((1, 784), DType::F32, &device)?;
let z = x.matmul(&y)?;
```
Will print at runtime:
```bash
Error: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }
```
After adding `RUST_BACKTRACE=1`:
```bash
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
```
Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
especially in release builds. We're using [`anyhow`](https://docs.rs/anyhow/latest/anyhow/) for that.
The library is still young, please [report](https://github.com/LaurentMazare/candle/issues) any issues detecting where an error is coming from.
## Cuda error management
When running a model on Cuda, you might get a stacktrace not really representing the error.
The reason is that CUDA is async by nature, and therefore the error might be caught while you were sending totally different kernels.
One way to avoid this is to use `CUDA_LAUNCH_BLOCKING=1` as an environment variable. This will force every kernel to be launched sequentially.
You might still however see the error happening on other kernels as the faulty kernel might exit without an error but spoiling some pointer for which the error will happen when dropping the `CudaSlice` only.
If this occurs, you can use [`compute-sanitizer`](https://docs.nvidia.com/compute-sanitizer/ComputeSanitizer/index.html)
This tool is like `valgrind` but for cuda. It will help locate the errors in the kernels.

View File

@ -128,17 +128,17 @@ fn main() -> Result<()> {
```
Now it works, it is a great way to create your own layers.
But most of the classical layers are already implemented in [candle-nn](https://github.com/LaurentMazare/candle/tree/main/candle-nn).
But most of the classical layers are already implemented in [candle-nn](https://github.com/huggingface/candle/tree/main/candle-nn).
## Using `candle_nn`.
For instance [Linear](https://github.com/LaurentMazare/candle/blob/main/candle-nn/src/linear.rs) is already there.
For instance [Linear](https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs) is already there.
This Linear is coded with PyTorch layout in mind, to reuse better existing models out there, so it uses the transpose of the weights and not the weights directly.
So instead we can simplify our example:
```bash
cargo add --git https://github.com/LaurentMazare/candle.git candle-nn
cargo add --git https://github.com/huggingface/candle.git candle-nn
```
And rewrite our examples using it

View File

@ -5,13 +5,13 @@ Start by creating a new app:
```bash
cargo new myapp
cd myapp
cargo add --git https://github.com/LaurentMazare/candle.git candle
cargo add --git https://github.com/huggingface/candle.git candle-core
```
At this point, candle will be built **without** CUDA support.
To get CUDA support use the `cuda` feature
```bash
cargo add --git https://github.com/LaurentMazare/candle.git candle --features cuda
cargo add --git https://github.com/huggingface/candle.git candle-core --features cuda
```
You can check everything works properly:

View File

@ -1 +1,7 @@
# Running a model
In order to run an existing model, you will need to download and use existing weights.
Most models are already available on https://huggingface.co/ in [`safetensors`](https://github.com/huggingface/safetensors) format.
Let's get started by running an old model : `bert-base-uncased`.

View File

@ -1 +1,104 @@
# Using the hub
Install the [`hf-hub`](https://github.com/huggingface/hf-hub) crate:
```bash
cargo add hf-hub
```
Then let's start by downloading the [model file](https://huggingface.co/bert-base-uncased/tree/main).
```rust
# extern crate candle_core;
# extern crate hf_hub;
use hf_hub::api::sync::Api;
use candle_core::Device;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights = repo.get("model.safetensors").unwrap();
let weights = candle_core::safetensors::load(weights, &Device::Cpu);
```
We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.
You can check all the names of the tensors [here](https://huggingface.co/bert-base-uncased?show_tensors=true)
## Using async
`hf-hub` comes with an async API.
```bash
cargo add hf-hub --features tokio
```
```rust,ignore
# This is tested directly in examples crate because it needs external dependencies unfortunately:
# See [this](https://github.com/rust-lang/mdBook/issues/706)
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
```
## Using in a real model.
Now that we have our weights, we can use them in our bert architecture:
```rust
# extern crate candle_core;
# extern crate candle_nn;
# extern crate hf_hub;
# use hf_hub::api::sync::Api;
#
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());
#
# let weights = repo.get("model.safetensors").unwrap();
use candle_core::{Device, Tensor, DType};
use candle_nn::Linear;
let weights = candle_core::safetensors::load(weights, &Device::Cpu).unwrap();
let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();
let linear = Linear::new(weight.clone(), Some(bias.clone()));
let input_ids = Tensor::zeros((3, 768), DType::F32, &Device::Cpu).unwrap();
let output = linear.forward(&input_ids).unwrap();
```
For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example.
## Memory mapping
For more efficient loading, instead of reading the file, you could use [`memmap2`](https://docs.rs/memmap2/latest/memmap2/)
**Note**: Be careful about memory mapping it seems to cause issues on [Windows, WSL](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5893)
and will definitely be slower on network mounted disk, because it will issue more read calls.
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
```
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
In practice model files should never be modified, and the mmaps should be mostly READONLY anyway, so the caveat most likely does not apply, but always keep it in mind.
## Tensor Parallel Sharding
When using multiple GPUs to use in Tensor Parallel in order to get good latency, you can load only the part of the Tensor you need.
For that you need to use [`safetensors`](https://crates.io/crates/safetensors) directly.
```bash
cargo add safetensors
```
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
```

View File

@ -10,8 +10,9 @@ license.workspace = true
readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.1.0", optional = true }
candle-kernels = { path = "../candle-kernels", version = "0.1.1", optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
@ -21,14 +22,19 @@ memmap2 = { workspace = true }
num-traits = { workspace = true }
num_cpus = { workspace = true }
rand = { workspace = true }
rand_distr = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
zip = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
clap = { workspace = true }
[features]
default = []
cuda = ["dep:cudarc", "dep:candle-kernels"]
cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]

View File

@ -1,29 +1,18 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
let c = a.matmul(&b)?;
println!("{a} {b} {c}");
let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
let t1 = Tensor::new(data, &Device::Cpu)?;
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];
let t2 = Tensor::new(data2, &Device::Cpu)?;
assert_eq!(
Tensor::cat(&[&t1.t()?, &t2.t()?], 1)?
.t()?
.to_vec2::<f32>()?,
[
[3.0, 1.0, 4.0, 1.0, 5.0],
[2.0, 7.0, 1.0, 8.0, 2.0],
[5.0, 5.0, 5.0, 5.0, 5.0],
[2.0, 7.0, 1.0, 8.0, 2.0]
]
);
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
let start = std::time::Instant::now();
let res = inp.conv2d(&w, 0, 1);
println!("{:?}", start.elapsed());
println!("{res:?}");
Ok(())
}

View File

@ -0,0 +1,142 @@
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_core::{Device, Result, Tensor, D};
use clap::{Parser, Subcommand};
fn softmax<D: candle_core::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
let dim = dim.to_index(xs.shape(), "softmax")?;
let max = xs.max_keepdim(dim)?;
let diff = xs.broadcast_sub(&max)?;
let num = diff.exp()?;
let den = num.sum_keepdim(dim)?;
num.broadcast_div(&den)
}
trait Benchmark {
type PreProcessData;
type RunResult;
fn preprocess() -> Result<Self::PreProcessData>;
fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>;
const ITERS: usize;
}
// Conv1d example as used in whisper.
struct Conv1d;
impl Benchmark for Conv1d {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
Ok((inp, w))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.conv1d(&d.1, 0, 1)
}
const ITERS: usize = 5;
}
// Conv2d example as used in stable-diffusion.
struct Conv2d;
impl Benchmark for Conv2d {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
Ok((inp, w))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.conv2d(&d.1, 0, 1)
}
const ITERS: usize = 1;
}
struct Matmul;
impl Benchmark for Matmul {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
Ok((lhs, rhs))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.matmul(&d.1)
}
const ITERS: usize = 100;
}
struct Softmax;
impl Benchmark for Softmax {
type PreProcessData = Tensor;
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
// Typical whisper tiny size.
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
Ok(x)
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
softmax(d, D::Minus1)
}
const ITERS: usize = 100;
}
fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
use std::hint::black_box;
let iters = iters.unwrap_or(B::ITERS);
let d = B::preprocess()?;
let start = std::time::Instant::now();
for _iter in 0..iters {
let _res = black_box(B::run_one(black_box(&d))?);
}
println!("{:?}", start.elapsed() / iters as u32);
Ok(())
}
#[derive(Subcommand, Debug, Clone)]
enum Task {
Conv1d,
Conv2d,
Matmul,
Softmax,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// The benchmark to be run.
#[command(subcommand)]
task: Task,
#[arg(long)]
iters: Option<usize>,
}
fn main() -> Result<()> {
let args = Args::parse();
match args.task {
Task::Conv1d => run::<Conv1d>(args.iters)?,
Task::Conv2d => run::<Conv2d>(args.iters)?,
Task::Matmul => run::<Matmul>(args.iters)?,
Task::Softmax => run::<Softmax>(args.iters)?,
}
Ok(())
}

View File

@ -1,3 +1,6 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@ -6,10 +9,9 @@ use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
let sum = t.sum_keepdim(0)?;
println!("{sum}");
let sum = t.sum_keepdim(1)?;
println!("{sum}");
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
let res = t.conv2d(&w, 1, 1)?;
println!("{res:?}");
Ok(())
}

View File

@ -1,6 +1,9 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use std::str::FromStr;
use anyhow::Result;

View File

@ -0,0 +1,350 @@
#![allow(dead_code)]
use libc::{c_char, c_double, c_float, c_int, c_long, c_ulong};
mod ffi {
use super::*;
extern "C" {
// It would be nice to be able to switch to the NEWLAPACK version of the function but this
// seems to trigger some link error. Available function names can be seen here:
// /Library/Developer/CommandLineTools/SDKs/MacOSX13.3.sdk/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate.tbd
#[link_name = "sgemm_"]
pub fn sgemm_ffi(
transa: *const c_char,
transb: *const c_char,
m: *const c_int,
n: *const c_int,
k: *const c_int,
alpha: *const c_float,
a: *const c_float,
lda: *const c_int,
b: *const c_float,
ldb: *const c_int,
beta: *const c_float,
c: *mut c_float,
ldc: *const c_int,
);
#[link_name = "dgemm_"]
pub fn dgemm_ffi(
transa: *const c_char,
transb: *const c_char,
m: *const c_int,
n: *const c_int,
k: *const c_int,
alpha: *const c_double,
a: *const c_double,
lda: *const c_int,
b: *const c_double,
ldb: *const c_int,
beta: *const c_double,
c: *mut c_double,
ldc: *const c_int,
);
pub fn vvexpf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvexp(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvsqrtf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvsqrt(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvsinf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvsin(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvcosf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vDSP_vaddD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vadd(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vsubD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vsub(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmulD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmul(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vdivD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vdiv(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
}
}
#[allow(clippy::too_many_arguments)]
#[inline]
pub unsafe fn sgemm(
transa: u8,
transb: u8,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: &[f32],
lda: i32,
b: &[f32],
ldb: i32,
beta: f32,
c: &mut [f32],
ldc: i32,
) {
ffi::sgemm_ffi(
&(transa as c_char),
&(transb as c_char),
&m,
&n,
&k,
&alpha,
a.as_ptr(),
&lda,
b.as_ptr(),
&ldb,
&beta,
c.as_mut_ptr(),
&ldc,
)
}
#[allow(clippy::too_many_arguments)]
#[inline]
pub unsafe fn dgemm(
transa: u8,
transb: u8,
m: i32,
n: i32,
k: i32,
alpha: f64,
a: &[f64],
lda: i32,
b: &[f64],
ldb: i32,
beta: f64,
c: &mut [f64],
ldc: i32,
) {
ffi::dgemm_ffi(
&(transa as c_char),
&(transb as c_char),
&m,
&n,
&k,
&alpha,
a.as_ptr(),
&lda,
b.as_ptr(),
&ldb,
&beta,
c.as_mut_ptr(),
&ldc,
)
}
#[inline]
pub fn vs_exp(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvexpf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_exp(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvexp(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_sqrt(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvsqrtf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_sqrt(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvsqrt(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_sin(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvsinf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_sin(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvsin(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_cos(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvcosf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_cos(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvlogf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_ln(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvlog(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_sqr(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
}
#[inline]
pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
}
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline]
pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) {
let a_len = a.len();
let b_len = b.len();
let y_len = y.len();
if a_len != y_len || b_len != y_len {
panic!(
"{} a,b,y len mismatch {a_len} {b_len} {y_len}",
stringify!($fn_name)
);
}
unsafe {
// Weird quirk of accelerate, the rhs comes before the lhs.
ffi::$accelerate_name(
b.as_ptr(),
1,
a.as_ptr(),
1,
y.as_mut_ptr(),
1,
a_len as u64,
)
}
}
};
}
binary_op!(vs_add, f32, vDSP_vadd);
binary_op!(vd_add, f64, vDSP_vaddD);
binary_op!(vs_sub, f32, vDSP_vsub);
binary_op!(vd_sub, f64, vDSP_vsubD);
binary_op!(vs_mul, f32, vDSP_vmul);
binary_op!(vd_mul, f64, vDSP_vmulD);
binary_op!(vs_div, f32, vDSP_vdiv);
binary_op!(vd_div, f64, vDSP_vdivD);

View File

@ -37,6 +37,18 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
fn conv2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConv2D,
) -> Result<Self>;
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,

View File

@ -55,6 +55,11 @@ impl Tensor {
kernel: rhs,
..
}
| Op::Conv2D {
arg: lhs,
kernel: rhs,
..
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
@ -81,6 +86,9 @@ impl Tensor {
}
}
Op::Reshape(node)
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
@ -163,6 +171,12 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
@ -291,6 +305,11 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)?
}
Op::Unary(arg, UnaryOp::Recip) => {
let sum_grad = grads.or_insert(arg)?;
let grad = (grad / arg.sqr()?)?;
*sum_grad = sum_grad.sub(&grad)?
}
&Op::Narrow(ref arg, dim, start_idx, len) => {
let arg_dims = arg.dims();
let left_pad = if start_idx == 0 {

View File

@ -1,6 +1,6 @@
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
pub(crate) b_size: Option<usize>,
pub(crate) b_size: usize,
// Maybe we should have a version without l_in as this bit depends on the input and not only on
// the weights.
pub(crate) l_in: usize,
@ -19,9 +19,35 @@ impl ParamsConv1D {
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
match self.b_size {
None => vec![self.c_out, l_out],
Some(n) => vec![n, self.c_out, l_out],
}
vec![self.b_size, self.c_out, l_out]
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv2D {
pub(crate) b_size: usize,
pub(crate) i_h: usize,
pub(crate) i_w: usize,
pub(crate) k_h: usize,
pub(crate) k_w: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
}
impl ParamsConv2D {
pub(crate) fn out_h(&self) -> usize {
let dilation = 1;
(self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
}
pub(crate) fn out_w(&self) -> usize {
let dilation = 1;
(self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
}
}

View File

@ -1,6 +1,6 @@
//! Implement conversion traits for tensors
use crate::{Device, Error, Tensor, WithDType};
use half::{bf16, f16};
use crate::{DType, Device, Error, Tensor, WithDType};
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::convert::TryFrom;
impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
@ -94,3 +94,46 @@ from_tensor!(f16);
from_tensor!(bf16);
from_tensor!(u32);
from_tensor!(u8);
impl Tensor {
pub fn write_bytes<W: std::io::Write>(&self, f: &mut W) -> crate::Result<()> {
use byteorder::{LittleEndian, WriteBytesExt};
let vs = self.flatten_all()?;
match self.dtype() {
DType::BF16 => {
let vs = vs.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
let vs = vs.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
for v in vs.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)?
}
}
DType::F64 => {
for v in vs.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)?
}
}
DType::U32 => {
for v in vs.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)?
}
}
DType::U8 => {
let vs = vs.to_vec1::<u8>()?;
f.write_all(&vs)?;
}
}
Ok(())
}
}

148
candle-core/src/cpu/avx.rs Normal file
View File

@ -0,0 +1,148 @@
use super::{Cpu, CpuF16};
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use half::f16;
pub struct CurrentCpu {}
const STEP: usize = 32;
const EPR: usize = 8;
const ARR: usize = STEP / EPR;
impl Cpu<ARR> for CurrentCpu {
type Unit = __m256;
type Array = [__m256; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
_mm256_setzero_ps()
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn from_f32(v: f32) -> Self::Unit {
_mm256_set1_ps(v)
}
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
_mm256_loadu_ps(mem_addr)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
_mm256_add_ps(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
_mm256_add_ps(_mm256_mul_ps(b, c), a)
}
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
_mm256_storeu_ps(mem_addr, a);
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..ARR / 2 {
x[2 * i] = _mm256_add_ps(x[2 * i], x[2 * i + 1]);
}
for i in 0..ARR / 4 {
x[4 * i] = _mm256_add_ps(x[4 * i], x[4 * i + 2]);
}
#[allow(clippy::reversed_empty_ranges)]
for i in 0..ARR / 8 {
x[8 * i] = _mm256_add_ps(x[8 * i], x[8 * i + 4]);
}
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
let t1 = _mm_hadd_ps(t0, t0);
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}
pub struct CurrentCpuF16 {}
impl CpuF16<ARR> for CurrentCpuF16 {
type Unit = __m256;
type Array = [__m256; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
_mm256_setzero_ps()
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn from_f32(v: f32) -> Self::Unit {
_mm256_set1_ps(v)
}
#[cfg(target_feature = "f16c")]
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
}
#[cfg(not(target_feature = "f16c"))]
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
let mut tmp = [0.0f32; 8];
for i in 0..8 {
tmp[i] = (*mem_addr.add(i)).to_f32();
}
_mm_loadu_ps(tmp.as_ptr())
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
_mm256_add_ps(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
_mm256_add_ps(_mm256_mul_ps(b, c), a)
}
#[cfg(target_feature = "f16c")]
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
}
#[cfg(not(target_feature = "f16c"))]
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
let mut tmp = [0.0f32; 8];
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
for i in 0..8 {
*mem_addr.add(i) = f16::from_f32(tmp[i]);
}
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
let mut offset = ARR >> 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
let t1 = _mm_hadd_ps(t0, t0);
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}

View File

@ -0,0 +1,89 @@
pub trait VecOps: num_traits::NumAssign + Copy {
/// Dot-product of two vectors.
///
/// # Safety
///
/// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
*res = Self::zero();
for i in 0..len {
*res += *lhs.add(i) * *rhs.add(i)
}
}
/// Sum of all elements in a vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len`. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
*res = Self::zero();
for i in 0..len {
*res += *xs.add(i)
}
}
}
impl VecOps for f32 {
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
super::vec_dot_f32(lhs, rhs, res, len)
}
#[inline(always)]
unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
super::vec_sum(xs, res, len)
}
}
impl VecOps for half::f16 {
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
*res = half::f16::from_f32(res_f32);
}
}
impl VecOps for f64 {}
impl VecOps for half::bf16 {}
impl VecOps for u8 {}
impl VecOps for u32 {}
#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
if n_threads == 1 {
func(0)
} else {
rayon::scope(|s| {
for thread_idx in 0..n_threads {
let func = &func;
s.spawn(move |_| func(thread_idx));
}
})
}
}
#[inline(always)]
pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
if n_threads == 1 {
for i in lo..up {
func(i)
}
} else {
rayon::scope(|s| {
for thread_idx in 0..n_threads {
let func = &func;
s.spawn(move |_| {
for i in (thread_idx..up).step_by(n_threads) {
func(i)
}
});
}
})
}
}

179
candle-core/src/cpu/mod.rs Normal file
View File

@ -0,0 +1,179 @@
pub mod kernels;
trait Cpu<const ARR: usize> {
type Unit;
type Array;
const STEP: usize;
const EPR: usize;
fn n() -> usize;
unsafe fn zero() -> Self::Unit;
unsafe fn zero_array() -> Self::Array;
unsafe fn load(mem_addr: *const f32) -> Self::Unit;
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
}
trait CpuF16<const ARR: usize> {
type Unit;
type Array;
const STEP: usize;
const EPR: usize;
fn n() -> usize;
unsafe fn zero() -> Self::Unit;
unsafe fn zero_array() -> Self::Array;
unsafe fn load(mem_addr: *const f16) -> Self::Unit;
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
}
use half::f16;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub mod avx;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub use avx::{CurrentCpu, CurrentCpuF16};
#[cfg(target_arch = "wasm32")]
#[cfg(target_feature = "simd128")]
pub mod simd128;
#[cfg(target_arch = "wasm32")]
#[cfg(target_feature = "simd128")]
pub use simd128::CurrentCpu;
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(target_feature = "neon")]
pub mod neon;
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(target_feature = "neon")]
pub use neon::CurrentCpu;
#[cfg(any(
target_feature = "neon",
target_feature = "avx",
target_feature = "simd128"
))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
let np = k & !(CurrentCpu::STEP - 1);
let mut sum = CurrentCpu::zero_array();
let mut ax = CurrentCpu::zero_array();
let mut ay = CurrentCpu::zero_array();
for i in (0..np).step_by(CurrentCpu::STEP) {
for j in 0..CurrentCpu::n() {
ax[j] = CurrentCpu::load(a_row.add(i + j * CurrentCpu::EPR));
ay[j] = CurrentCpu::load(b_row.add(i + j * CurrentCpu::EPR));
sum[j] = CurrentCpu::vec_fma(sum[j], ax[j], ay[j]);
}
}
CurrentCpu::vec_reduce(sum, c);
// leftovers
for i in np..k {
*c += *a_row.add(i) * (*b_row.add(i));
}
}
#[cfg(not(any(
target_feature = "neon",
target_feature = "avx",
target_feature = "simd128"
)))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
// leftovers
for i in 0..k {
*c += *a_row.add(i) * (*b_row.add(i));
}
}
#[cfg(any(
target_feature = "neon",
target_feature = "avx",
target_feature = "simd128"
))]
#[inline(always)]
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
let np = k & !(CurrentCpu::STEP - 1);
let mut sum = CurrentCpu::zero_array();
let mut x = CurrentCpu::zero_array();
for i in (0..np).step_by(CurrentCpu::STEP) {
for j in 0..CurrentCpu::n() {
x[j] = CurrentCpu::load(row.add(i + j * CurrentCpu::EPR));
sum[j] = CurrentCpu::vec_add(sum[j], x[j]);
}
}
CurrentCpu::vec_reduce(sum, b);
// leftovers
for i in np..k {
*b += *row.add(i)
}
}
#[cfg(not(any(
target_feature = "neon",
target_feature = "avx",
target_feature = "simd128"
)))]
#[inline(always)]
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
*b = 0f32;
for i in 0..k {
*b += *row.add(i)
}
}
#[cfg(target_feature = "avx")]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
let mut sumf = 0.0f32;
let np = k & !(CurrentCpuF16::STEP - 1);
let mut sum = CurrentCpuF16::zero_array();
let mut ax = CurrentCpuF16::zero_array();
let mut ay = CurrentCpuF16::zero_array();
for i in (0..np).step_by(CurrentCpuF16::STEP) {
for j in 0..CurrentCpuF16::n() {
ax[j] = CurrentCpuF16::load(a_row.add(i + j * CurrentCpuF16::EPR));
ay[j] = CurrentCpuF16::load(b_row.add(i + j * CurrentCpuF16::EPR));
sum[j] = CurrentCpuF16::vec_fma(sum[j], ax[j], ay[j]);
}
}
CurrentCpuF16::vec_reduce(sum, &mut sumf);
// leftovers
for i in np..k {
sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sumf;
}
#[cfg(not(target_feature = "avx"))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
// leftovers
let mut sum = 0.0;
for i in 0..k {
sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sum;
}

View File

@ -0,0 +1,74 @@
use super::Cpu;
#[cfg(target_arch = "arm")]
use core::arch::arm::*;
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
pub struct CurrentCpu {}
const STEP: usize = 16;
const EPR: usize = 4;
const ARR: usize = STEP / EPR;
impl CurrentCpu {
#[cfg(target_arch = "aarch64")]
unsafe fn reduce_one(x: float32x4_t) -> f32 {
vaddvq_f32(x)
}
#[cfg(target_arch = "arm")]
unsafe fn reduce_one(x: float32x4_t) -> f32 {
vgetq_lane_f32(x, 0) + vgetq_lane_f32(x, 1) + vgetq_lane_f32(x, 2) + vgetq_lane_f32(x, 3)
}
}
impl Cpu<ARR> for CurrentCpu {
type Unit = float32x4_t;
type Array = [float32x4_t; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
vdupq_n_f32(0.0)
}
unsafe fn from_f32(x: f32) -> Self::Unit {
vdupq_n_f32(x)
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
vld1q_f32(mem_addr)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
vaddq_f32(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
vfmaq_f32(a, b, c)
}
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
vst1q_f32(mem_addr, a);
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..ARR / 2 {
x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]);
}
for i in 0..ARR / 4 {
x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]);
}
*y = Self::reduce_one(x[0]);
}
}

View File

@ -0,0 +1,64 @@
use super::Cpu;
use core::arch::wasm32::*;
pub struct CurrentCpu {}
const STEP: usize = 16;
const EPR: usize = 4;
const ARR: usize = STEP / EPR;
impl Cpu<ARR> for CurrentCpu {
type Unit = v128;
type Array = [v128; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
f32x4_splat(0.0)
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn from_f32(v: f32) -> Self::Unit {
f32x4_splat(v)
}
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
v128_load(mem_addr as *mut v128)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
f32x4_add(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
f32x4_add(f32x4_mul(b, c), a)
}
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
v128_store(mem_addr as *mut v128, a);
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..ARR / 2 {
x[2 * i] = f32x4_add(x[2 * i], x[2 * i + 1]);
}
for i in 0..ARR / 4 {
x[4 * i] = f32x4_add(x[4 * i], x[4 * i + 2]);
}
for i in 0..ARR / 8 {
x[8 * i] = f32x4_add(x[8 * i], x[8 * i + 4]);
}
*y = f32x4_extract_lane::<0>(x[0])
+ f32x4_extract_lane::<1>(x[0])
+ f32x4_extract_lane::<2>(x[0])
+ f32x4_extract_lane::<3>(x[0]);
}
}

View File

@ -278,17 +278,17 @@ impl Map1Any for ReduceIndex {
}
}
struct Reduce<'a> {
struct ReduceSum<'a> {
dst_shape: &'a Shape,
reduce_dims: &'a [usize],
reduce_dims_and_stride: Vec<(usize, usize)>,
}
impl<'a> Reduce<'a> {
impl<'a> ReduceSum<'a> {
#[inline(always)]
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
where
T: Clone + Copy,
T: WithDType,
F: Fn(T, T) -> T,
{
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
@ -310,12 +310,15 @@ impl<'a> Reduce<'a> {
.iter()
.map(|(u, _)| u)
.product::<usize>();
let mut src_i = 0;
for dst_v in dst.iter_mut() {
for &s in src[src_i..src_i + reduce_sz].iter() {
*dst_v = f(*dst_v, s)
}
src_i += reduce_sz
for (dst_i, dst_v) in dst.iter_mut().enumerate() {
let src_i = dst_i * reduce_sz;
unsafe {
T::vec_reduce_sum(
src[src_i..src_i + reduce_sz].as_ptr(),
dst_v,
reduce_sz,
)
};
}
return Ok(dst);
};
@ -347,7 +350,7 @@ impl<'a> Reduce<'a> {
}
}
impl<'a> Map1 for Reduce<'a> {
impl<'a> Map1 for ReduceSum<'a> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
@ -633,6 +636,126 @@ impl Map1 for Affine {
}
}
struct AvgPool2D((usize, usize), (usize, usize));
impl Map1 for AvgPool2D {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
let (k_h, k_w) = self.0;
let (s_h, s_w) = self.1;
let (b_sz, c, h, w) = layout.shape().dims4()?;
let stride = layout.stride();
let (stride_h, stride_w) = (stride[2], stride[3]);
let h_out = (h - k_h) / s_h + 1;
let w_out = (w - k_w) / s_w + 1;
let src_index = layout.start_offset();
let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
let scale = 1f64 / (k_h * k_w) as f64;
let scale = T::from_f64(scale);
for b_idx in 0..b_sz {
let dst = &mut dst[b_idx * c * h_out * w_out..];
let src_index = src_index + b_idx * stride[0];
for c_idx in 0..c {
let dst = &mut dst[c_idx * h_out * w_out..];
let src_index = src_index + c_idx * stride[1];
for h_idx in 0..h_out {
for w_idx in 0..w_out {
let mut sum = T::zero();
for m in 0..k_h {
for n in 0..k_w {
let m = s_h * h_idx + m;
let n = s_w * w_idx + n;
sum += src[src_index + m * stride_h + n * stride_w]
}
}
dst[h_idx * w_out + w_idx] = sum * scale;
}
}
}
}
Ok(dst)
}
}
struct MaxPool2D((usize, usize), (usize, usize));
impl Map1 for MaxPool2D {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
let (k_h, k_w) = self.0;
let (s_h, s_w) = self.1;
let (b_sz, c, h, w) = layout.shape().dims4()?;
let stride = layout.stride();
let (stride_h, stride_w) = (stride[2], stride[3]);
let h_out = (h - k_h) / s_h + 1;
let w_out = (w - k_w) / s_w + 1;
let src_index = layout.start_offset();
let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
for b_idx in 0..b_sz {
let dst = &mut dst[b_idx * c * h_out * w_out..];
let src_index = src_index + b_idx * stride[0];
for c_idx in 0..c {
let dst = &mut dst[c_idx * h_out * w_out..];
let src_index = src_index + c_idx * stride[1];
for h_idx in 0..h_out {
for w_idx in 0..w_out {
let mut largest =
src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
for m in 0..k_h {
for n in 0..k_w {
let m = s_h * h_idx + m;
let n = s_w * w_idx + n;
if largest < src[src_index + m * stride_h + n * stride_w] {
largest = src[src_index + m * stride_h + n * stride_w]
}
}
}
dst[h_idx * w_out + w_idx] = largest;
}
}
}
}
Ok(dst)
}
}
struct UpsampleNearest2D(usize, usize);
impl Map1 for UpsampleNearest2D {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
// TODO: Specialized implementation for the case 2*h, 2*w?
let (dst_h, dst_w) = (self.0, self.1);
let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
let stride = layout.stride();
let (stride_h, stride_w) = (stride[2], stride[3]);
let src_index = layout.start_offset();
let scale_h = src_h as f64 / dst_h as f64;
let scale_w = src_w as f64 / dst_w as f64;
let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
let src_h_idxs = (0..dst_h)
.map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
.collect::<Vec<_>>();
let src_w_idxs = (0..dst_w)
.map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
.collect::<Vec<_>>();
for b_idx in 0..b_sz {
let dst = &mut dst[b_idx * c * dst_h * dst_w..];
let src_index = src_index + b_idx * stride[0];
for c_idx in 0..c {
let dst = &mut dst[c_idx * dst_h * dst_w..];
let src_index = src_index + c_idx * stride[1];
for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
dst[h_idx * dst_w + w_idx] = src[src_index]
}
}
}
}
Ok(dst)
}
}
struct Gather<'a, I: IntDType> {
ids: &'a [I],
ids_l: &'a Layout,
@ -903,56 +1026,152 @@ struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl<'a> Map2 for Conv1D<'a> {
const OP: &'static str = "conv1d";
fn f<T: 'static + num_traits::NumAssign + Copy>(
&self,
inp: &[T],
inp_l: &Layout,
k: &[T],
k_l: &Layout,
) -> Result<Vec<T>> {
// TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc).
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
let inp_stride = inp_l.stride();
let (inp_stride0, inp_stride) = if inp_stride.len() == 3 {
(inp_stride[0], &inp_stride[1..])
} else {
(0, inp_stride) // This value never gets used anyway
};
let k_stride = k_l.stride();
let k_over_2 = p.k_size / 2;
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
let mut dst = vec![T::zero(); dst_elems];
let dst_elems = p.c_out * l_out * p.b_size;
// The output shape is [b_size, c_out, l_out]
for b_idx in 0..p.b_size.unwrap_or(1) {
let inp_idx = b_idx * inp_stride0;
let dst_idx = b_idx * p.c_out * l_out;
for dst_c_idx in 0..p.c_out {
let dst_idx = dst_idx + dst_c_idx * l_out;
for dst_l in 0..l_out {
let dst_idx = dst_idx + dst_l;
let mut d = T::zero();
for offset in 0..p.k_size {
let src_l_plus = p.stride * dst_l + offset;
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in {
let src_l = src_l_plus - k_over_2;
for src_c_idx in 0..p.c_in {
let inp_idx =
inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
let k_idx = dst_c_idx * k_stride[0]
+ src_c_idx * k_stride[1]
+ offset * k_stride[2];
d += inp[inp_idx] * k[k_idx]
}
}
}
dst[dst_idx] = d
let dst = vec![T::zero(); dst_elems];
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
for b_idx in 0..p.b_size {
for src_l in 0..p.l_in {
for src_c_idx in 0..p.c_in {
let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
}
}
}
let num_threads = crate::utils::get_num_threads();
for offset in 0..p.k_size {
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
let dst_idx = dst_c_idx * l_out;
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
.collect::<Vec<_>>();
for b_idx in 0..p.b_size {
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
for dst_l in 0..l_out {
let dst_idx = dst_idx + dst_l;
let src_l = p.stride * dst_l + offset;
if src_l < p.padding || src_l >= p.padding + p.l_in {
continue;
}
let src_l = src_l - p.padding;
let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
assert!(inp_cont.len() >= p.c_in);
assert!(k_cont.len() >= p.c_in);
let mut d = T::zero();
unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
// the different tasks so no two threads can try to write at the same
// location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
})
}
Ok(dst)
}
}
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
const OP: &'static str = "conv2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
let k = &k[k_l.start_offset()..];
let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
let (out_h, out_w) = (p.out_h(), p.out_w());
// Output shape: [b_size, c_out, out_h, out_w].
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
let cont_s0 = p.i_h * p.i_w * p.c_in;
let cont_s1 = p.i_w * p.c_in;
let cont_s2 = p.c_in;
for b_idx in 0..p.b_size {
for h_idx in 0..p.i_h {
for w_idx in 0..p.i_w {
for c_idx in 0..p.c_in {
let src_idx =
b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
inp_cont[dst_idx] = inp[src_idx]
}
}
}
}
let num_threads = crate::utils::get_num_threads();
for offset_h in 0..p.k_h {
for offset_w in 0..p.k_w {
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
let dst_idx = dst_c_idx * out_w * out_h;
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
k[dst_c_idx * k_s0
+ c_in_idx * k_s1
+ offset_h * k_s2
+ offset_w * k_s3]
})
.collect::<Vec<_>>();
for b_idx in 0..p.b_size {
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
for dst_h in 0..out_h {
let dst_idx = dst_idx + dst_h * out_w;
let src_h = p.stride * dst_h + offset_h;
if src_h < p.padding || src_h >= p.i_h + p.padding {
continue;
}
let src_h = src_h - p.padding;
for dst_w in 0..out_w {
let dst_idx = dst_idx + dst_w;
let src_w = p.stride * dst_w + offset_w;
if src_w < p.padding || src_w >= p.i_w + p.padding {
continue;
}
let src_w = src_w - p.padding;
let inp_cont = &inp_cont
[b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
assert!(inp_cont.len() >= p.c_in);
assert!(k_cont.len() >= p.c_in);
let mut d = T::zero();
unsafe {
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
}
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
// the different tasks so no two threads can try to write at the same
// location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
}
});
}
}
Ok(dst)
}
}
@ -974,7 +1193,7 @@ impl MatMul {
impl Map2 for MatMul {
const OP: &'static str = "mat_mul";
#[cfg(not(feature = "mkl"))]
#[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],
@ -1053,6 +1272,109 @@ impl Map2 for MatMul {
Ok(dst)
}
#[cfg(feature = "accelerate")]
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],
lhs_l: &Layout,
rhs: &[T],
rhs_l: &Layout,
) -> Result<Vec<T>> {
let (b, m, n, k) = self.0;
let lhs = &lhs[lhs_l.start_offset()..];
let rhs = &rhs[rhs_l.start_offset()..];
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let c_skip: usize = m * n;
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
(n as i32, b'N')
} else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, b'T')
} else {
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
(k as i32, b'N')
} else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, b'T')
} else {
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
};
let mut dst = vec![T::zero(); b * m * n];
match T::DTYPE {
DType::F16 => {
crate::bail!("the accelerate backend does not support f16 matmul")
}
DType::F32 => {
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
let a = rhs_p.as_ptr() as *const f32;
let b = lhs_p.as_ptr() as *const f32;
let c = dst_p.as_mut_ptr() as *mut f32;
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
crate::accelerate::sgemm(
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
)
}
}
}
DType::F64 => {
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
let a = rhs_p.as_ptr() as *const f64;
let b = lhs_p.as_ptr() as *const f64;
let c = dst_p.as_mut_ptr() as *mut f64;
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
crate::accelerate::dgemm(
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
)
}
}
}
dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
}
Ok(dst)
}
#[cfg(feature = "mkl")]
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
@ -1379,7 +1701,7 @@ impl BackendStorage for CpuStorage {
.iter()
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
.collect();
Reduce {
ReduceSum {
dst_shape: &dst_shape,
reduce_dims: &reduce_dims,
reduce_dims_and_stride,
@ -1426,6 +1748,28 @@ impl BackendStorage for CpuStorage {
Affine(mul, add).map(self, layout)
}
fn avg_pool2d(
&self,
layout: &Layout,
kernel_size: (usize, usize),
stride: (usize, usize),
) -> Result<Self> {
AvgPool2D(kernel_size, stride).map(self, layout)
}
fn max_pool2d(
&self,
layout: &Layout,
kernel_size: (usize, usize),
stride: (usize, usize),
) -> Result<Self> {
MaxPool2D(kernel_size, stride).map(self, layout)
}
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
UpsampleNearest2D(h, w).map(self, layout)
}
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
match self {
@ -1612,6 +1956,16 @@ impl BackendStorage for CpuStorage {
Conv1D(params).map(self, l, kernel, kernel_l)
}
fn conv2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Conv2D(params).map(self, l, kernel, kernel_l)
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
@ -1767,35 +2121,36 @@ impl BackendDevice for CpuDevice {
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
let std = bf16::from_f64(std);
let mean = bf16::from_f64(mean);
let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean)
data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::BF16(data))
}
DType::F16 => {
let mut data = Vec::with_capacity(elem_count);
let std = f16::from_f64(std);
let mean = f16::from_f64(mean);
let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean)
data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F16(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let std = std as f32;
let mean = mean as f32;
let normal =
rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F32(data))
}
DType::F64 => {
let mut data = Vec::with_capacity(elem_count);
let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F64(data))
}

View File

@ -64,7 +64,7 @@ impl From<CudaError> for crate::Error {
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub(crate) struct DeviceId(usize);
pub struct DeviceId(usize);
impl DeviceId {
fn new() -> Self {
@ -111,6 +111,14 @@ impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
}
impl CudaDevice {
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
self.device.clone()
}
pub fn id(&self) -> DeviceId {
self.id
}
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
@ -897,14 +905,13 @@ impl<'a> Map2 for Conv1D<'a> {
// Kernel shape: (c_out, c_in_k, k_size)
// Input shape: (b_size, c_in, l_in) or (c_in, l_in)
let p = &self.0;
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(k_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
let l_out = p.l_out();
let dst_el = p.c_out * l_out * p.b_size.unwrap_or(1);
let dst_el = p.c_out * l_out * p.b_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
@ -917,7 +924,136 @@ impl<'a> Map2 for Conv1D<'a> {
panic!("unexpected input shape for conv1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (el, l_out, p.stride, &ds, inp, k, &out);
let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
inp_l: &Layout,
k: &CudaSlice<T>,
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
// Kernel shape: (c_out, c_in_k, w_k, h_k)
// Input shape: (b_size, c_in, w_in, c_in)
let p = &self.0;
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(k_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv2d"), kernels::CONV)?;
let ds = if dims.len() == 4 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else {
panic!("unexpected input shape for conv1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
enum PoolOp {
Max,
Avg,
}
struct Pool2D {
w_k: usize,
h_k: usize,
w_stride: usize,
h_stride: usize,
op: PoolOp,
}
impl Map1 for Pool2D {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
dev: &CudaDevice,
inp_l: &Layout,
) -> Result<CudaSlice<T>> {
// Input shape: (b_size, c, h, w)
let inp = &inp.slice(inp_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let ds = if dims.len() == 4 {
[dims, inp_l.stride()].concat()
} else {
panic!("unexpected input shape for conv1d {dims:?}")
};
let el = shape.elem_count();
let out_w = (dims[2] - self.w_k) / self.w_stride + 1;
let out_h = (dims[3] - self.h_k) / self.h_stride + 1;
let dst_el = out_w * out_h * dims[0] * dims[1];
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let kname = match self.op {
PoolOp::Max => "max_pool2d",
PoolOp::Avg => "avg_pool2d",
};
let func = dev.get_or_load_func(&kernel_name::<T>(kname), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let ds = dev.htod_copy(ds).w()?;
let params = (
el,
self.w_k,
self.h_k,
self.w_stride,
self.h_stride,
&ds,
inp,
&out,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct UpsampleNearest2D(usize, usize);
impl Map1 for UpsampleNearest2D {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
dev: &CudaDevice,
inp_l: &Layout,
) -> Result<CudaSlice<T>> {
// Input shape: (b_size, c, h, w)
let inp = &inp.slice(inp_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let ds = if dims.len() == 4 {
[dims, inp_l.stride()].concat()
} else {
panic!("unexpected input shape for conv1d {dims:?}")
};
let (out_w, out_h) = (self.0, self.1);
let dst_el = out_w * out_h * dims[0] * dims[1];
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let ds = dev.htod_copy(ds).w()?;
let scale_w = dims[2] as f64 / out_w as f64;
let scale_h = dims[3] as f64 / out_h as f64;
let params = (out_w, out_h, scale_w, scale_h, &ds, inp, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
@ -1381,6 +1517,114 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
#[cfg(not(feature = "cudnn"))]
fn conv2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
let device = self.device().clone();
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
}
#[cfg(feature = "cudnn")]
fn conv2d(
&self,
inp_l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
let device = self.device().clone();
if !kernel_l.is_contiguous() {
let slice = Conv2D(params).map(&self.slice, inp_l, &kernel.slice, kernel_l, &device)?;
return Ok(Self { slice, device });
}
let (out_w, out_h) = (params.out_w(), params.out_h());
let dst_el = params.c_out * out_w * out_h * params.b_size;
let slice = match (&self.slice, &kernel.slice) {
(S::U8(inp), S::U8(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<u8>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::U8(out)
}
(S::BF16(inp), S::BF16(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<bf16>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::BF16(out)
}
(S::F16(inp), S::F16(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f16>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F16(out)
}
(S::F32(inp), S::F32(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F32(out)
}
(S::F64(inp), S::F64(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f64>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F64(out)
}
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
_ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
};
Ok(Self { slice, device })
}
fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
let device = self.device().clone();
let slice = Pool2D {
w_k: k.0,
h_k: k.1,
w_stride: stride.0,
h_stride: stride.1,
op: PoolOp::Avg,
}
.map(&self.slice, &device, l)?;
Ok(Self { slice, device })
}
fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
let device = self.device().clone();
let slice = Pool2D {
w_k: k.0,
h_k: k.1,
w_stride: stride.0,
h_stride: stride.1,
op: PoolOp::Max,
}
.map(&self.slice, &device, l)?;
Ok(Self { slice, device })
}
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
let device = self.device().clone();
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
Ok(Self { slice, device })
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
let device = self.device().clone();
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;

107
candle-core/src/cudnn.rs Normal file
View File

@ -0,0 +1,107 @@
use crate::WithDType;
use cudarc;
use cudarc::cudnn::safe::{Conv2dForward, Cudnn};
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::Arc;
// The cudnn handles are stored per thread here rather than on the CudaDevice as they are neither
// send nor sync.
thread_local! {
static CUDNN: RefCell<HashMap<crate::cuda_backend::DeviceId, Arc<Cudnn>>> = HashMap::new().into();
}
impl From<cudarc::cudnn::CudnnError> for crate::Error {
fn from(err: cudarc::cudnn::CudnnError) -> Self {
crate::Error::wrap(err)
}
}
impl From<cudarc::driver::DriverError> for crate::Error {
fn from(err: cudarc::driver::DriverError) -> Self {
crate::Error::wrap(err)
}
}
pub(crate) fn launch_conv2d<
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
>(
src: &CudaView<T>,
src_l: &crate::Layout,
filter: &CudaView<T>,
dst: &mut CudaSlice<T>,
params: &crate::conv::ParamsConv2D,
dev: &crate::cuda_backend::CudaDevice,
) -> crate::Result<()> {
let device_id = dev.id();
let cudnn = CUDNN.with(|cudnn| {
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
return Ok(cudnn.clone());
}
let c = Cudnn::new(dev.cuda_device());
if let Ok(c) = &c {
cudnn.borrow_mut().insert(device_id, c.clone());
}
c
})?;
let conv = cudnn.create_conv2d::<T>(
/* pad */ [params.padding as i32, params.padding as i32],
/* stride */ [params.stride as i32, params.stride as i32],
/* dilation */ [1, 1],
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)?;
let x_shape = [
params.b_size as i32,
params.c_in as i32,
params.i_w as i32,
params.i_h as i32,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
x_shape,
)?
} else {
let s = src_l.stride();
cudnn.create_4d_tensor_ex(
x_shape,
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
)?
};
let w = cudnn.create_4d_filter(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[
params.c_out as i32,
params.c_in as i32,
params.k_w as i32,
params.k_h as i32,
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, w_out, h_out],
)?;
let conv2d = Conv2dForward {
conv: &conv,
x: &x,
w: &w,
y: &y,
};
let alg = conv2d.pick_algorithm()?;
let workspace_size = conv2d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
unsafe {
conv2d.launch::<CudaSlice<u8>, _, _, _>(
alg,
Some(&mut workspace),
(T::one(), T::zero()),
src,
filter,
dst,
)?;
}
Ok(())
}

View File

@ -101,6 +101,13 @@ impl Device {
}
}
pub fn is_cpu(&self) -> bool {
match self {
Self::Cpu => true,
Self::Cuda(_) => false,
}
}
pub fn is_cuda(&self) -> bool {
match self {
Self::Cpu => false,

View File

@ -43,7 +43,7 @@ impl DType {
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U8 => 4,
Self::U8 => 1,
Self::U32 => 4,
Self::BF16 => 2,
Self::F16 => 2,
@ -53,7 +53,17 @@ impl DType {
}
}
pub trait WithDType: Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + 'static {
pub trait WithDType:
Sized
+ Copy
+ num_traits::NumAssign
+ std::cmp::PartialOrd
+ std::fmt::Display
+ 'static
+ Send
+ Sync
+ crate::cpu::kernels::VecOps
{
const DTYPE: DType;
fn from_f64(v: f64) -> Self;

View File

@ -75,6 +75,16 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv2d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -119,6 +129,18 @@ impl crate::backend::BackendStorage for CudaStorage {
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
}
impl crate::backend::BackendDevice for CudaDevice {

View File

@ -185,6 +185,13 @@ pub enum Error {
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
/// Adding path information to an error.
#[error("path: {path:?} {inner}")]
WithPath {
inner: Box<Self>,
path: std::path::PathBuf,
},
#[error("{inner}\n{backtrace}")]
WithBacktrace {
inner: Box<Self>,
@ -214,6 +221,13 @@ impl Error {
},
}
}
pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
Self::WithPath {
inner: Box::new(self),
path: p.as_ref().to_path_buf(),
}
}
}
#[macro_export]

View File

@ -33,13 +33,18 @@
//!
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
#[cfg(feature = "accelerate")]
mod accelerate;
pub mod backend;
pub mod backprop;
mod conv;
mod convert;
pub mod cpu;
pub mod cpu_backend;
#[cfg(feature = "cuda")]
pub mod cuda_backend;
#[cfg(feature = "cudnn")]
pub mod cudnn;
mod device;
pub mod display;
mod dtype;
@ -51,6 +56,7 @@ pub mod layout;
mod mkl;
pub mod npy;
mod op;
pub mod quantized;
pub mod safetensors;
pub mod shape;
mod storage;

View File

@ -26,7 +26,7 @@
//! values = np.loadz("test.npz")
//! ```
use crate::{DType, Device, Error, Result, Shape, Tensor};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use byteorder::{LittleEndian, ReadBytesExt};
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::collections::HashMap;
use std::fs::File;
@ -307,42 +307,7 @@ impl Tensor {
header.push('\n');
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
f.write_all(header.as_bytes())?;
let elem_count = self.elem_count();
match self.dtype() {
DType::BF16 => {
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)?
}
}
DType::F64 => {
for v in self.reshape(elem_count)?.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)?
}
}
DType::U32 => {
for v in self.reshape(elem_count)?.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)?
}
}
DType::U8 => {
let data = self.reshape(elem_count)?.to_vec1::<u8>()?;
f.write_all(&data)?;
}
}
Ok(())
self.write_bytes(f)
}
/// Writes a multi-dimensional array in the npy format.
@ -373,7 +338,7 @@ pub struct NpzTensors {
index_per_name: HashMap<String, usize>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
// re-create a zip reader each time.
// re-create a zip reader for each tensor.
}
impl NpzTensors {

View File

@ -51,6 +51,7 @@ pub enum UnaryOp {
Cos,
Abs,
Neg,
Recip,
Sqr,
Sqrt,
Gelu,
@ -79,6 +80,28 @@ pub enum Op {
stride: usize,
},
#[allow(dead_code)]
Conv2D {
arg: Tensor,
kernel: Tensor,
padding: usize,
stride: usize,
},
AvgPool2D {
arg: Tensor,
kernel_size: (usize, usize),
stride: (usize, usize),
},
MaxPool2D {
arg: Tensor,
kernel_size: (usize, usize),
stride: (usize, usize),
},
UpsampleNearest2D(Tensor),
Cat(Vec<Tensor>, usize),
#[allow(dead_code)] // add is currently unused.
@ -264,6 +287,7 @@ pub(crate) struct Sin;
pub(crate) struct Cos;
pub(crate) struct Abs;
pub(crate) struct Neg;
pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
@ -314,6 +338,21 @@ macro_rules! bin_op {
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
crate::mkl::$f64_vec(xs1, xs2, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
crate::accelerate::$f32_vec(xs1, xs2, ys)
}
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
crate::accelerate::$f64_vec(xs1, xs2, ys)
}
}
};
}
@ -400,6 +439,21 @@ macro_rules! unary_op {
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::$f64_vec(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::$f32_vec(xs, ys)
}
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::$f64_vec(xs, ys)
}
}
};
}
@ -410,6 +464,7 @@ unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);

View File

@ -0,0 +1,215 @@
//! Support for the GGML file format.
use super::{k_quants, GgmlDType};
use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt};
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Magic {
Ggjt,
Ggla,
Ggmf,
Ggml,
Ggsn,
}
impl TryFrom<u32> for Magic {
type Error = crate::Error;
fn try_from(value: u32) -> Result<Self> {
let magic = match value {
0x67676a74 => Self::Ggjt,
0x67676c61 => Self::Ggla,
0x67676d66 => Self::Ggmf,
0x67676d6c => Self::Ggml,
0x6767736e => Self::Ggsn,
_ => crate::bail!("unknown magic {value:08x}"),
};
Ok(magic)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VersionedMagic {
GgmlUnversioned,
GgmfV1,
GgjtV1,
GgjtV2,
GgjtV3,
}
impl VersionedMagic {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let magic = reader.read_u32::<LittleEndian>()?;
let magic = Magic::try_from(magic)?;
if magic == Magic::Ggml {
return Ok(Self::GgmlUnversioned);
}
let version = reader.read_u32::<LittleEndian>()?;
let versioned_magic = match (magic, version) {
(Magic::Ggmf, 1) => Self::GgmfV1,
(Magic::Ggjt, 1) => Self::GgjtV1,
(Magic::Ggjt, 2) => Self::GgjtV2,
(Magic::Ggjt, 3) => Self::GgjtV3,
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
}
fn align32(&self) -> bool {
match self {
Self::GgmlUnversioned | Self::GgmfV1 => false,
Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HParams {
pub n_vocab: u32,
pub n_embd: u32,
pub n_mult: u32,
pub n_head: u32,
pub n_layer: u32,
pub n_rot: u32,
pub ftype: u32,
}
impl HParams {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let n_vocab = reader.read_u32::<LittleEndian>()?;
let n_embd = reader.read_u32::<LittleEndian>()?;
let n_mult = reader.read_u32::<LittleEndian>()?;
let n_head = reader.read_u32::<LittleEndian>()?;
let n_layer = reader.read_u32::<LittleEndian>()?;
let n_rot = reader.read_u32::<LittleEndian>()?;
let ftype = reader.read_u32::<LittleEndian>()?;
Ok(Self {
n_vocab,
n_embd,
n_mult,
n_head,
n_layer,
n_rot,
ftype,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Vocab {
pub token_score_pairs: Vec<(Vec<u8>, f32)>,
}
impl Vocab {
fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556
let mut token_score_pairs = Vec::with_capacity(n_vocab);
for _index in 0..n_vocab {
let len = reader.read_u32::<LittleEndian>()? as usize;
let mut word = vec![0u8; len];
reader.read_exact(&mut word)?;
let score = reader.read_f32::<LittleEndian>()?;
token_score_pairs.push((word, score))
}
Ok(Self { token_score_pairs })
}
}
fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8],
size_in_bytes: usize,
dims: Vec<usize>,
) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
Ok(super::QTensor::new(data.to_vec(), dims))
}
/// Creates a [Tensor] from a raw GGML tensor.
pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],
dims: Vec<usize>,
) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
let mut dims = vec![0u32; n_dims as usize];
reader.read_u32_into::<LittleEndian>(&mut dims)?;
let mut name = vec![0u8; name_len as usize];
reader.read_exact(&mut name)?;
let name = String::from_utf8_lossy(&name).into_owned();
if magic.align32() {
let pos = reader.stream_position()?;
reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
}
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
println!("{name} {ggml_dtype:?} {dims:?}");
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
}
}
pub struct Content {
pub magic: VersionedMagic,
pub hparams: HParams,
pub vocab: Vocab,
pub tensors: Vec<(String, super::QTensor)>,
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
let magic = VersionedMagic::read(reader)?;
let hparams = HParams::read(reader)?;
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
let mut tensors = vec![];
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic)?;
tensors.push((name, tensor))
}
Ok(Self {
magic,
hparams,
vocab,
tensors,
})
}
}

View File

@ -0,0 +1,802 @@
use super::GgmlDType;
use crate::Result;
use half::f16;
// Default to QK_K 256 rather than 64.
pub const QK_K: usize = 256;
pub const K_SCALE_SIZE: usize = 12;
pub const QK4_0: usize = 32;
pub const QK4_1: usize = 32;
pub const QK5_0: usize = 32;
pub const QK5_1: usize = 32;
pub const QK8_0: usize = 32;
pub const QK8_1: usize = 32;
pub trait GgmlType: Sized + Clone {
const DTYPE: GgmlDType;
const BLCK_SIZE: usize;
type VecDotType: GgmlType;
// This is only safe for types that include immediate values such as float/int/...
fn zeros() -> Self {
unsafe { std::mem::MaybeUninit::zeroed().assume_init() }
}
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>;
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>;
/// Dot product used as a building block for quantized mat-mul.
/// n is the number of elements to be considered.
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
}
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ4_0 {
d: f16,
qs: [u8; QK4_0 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ4_1 {
d: f16,
m: f16,
qs: [u8; QK4_1 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ5_0 {
d: f16,
qh: [u8; 4],
qs: [u8; QK5_0 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ5_1 {
d: f16,
m: f16,
qh: [u8; 4],
qs: [u8; QK5_1 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ8_0 {
d: f16,
qs: [u8; QK8_0],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ8_1 {
d: f16,
s: f16,
qs: [u8; QK8_1],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ2K {
scales: [u8; QK_K / 16],
qs: [u8; QK_K / 4],
d: f16,
dmin: f16,
}
const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ3K {
hmask: [u8; QK_K / 8],
qs: [u8; QK_K / 4],
scales: [u8; 12],
d: f16,
}
const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
#[derive(Debug, Clone, PartialEq)]
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
#[repr(C)]
pub struct BlockQ4K {
d: f16,
dmin: f16,
scales: [u8; K_SCALE_SIZE],
qs: [u8; QK_K / 2],
}
const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ5K {
d: f16,
dmin: f16,
scales: [u8; K_SCALE_SIZE],
qh: [u8; QK_K / 8],
qs: [u8; QK_K / 2],
}
const _: () =
assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ6K {
ql: [u8; QK_K / 2],
qh: [u8; QK_K / 4],
scales: [i8; QK_K / 16],
d: f16,
}
const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ8K {
d: f32,
qs: [i8; QK_K],
bsums: [i16; QK_K / 16],
}
const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::<BlockQ8K>());
impl GgmlType for BlockQ4_1 {
const DTYPE: GgmlDType = GgmlDType::Q4_1;
const BLCK_SIZE: usize = QK4_1;
type VecDotType = BlockQ8_1;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1545
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK4_1 != 0 {
crate::bail!("dequantize_row_q4_1: {k} is not divisible by {QK4_1}");
}
let nb = k / QK4_1;
for i in 0..nb {
let d = xs[i].d.to_f32();
let m = xs[i].m.to_f32();
for j in 0..(QK4_1 / 2) {
let x0 = xs[i].qs[j] & 0x0F;
let x1 = xs[i].qs[j] >> 4;
ys[i * QK4_1 + j] = (x0 as f32) * d + m;
ys[i * QK4_1 + j + QK4_1 / 2] = (x1 as f32) * d + m;
}
}
Ok(())
}
}
impl GgmlType for BlockQ5_0 {
const DTYPE: GgmlDType = GgmlDType::Q5_0;
const BLCK_SIZE: usize = QK5_0;
type VecDotType = BlockQ8_0;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1566
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK5_0 != 0 {
crate::bail!("dequantize_row_q5_0: {k} is not divisible by {QK5_0}");
}
let nb = k / QK5_0;
for i in 0..nb {
let d = xs[i].d.to_f32();
let qh: u32 = unsafe { std::mem::transmute_copy(&xs[i].qh) };
for j in 0..(QK5_0 / 2) {
let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
let x0 = ((xs[i].qs[j] & 0x0F) | xh_0) as i32 - 16;
let x1 = ((xs[i].qs[j] >> 4) | xh_1) as i32 - 16;
ys[i * QK5_0 + j] = (x0 as f32) * d;
ys[i * QK5_0 + j + QK5_0 / 2] = (x1 as f32) * d;
}
}
Ok(())
}
}
impl GgmlType for BlockQ5_1 {
const DTYPE: GgmlDType = GgmlDType::Q5_1;
const BLCK_SIZE: usize = QK5_1;
type VecDotType = BlockQ8_1;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1592
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK5_1 != 0 {
crate::bail!("dequantize_row_q5_1: {k} is not divisible by {QK5_1}");
}
let nb = k / QK5_1;
for i in 0..nb {
let d = xs[i].d.to_f32();
let m = xs[i].m.to_f32();
let qh: u32 = unsafe { std::mem::transmute_copy(&xs[i].qh) };
for j in 0..(QK5_1 / 2) {
let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
let x0 = (xs[i].qs[j] & 0x0F) | xh_0;
let x1 = (xs[i].qs[j] >> 4) | xh_1;
ys[i * QK5_1 + j] = (x0 as f32) * d + m;
ys[i * QK5_1 + j + QK5_1 / 2] = (x1 as f32) * d + m;
}
}
Ok(())
}
}
impl GgmlType for BlockQ2K {
const DTYPE: GgmlDType = GgmlDType::Q2K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}")
}
let mut ys_index = 0;
for x in xs {
let d = x.d.to_f32();
let min = x.dmin.to_f32();
let q = &x.qs;
let mut is = 0;
for n in (0..QK_K).step_by(128) {
// Step by 32 over q.
let q = &q[n / 4..];
let mut shift = 0;
for _j in 0..4 {
let sc = x.scales[is];
is += 1;
let dl = d * (sc & 0xF) as f32;
let ml = min * (sc >> 4) as f32;
for q in &q[..16] {
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
ys[ys_index] = y;
ys_index += 1;
}
let sc = x.scales[is];
is += 1;
let dl = d * (sc & 0xF) as f32;
let ml = min * (sc >> 4) as f32;
for q in &q[16..32] {
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
ys[ys_index] = y;
ys_index += 1;
}
shift += 2;
}
}
}
Ok(())
}
}
fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
if j < 4 {
let d = q[j] & 63;
let m = q[j + 4] & 63;
(d, m)
} else {
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
(d, m)
}
}
impl GgmlType for BlockQ4K {
const DTYPE: GgmlDType = GgmlDType::Q4K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}")
}
let mut ys_index = 0;
for x in xs.iter() {
let d = x.d.to_f32();
let min = x.dmin.to_f32();
let q = &x.qs;
let mut is = 0;
for j in (0..QK_K).step_by(64) {
let q = &q[j / 2..j / 2 + 32];
let (sc, m) = get_scale_min_k4(is, &x.scales);
let d1 = d * sc as f32;
let m1 = min * m as f32;
let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
let d2 = d * sc as f32;
let m2 = min * m as f32;
for q in q {
let y = d1 * (q & 0xF) as f32 - m1;
ys[ys_index] = y;
ys_index += 1;
}
for q in q {
let y = d2 * (q >> 4) as f32 - m2;
ys[ys_index] = y;
ys_index += 1;
}
is += 2;
}
}
Ok(())
}
}
impl GgmlType for BlockQ3K {
const DTYPE: GgmlDType = GgmlDType::Q3K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
todo!()
}
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928
impl GgmlType for BlockQ5K {
const DTYPE: GgmlDType = GgmlDType::Q5K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q5k: {k} is not divisible by {QK_K}")
}
let mut ys_index = 0;
for x in xs.iter() {
let d = x.d.to_f32();
let min = x.dmin.to_f32();
let ql = &x.qs;
let qh = &x.qh;
let mut is = 0;
let mut u1 = 1;
let mut u2 = 2;
for j in (0..QK_K).step_by(64) {
let ql = &ql[j / 2..j / 2 + 32];
let (sc, m) = get_scale_min_k4(is, &x.scales);
let d1 = d * sc as f32;
let m1 = min * m as f32;
let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
let d2 = d * sc as f32;
let m2 = min * m as f32;
for (ql, qh) in ql.iter().zip(qh) {
let to_add = if qh & u1 != 0 { 16 } else { 1 };
let y = d1 * ((ql & 0xF) + to_add) as f32 - m1;
ys[ys_index] = y;
ys_index += 1;
}
for (ql, qh) in ql.iter().zip(qh) {
let to_add = if qh & u2 != 0 { 16 } else { 1 };
let y = d2 * ((ql >> 4) + to_add) as f32 - m2;
ys[ys_index] = y;
ys_index += 1;
}
is += 2;
u1 <<= 2;
u2 <<= 2;
}
}
Ok(())
}
}
impl GgmlType for BlockQ6K {
const DTYPE: GgmlDType = GgmlDType::Q6K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
}
for x in xs.iter() {
let d = x.d.to_f32();
let ql = &x.ql;
let qh = &x.qh;
let sc = &x.scales;
for n in (0..QK_K).step_by(128) {
let idx = n / 128;
let ys = &mut ys[n..];
let sc = &sc[8 * idx..];
let ql = &ql[64 * idx..];
let qh = &qh[32 * idx..];
for l in 0..32 {
let is = l / 16;
let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32;
let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32;
let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32;
let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32;
ys[l] = d * sc[is] as f32 * q1 as f32;
ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32;
ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32;
ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32;
}
}
}
Ok(())
}
}
impl GgmlType for BlockQ8K {
const DTYPE: GgmlDType = GgmlDType::Q8K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
todo!()
}
}
impl GgmlType for BlockQ4_0 {
const DTYPE: GgmlDType = GgmlDType::Q4_0;
const BLCK_SIZE: usize = QK4_0;
type VecDotType = BlockQ8_0;
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK4_0 != 0 {
crate::bail!("dequantize_row_q4_0: {k} is not divisible by {QK4_0}")
}
let nb = k / QK4_0;
for i in 0..nb {
let d = xs[i].d.to_f32();
for j in 0..(QK4_0 / 2) {
let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8;
let x1 = (xs[i].qs[j] >> 4) as i16 - 8;
ys[i * QK4_0 + j] = (x0 as f32) * d;
ys[i * QK4_0 + j + QK4_0 / 2] = (x1 as f32) * d;
}
}
Ok(())
}
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
// quantize_row_q4_0
let qk = Self::BLCK_SIZE;
let k = xs.len();
if k % qk != 0 {
crate::bail!("{k} is not divisible by {}", qk);
};
let nb = k / qk;
if ys.len() != nb {
crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,)
}
for (i, ys) in ys.iter_mut().enumerate() {
let mut amax = 0f32;
let mut max = 0f32;
let xs = &xs[i * qk..(i + 1) * qk];
for &x in xs.iter() {
if amax < x.abs() {
amax = x.abs();
max = x;
}
}
let d = max / -8.0;
let id = if d != 0f32 { 1. / d } else { 0. };
ys.d = f16::from_f32(d);
for (j, q) in ys.qs.iter_mut().enumerate() {
let x0 = xs[j] * id;
let x1 = xs[qk / 2 + j] * id;
let xi0 = u8::min(15, (x0 + 8.5) as u8);
let xi1 = u8::min(15, (x1 + 8.5) as u8);
*q = xi0 | (xi1 << 4)
}
}
Ok(())
}
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
// Generic implementation.
let mut sumf = 0f32;
for i in 0..nb {
let mut sum_i = 0;
for j in 0..qk / 2 {
let v0 = (xs[i].qs[j] & 0x0F) as i32 - 8;
let v1 = (xs[i].qs[j] >> 4) as i32 - 8;
sum_i += v0 * ys[i].qs[j] as i32 + v1 * ys[i].qs[j + qk / 2] as i32
}
sumf += sum_i as f32 * f16::to_f32(xs[i].d) * f16::to_f32(ys[i].d)
}
Ok(sumf)
}
}
impl GgmlType for BlockQ8_0 {
const DTYPE: GgmlDType = GgmlDType::Q8_0;
const BLCK_SIZE: usize = QK8_0;
type VecDotType = BlockQ8_0;
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK8_0 != 0 {
crate::bail!("dequantize_row_q8_0: {k} is not divisible by {QK8_0}");
}
let nb = k / QK8_0;
for i in 0..nb {
let d = xs[i].d.to_f32();
for j in 0..QK8_0 {
ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d;
}
}
Ok(())
}
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
// quantize_row_q8_0
let k = xs.len();
if k % Self::BLCK_SIZE != 0 {
crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE);
};
let nb = k / Self::BLCK_SIZE;
if ys.len() != nb {
crate::bail!(
"size mismatch {} {} {}",
xs.len(),
ys.len(),
Self::BLCK_SIZE
)
}
for (i, ys) in ys.iter_mut().enumerate() {
let mut amax = 0f32;
let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
for &x in xs.iter() {
amax = amax.max(x.abs())
}
let d = amax / ((1 << 7) - 1) as f32;
let id = if d != 0f32 { 1. / d } else { 0. };
ys.d = f16::from_f32(d);
for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
*y = f32::round(x * id) as u8
}
}
Ok(())
}
fn vec_dot(_: usize, _: &[Self], _: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
}
impl GgmlType for BlockQ8_1 {
const DTYPE: GgmlDType = GgmlDType::Q3K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8_1;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
todo!()
}
}
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605
pub fn matmul<T: GgmlType>(
mkn: (usize, usize, usize),
lhs: &[f32],
rhs_t: &[T],
dst: &mut [f32],
) -> Result<()> {
let (m, k, n) = mkn;
if m * k != lhs.len() {
crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
}
let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE;
let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE;
// TODO: Do not make this copy if the DotType is f32.
// TODO: Pre-allocate this.
let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];
for row_idx in 0..m {
let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
let lhs = &lhs[row_idx * k..(row_idx + 1) * k];
T::VecDotType::from_float(lhs, lhs_b)?
}
let lhs_b = lhs_b.as_slice();
for row_idx in 0..m {
let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
for (col_idx, dst) in dst_row.iter_mut().enumerate() {
let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
*dst = T::vec_dot(k, rhs_col, lhs_row)?;
}
}
Ok(())
}
impl GgmlType for f32 {
const DTYPE: GgmlDType = GgmlDType::F32;
const BLCK_SIZE: usize = 1;
type VecDotType = f32;
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if xs.len() < n {
crate::bail!("size mismatch {} < {n}", xs.len())
}
if ys.len() < n {
crate::bail!("size mismatch {} < {n}", ys.len())
}
let mut res = 0f32;
unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
Ok(res)
}
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
if xs.len() != ys.len() {
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
}
ys.copy_from_slice(xs);
Ok(())
}
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
if xs.len() != ys.len() {
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
}
ys.copy_from_slice(xs);
Ok(())
}
}
impl GgmlType for f16 {
const DTYPE: GgmlDType = GgmlDType::F16;
const BLCK_SIZE: usize = 1;
type VecDotType = f16;
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if xs.len() < n {
crate::bail!("size mismatch {} < {n}", xs.len())
}
if ys.len() < n {
crate::bail!("size mismatch {} < {n}", ys.len())
}
let mut res = 0f32;
unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
Ok(res)
}
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
if xs.len() != ys.len() {
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
}
// TODO: vectorize
for (x, y) in xs.iter().zip(ys.iter_mut()) {
*y = f16::from_f32(*x)
}
Ok(())
}
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
if xs.len() != ys.len() {
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
}
// TODO: vectorize
for (x, y) in xs.iter().zip(ys.iter_mut()) {
*y = x.to_f32()
}
Ok(())
}
}

View File

@ -0,0 +1,194 @@
use crate::{Device, Result, Shape, Tensor};
pub mod ggml_file;
pub mod k_quants;
pub use k_quants::GgmlType;
pub struct QTensor {
data: Box<dyn QuantizedType>,
shape: Shape,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GgmlDType {
F32,
F16,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
Q8K,
}
impl GgmlDType {
pub(crate) fn from_u32(u: u32) -> Result<Self> {
let dtype = match u {
0 => Self::F32,
1 => Self::F16,
2 => Self::Q4_0,
3 => Self::Q4_1,
6 => Self::Q5_0,
7 => Self::Q5_1,
8 => Self::Q8_0,
9 => Self::Q8_1,
10 => Self::Q2K,
11 => Self::Q3K,
12 => Self::Q4K,
13 => Self::Q5K,
14 => Self::Q6K,
15 => Self::Q8K,
_ => crate::bail!("unknown dtype for tensor {u}"),
};
Ok(dtype)
}
fn type_size(&self) -> usize {
use k_quants::*;
match self {
Self::F32 => 4,
Self::F16 => 2,
Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
Self::Q2K => std::mem::size_of::<BlockQ2K>(),
Self::Q3K => std::mem::size_of::<BlockQ3K>(),
Self::Q4K => std::mem::size_of::<BlockQ4K>(),
Self::Q5K => std::mem::size_of::<BlockQ5K>(),
Self::Q6K => std::mem::size_of::<BlockQ6K>(),
Self::Q8K => std::mem::size_of::<BlockQ8K>(),
}
}
fn blck_size(&self) -> usize {
match self {
Self::F32 => 1,
Self::F16 => 1,
Self::Q4_0 => k_quants::QK4_0,
Self::Q4_1 => k_quants::QK4_1,
Self::Q5_0 => k_quants::QK5_0,
Self::Q5_1 => k_quants::QK5_1,
Self::Q8_0 => k_quants::QK8_0,
Self::Q8_1 => k_quants::QK8_1,
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
}
}
}
// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
pub trait QuantizedType: Send + Sync {
fn dtype(&self) -> GgmlDType;
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
}
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
}
fn dtype(&self) -> GgmlDType {
T::DTYPE
}
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
T::to_float(self.as_slice(), ys)
}
}
impl std::fmt::Debug for QTensor {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
}
}
impl QTensor {
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
data: Vec<T>,
shape: S,
) -> Self {
Self {
data: Box::new(data),
shape: shape.into(),
}
}
pub fn dtype(&self) -> GgmlDType {
self.data.dtype()
}
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let mut f32_data = vec![0f32; self.shape.elem_count()];
self.data.to_float(&mut f32_data)?;
Tensor::from_vec(f32_data, &self.shape, device)
}
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
self.data.matmul_t(mkn, lhs, dst)
}
}
#[derive(Debug, Clone)]
pub struct QMatMul(std::sync::Arc<QTensor>);
impl QMatMul {
pub fn new(qtensor: std::sync::Arc<QTensor>) -> Self {
Self(qtensor)
}
}
impl crate::CustomOp1 for QMatMul {
fn name(&self) -> &'static str {
"qmatmul"
}
fn cpu_fwd(
&self,
storage: &crate::CpuStorage,
layout: &crate::Layout,
) -> Result<(crate::CpuStorage, Shape)> {
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
let (k, n) = self.0.shape.dims2()?;
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!(
"input tensor {layout:?} incompatible with {:?}",
self.0.shape
)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
let storage = storage.as_slice::<f32>()?;
let storage =
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
self.0.matmul_t(
(dst_shape.elem_count() / n, k, n),
storage,
&mut dst_storage,
)?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
}
}

View File

@ -242,18 +242,28 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
let data = std::fs::read(filename.as_ref())?;
let st = safetensors::SafeTensors::deserialize(&data)?;
load_buffer(&data[..], device)
}
pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
let st = safetensors::SafeTensors::deserialize(data)?;
st.tensors()
.into_iter()
.map(|(name, view)| Ok((name, view.load(device)?)))
.collect()
}
pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Result<()> {
pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
tensors: &HashMap<K, Tensor>,
filename: P,
) -> Result<()> {
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
}
pub struct MmapedFile(memmap2::Mmap);
pub struct MmapedFile {
path: std::path::PathBuf,
inner: memmap2::Mmap,
}
impl MmapedFile {
/// Creates a wrapper around a memory mapped file from which you can retrieve
@ -263,13 +273,20 @@ impl MmapedFile {
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let file = std::fs::File::open(p)?;
let mmap = memmap2::MmapOptions::new().map(&file)?;
Ok(Self(mmap))
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let inner = memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| Error::from(e).with_path(p))?;
Ok(Self {
inner,
path: p.to_path_buf(),
})
}
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
let st = safetensors::SafeTensors::deserialize(&self.0)?;
let st = safetensors::SafeTensors::deserialize(&self.inner)
.map_err(|e| Error::from(e).with_path(&self.path))?;
Ok(st)
}
}

View File

@ -79,20 +79,25 @@ impl From<Vec<usize>> for Shape {
macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
impl Shape {
pub fn $fn_name(&self) -> Result<$out_type> {
if self.0.len() != $cnt {
Err(Error::UnexpectedNumberOfDims {
expected: $cnt,
got: self.0.len(),
shape: self.clone(),
}
.bt())
} else {
Ok($dims(&self.0))
pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
if dims.len() != $cnt {
Err(Error::UnexpectedNumberOfDims {
expected: $cnt,
got: dims.len(),
shape: Shape::from(dims),
}
.bt())
} else {
Ok($dims(dims))
}
}
impl Shape {
pub fn $fn_name(&self) -> Result<$out_type> {
$fn_name(self.0.as_slice())
}
}
impl crate::Tensor {
pub fn $fn_name(&self) -> Result<$out_type> {
self.shape().$fn_name()
@ -340,7 +345,7 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
extract_dims!(dims0, 0, |_: &Vec<usize>| (), ());
extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
extract_dims!(

View File

@ -266,6 +266,82 @@ impl Storage {
}
}
pub(crate) fn conv2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
self.same_device(kernel, "conv2d")?;
self.same_dtype(kernel, "conv2d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv2d",
}
.bt()),
}
}
pub(crate) fn avg_pool2d(
&self,
layout: &Layout,
kernel_size: (usize, usize),
stride: (usize, usize),
) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn max_pool2d(
&self,
layout: &Layout,
kernel_size: (usize, usize),
stride: (usize, usize),
) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn where_cond(
&self,
layout: &Layout,

View File

@ -269,6 +269,10 @@ impl Tensor {
Self::rand_impl(lo, up, s, device, false)
}
pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
}
pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
mean: T,
std: T,
@ -296,6 +300,17 @@ impl Tensor {
Ok(from_storage(storage, s, none, is_variable))
}
pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
Tensor::randn_f64_impl(
mean,
stdev,
self.shape(),
self.dtype(),
self.device(),
false,
)
}
/// Creates a new tensor initialized with values sampled from a normal distribution with the
/// specified `mean` and standard deviation `std`.
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
@ -474,6 +489,7 @@ impl Tensor {
broadcast_binary_op!(broadcast_sub, sub);
broadcast_binary_op!(broadcast_div, div);
unary_op!(recip, Recip);
unary_op!(neg, Neg);
unary_op!(exp, Exp);
unary_op!(log, Log);
@ -548,6 +564,32 @@ impl Tensor {
}
}
/// Split a tensor into the specified number of chunks, this may return less chunks than
/// specificed.
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
let dim = dim.to_index(self.shape(), "chunk")?;
let size = self.dim(dim)?;
if size < chunks {
(0..size).map(|i| self.narrow(dim, i, 1)).collect()
} else {
let chunk_size = size / chunks;
let cnt_additional = size % chunks;
let mut tensors = vec![];
let mut sum_chunk_size = 0;
for i in 0..chunks {
let chunk_size = if i < cnt_additional {
chunk_size + 1
} else {
chunk_size
};
let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
tensors.push(tensor);
sum_chunk_size += chunk_size
}
Ok(tensors)
}
}
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + len`.
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
@ -731,18 +773,7 @@ impl Tensor {
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = match *self.dims() {
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
[c_in, l_in] => (None, c_in, l_in),
_ => Err(Error::Conv1dInvalidArgs {
inp_shape: self.shape().clone(),
k_shape: kernel.shape().clone(),
padding,
stride,
msg: "input rank is not 2 or 3",
}
.bt())?,
};
let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
Err(Error::Conv1dInvalidArgs {
inp_shape: self.shape().clone(),
@ -775,6 +806,77 @@ impl Tensor {
Ok(from_storage(storage, out_dims, op, false))
}
pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = crate::conv::ParamsConv2D {
b_size,
i_h,
i_w,
k_h,
k_w,
c_out,
c_in,
padding,
stride,
};
let storage =
self.storage()
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
arg,
kernel,
padding,
stride,
});
let out_dims = params.out_dims();
Ok(from_storage(storage, out_dims, op, false))
}
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
let storage = self
.storage()
.upsample_nearest2d(self.layout(), target_h, target_w)?;
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
}
pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
let (n, c, h, w) = self.dims4()?;
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
arg,
kernel_size,
stride,
});
let storage = self
.storage()
.avg_pool2d(self.layout(), kernel_size, stride)?;
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
let (n, c, h, w) = self.dims4()?;
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
arg,
kernel_size,
stride,
});
let storage = self
.storage()
.max_pool2d(self.layout(), kernel_size, stride)?;
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
///
/// # Arguments
@ -1717,6 +1819,32 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
if left == 0 && right == 0 {
Ok(self.clone())
} else if left == 0 {
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
let mut dims = self.dims().to_vec();
dims[dim] = right;
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
Tensor::cat(&[self, &right], dim)
} else if right == 0 {
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
let mut dims = self.dims().to_vec();
dims[dim] = left;
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
Tensor::cat(&[&left, self], dim)
} else {
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
let mut dims = self.dims().to_vec();
dims[dim] = left;
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
dims[dim] = right;
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
Tensor::cat(&[&left, self, &right], dim)
}
}
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}

View File

@ -11,16 +11,14 @@ pub fn get_num_threads() -> usize {
}
}
pub fn has_accelerate() -> bool {
cfg!(feature = "accelerate")
}
pub fn has_mkl() -> bool {
#[cfg(feature = "mkl")]
return true;
#[cfg(not(feature = "mkl"))]
return false;
cfg!(feature = "mkl")
}
pub fn cuda_is_available() -> bool {
#[cfg(feature = "cuda")]
return true;
#[cfg(not(feature = "cuda"))]
return false;
cfg!(feature = "cuda")
}

View File

@ -0,0 +1,178 @@
mod test_utils;
use anyhow::Result;
use candle_core::{Device, Tensor};
/* This test is based on the following script.
import torch
torch.manual_seed(4242)
t = torch.randn((1, 4, 5))
w = torch.randn((2, 4, 3))
print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv1d(t, w)
print(res.flatten())
res = torch.nn.functional.conv1d(t, w, padding=1)
print(res.flatten())
*/
fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
1.8025, -0.1536, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278, -1.0124, 0.5599,
],
dev,
)?
.reshape((1, 4, 5))?;
let w = Tensor::new(
&[
-0.8404f32, -0.3490, 0.0130, 1.3123, 0.1763, -1.9249, 1.4270, 0.9421, 0.8670, -0.7181,
-1.1111, 0.8869, -1.2429, 1.8357, 1.6052, -1.3844, 0.3951, -1.2036, 0.6686, 1.6261,
-0.6451, -0.0840, -1.4247, 0.5512,
],
dev,
)?
.reshape((2, 4, 3))?;
let res = t.conv1d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 2, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
);
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
assert_eq!(res.dims(), [1, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
Ok(())
}
fn conv1d_small(dev: &Device) -> Result<()> {
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
let res = t.conv1d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.4056, -0.8689]
);
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.0, 0.4056, -0.8689, -0.0773],
);
Ok(())
}
/* This test is based on the following script.
import torch
torch.manual_seed(4242)
t = torch.randn((1, 4, 5, 5))
w = torch.randn((2, 4, 3, 3))
print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
*/
fn conv2d(dev: &Device) -> Result<()> {
let t = Tensor::new(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
],
dev,
)?;
let w = Tensor::new(
&[
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
0.5583, 0.4623, 0.6026,
],
dev,
)?;
let t = t.reshape((1, 4, 5, 5))?;
let w = w.reshape((2, 4, 3, 3))?;
let res = t.conv2d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
Ok(())
}
/* This test is based on the following script.
import torch
torch.manual_seed(4242)
t = torch.randn((1, 2, 3, 3))
w = torch.randn((1, 2, 1, 1))
print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
*/
fn conv2d_small(dev: &Device) -> Result<()> {
let t = Tensor::new(
&[
0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
-0.6266, 0.3529, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278,
],
dev,
)?;
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
let t = t.reshape((1, 2, 3, 3))?;
let w = w.reshape((1, 2, 1, 1))?;
let res = t.conv2d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
);
Ok(())
}
fn conv2d_smaller(dev: &Device) -> Result<()> {
let t = Tensor::new(
&[
0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866,
],
dev,
)?;
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
let t = t.reshape((1, 1, 3, 3))?;
let w = w.reshape((1, 1, 3, 3))?;
let res = t.conv2d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 1, 1]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[-0.6197]
);
Ok(())
}
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);

View File

@ -85,8 +85,14 @@ fn unary_grad(device: &Device) -> Result<()> {
let y = (x.log()? + 1.)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [2.0986123, 1.0, 2.3862944, -0.89712]);
assert_eq!(grad_x.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.0986, 1.0, 2.3863, -0.8971]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.3333, 1.0, 0.25, 6.6667]
);
let y = x.exp()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
@ -141,7 +147,7 @@ fn unary_grad(device: &Device) -> Result<()> {
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [3.0, 1.0, 4.0, 0.15]);
assert_eq!(grad_x.to_vec1::<f32>()?, [1.0, 1.0, 1.0, 1.0]);
assert_eq!(test_utils::to_vec1_round(grad_x, 4)?, [1.0, 1.0, 1.0, 1.0]);
let y = x.neg()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
@ -155,7 +161,10 @@ fn unary_grad(device: &Device) -> Result<()> {
let y = Tensor::new(1f32, device)?.broadcast_div(x)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[0.3333, 1.0, 0.25, 6.6667]
);
assert_eq!(
grad_x.to_vec1::<f32>()?,
[-0.11111111, -1.0, -0.0625, -44.444443],

View File

@ -0,0 +1,89 @@
mod test_utils;
use candle_core::{Device, IndexOp, Result, Tensor};
// https://github.com/huggingface/candle/issues/364
fn avg_pool2d(dev: &Device) -> Result<()> {
let data: Vec<f32> = vec![
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
Ok(())
}
fn max_pool2d(dev: &Device) -> Result<()> {
let data: Vec<f32> = vec![
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
];
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
Ok(())
}
/* This test corresponds to the following PyTorch script.
import torch
torch.manual_seed(4242)
t = torch.randn((1, 2, 4, 4))
print(t.flatten())
res = torch.nn.functional.avg_pool2d(t, 2)
print(res)
*/
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
let t = Tensor::new(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
0.2477, 1.3127,
],
dev,
)?
.reshape((1, 2, 4, 4))?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
assert_eq!(
test_utils::to_vec3_round(pool, 4)?,
[
[[-1.1926, -0.0395], [0.2688, 0.1871]],
[[0.1835, -0.1606], [0.6249, 0.3217]]
]
);
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
Ok(())
}
fn upsample_nearest2d(dev: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 6f32, dev)?.reshape((1, 1, 2, 3))?;
let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?;
assert_eq!(
t.i(0)?.i(0)?.to_vec2::<f32>()?,
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
);
assert_eq!(
upsampled.to_vec2::<f32>()?,
[
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0]
]
);
Ok(())
}
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
test_device!(
avg_pool2d_pytorch,
avg_pool2d_pytorch_cpu,
avg_pool2d_pytorch_gpu
);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
test_device!(
upsample_nearest2d,
upsample_nearest2d_cpu,
upsample_nearest2d_gpu
);

View File

@ -0,0 +1,46 @@
use candle_core::{quantized, Device, Result, Tensor};
use quantized::{k_quants, GgmlType};
#[test]
fn quantized_matmul() -> Result<()> {
let cpu = &Device::Cpu;
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!(
dst,
&[
85120.43, 214561.61, 345454.9, 474748.1, 213474.94, 604465.25, 1000686.4, 1388317.3,
341875.88, 994283.0, 1655708.8, 2301518.3
]
);
let mm = tensor_lhs.matmul(&tensor_rhs)?;
assert_eq!(
mm.to_vec2::<f32>()?,
&[
[85344.0, 214368.0, 343392.0, 472416.0],
[214368.0, 605536.0, 996704.0, 1387872.0],
[343392.0, 996704.0, 1650016.0, 2303328.0]
]
);
let qtensor = quantized::QTensor::new(rhs_t, (64, 4));
let op = quantized::QMatMul::new(std::sync::Arc::new(qtensor));
let res = tensor_lhs.custom_op1(op)?;
assert_eq!(
res.to_vec2::<f32>()?,
&[
[85120.43, 214561.61, 345454.9, 474748.1],
[213474.94, 604465.25, 1000686.4, 1388317.3],
[341875.88, 994283.0, 1655708.8, 2301518.3]
]
);
Ok(())
}

View File

@ -869,3 +869,14 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381
#[test]
fn randn_hasneg() -> Result<()> {
let t = Tensor::randn(0f32, 1f32, 200, &Device::Cpu)?.to_vec1::<f32>()?;
if t.iter().all(|&v| v >= 0.) {
candle_core::bail!("all values in tensors are non-negative")
}
Ok(())
}

View File

@ -1,5 +1,8 @@
#![allow(dead_code)]
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_core::{Result, Tensor};
#[macro_export]

View File

@ -0,0 +1,20 @@
[package]
name = "candle-datasets"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
readme = "README.md"
[dependencies]
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.1.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.1.1" }
hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
rand = { workspace = true }

View File

@ -0,0 +1 @@
# candle-datasets

View File

@ -0,0 +1,6 @@
//! Datasets & Dataloaders for Candle
pub mod batcher;
pub mod nlp;
pub mod vision;
pub use batcher::Batcher;

View File

@ -0,0 +1 @@
pub mod tinystories;

View File

@ -0,0 +1,122 @@
//! Helper functions for the tinystories dataset. This uses the pre-tokenized version as generated
//! by the tools from https://github.com/karpathy/llama2.c
use candle::{Device, Result, Tensor};
pub struct Dataset {
valid_tokens: Vec<memmap2::Mmap>,
train_tokens: Vec<memmap2::Mmap>,
}
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
let file = std::fs::File::open(p)?;
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
Ok(mmap)
}
impl Dataset {
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
let dir = dir.as_ref();
let mut bin_files = vec![];
for file in std::fs::read_dir(dir)?.flatten() {
let file = file.path();
if let Some(extension) = file.extension() {
if extension == "bin" {
bin_files.push(file)
}
}
}
if bin_files.len() < 2 {
candle::bail!("found less than two bin files in {:?}", dir)
}
bin_files.sort();
let valid_tokens = mmap_file(&bin_files[0])?;
let train_tokens = bin_files[1..]
.iter()
.map(mmap_file)
.collect::<Result<Vec<_>>>()?;
Ok(Self {
valid_tokens: vec![valid_tokens],
train_tokens,
})
}
pub fn train_tokens(&self) -> usize {
self.train_tokens.len()
}
pub fn valid_tokens(&self) -> usize {
self.valid_tokens.len()
}
}
pub struct DatasetRandomIter<'a> {
all_tokens: &'a [memmap2::Mmap],
tokens: Vec<&'a memmap2::Mmap>,
current_tokens: &'a memmap2::Mmap,
indexes_in_bytes: Vec<usize>,
seq_len: usize,
device: Device,
}
impl<'a> DatasetRandomIter<'a> {
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
use rand::seq::SliceRandom;
use rand::thread_rng;
let all_tokens = if valid {
&ds.valid_tokens
} else {
&ds.train_tokens
};
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
tokens.shuffle(&mut thread_rng());
let current_tokens = tokens.pop().unwrap();
let seq_len_in_bytes = seq_len * 2;
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
indexes_in_bytes.shuffle(&mut thread_rng());
Self {
all_tokens,
tokens,
current_tokens,
indexes_in_bytes,
seq_len,
device,
}
}
}
impl<'a> Iterator for DatasetRandomIter<'a> {
type Item = Result<(Tensor, Tensor)>;
fn next(&mut self) -> Option<Self::Item> {
use byteorder::{LittleEndian, ReadBytesExt};
use rand::seq::SliceRandom;
use rand::thread_rng;
let seq_len = self.seq_len;
if self.indexes_in_bytes.is_empty() {
if self.tokens.is_empty() {
self.tokens = self.all_tokens.iter().collect();
self.tokens.shuffle(&mut thread_rng());
}
self.current_tokens = self.tokens.pop().unwrap();
let seq_len_in_bytes = self.seq_len * 2;
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
self.indexes_in_bytes.shuffle(&mut thread_rng());
}
let start_idx = self.indexes_in_bytes.pop().unwrap();
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
let mut tokens = vec![0u16; bytes.len() / 2];
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
return Some(Err(err.into()));
}
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
let targets = Tensor::new(&tokens[1..], &self.device);
Some(candle::error::zip(inputs, targets))
}
}

View File

@ -10,10 +10,12 @@ license.workspace = true
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.1.0" }
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.1.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.1.1" }
candle-nn = { path = "../candle-nn", version = "0.1.1" }
candle-transformers = { path = "../candle-transformers", version = "0.1.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.1", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
@ -21,12 +23,13 @@ num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
byteorder = { workspace = true }
hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true }
hf-hub = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
@ -34,13 +37,17 @@ tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1"
[build-dependencies]
anyhow = { workspace = true }
[features]
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
@ -48,3 +55,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
[[example]]
name = "llama_multiprocess"
required-features = ["cuda", "nccl", "flash-attn"]
[[example]]
name = "stable-diffusion"
required-features = ["image"]

View File

@ -39,6 +39,10 @@ struct Args {
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
}
impl Args {
@ -107,7 +111,10 @@ fn main() -> Result<()> {
let device = &model.device;
if let Some(prompt) = args.prompt {
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
@ -164,7 +171,13 @@ fn main() -> Result<()> {
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if args.normalize_embeddings {
normalize_l2(&embeddings)?
} else {
embeddings
};
println!("pooled embeddings {:?}", embeddings.shape());
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = embeddings.get(i)?;
@ -184,3 +197,7 @@ fn main() -> Result<()> {
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}

View File

@ -65,10 +65,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
let token = self
.tokenizer
.decode(vec![next_token], true)
.map_err(E::msg)?;
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}

View File

@ -1,5 +1,8 @@
// TODO: Add an offline mode.
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@ -69,16 +72,14 @@ impl TextGeneration {
"{} token: {} '{}'",
index + 1,
next_token,
self.tokenizer
.decode(vec![next_token], true)
.map_err(E::msg)?
self.tokenizer.decode(&[next_token], true).map_err(E::msg)?
);
}
let dt = start_gen.elapsed();
println!(
"{sample_len} tokens generated ({} token/s)\n----\n{}\n----",
sample_len as f64 / dt.as_secs_f64(),
self.tokenizer.decode(new_tokens, true).map_err(E::msg)?
self.tokenizer.decode(&new_tokens, true).map_err(E::msg)?
);
Ok(())
}

View File

@ -0,0 +1,28 @@
use anyhow::Result;
use clap::Parser;
use std::fs::File;
use candle::quantized::ggml_file::Content;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
#[arg(long)]
model: String,
}
fn main() -> Result<()> {
let args = Args::parse();
let mut file = File::open(args.model)?;
let start = std::time::Instant::now();
let model = Content::read(&mut file)?;
println!(
"Loaded {:?} tensors in {:?}",
model.tensors.len(),
start.elapsed()
);
Ok(())
}

View File

@ -1,199 +0,0 @@
# Adapted from:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
import argparse
import gc
import json
import math
import os
import shutil
import warnings
import torch
import numpy as np
"""
Sample usage:
```
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
```
"""
INTERMEDIATE_SIZE_MAP = {
"7B": 11008,
"13B": 13824,
"30B": 17920,
"65B": 22016,
}
NUM_SHARDS = {
"7B": 1,
"13B": 2,
"30B": 4,
"65B": 8,
}
def compute_intermediate_size(n):
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(model_path, input_base_path, model_size):
os.makedirs(model_path, exist_ok=True)
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
# permute for sliced rotary
def permute(w):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if model_size == "7B":
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
else:
# Sharded
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
for i in range(num_shards)
]
param_count = 0
all_dicts = {}
for layer_i in range(n_layers):
if model_size == "7B":
# Unsharded
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wq.weight"]
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wk.weight"]
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
}
else:
# Sharded
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
state_dict = {
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
f"layers.{layer_i}.attention_norm.weight"
].clone(),
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
f"layers.{layer_i}.ffn_norm.weight"
].clone(),
}
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
)
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
all_dicts |= state_dict
if model_size == "7B":
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
"model.norm.weight": loaded["norm.weight"],
"lm_head.weight": loaded["output.weight"],
}
else:
state_dict = {
"model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
}
all_dicts |= state_dict
all_dicts = {k: v.numpy() for k, v in all_dicts.items()}
np.savez(os.path.join(model_path, "llama.npz"), **all_dicts)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"--model_size",
choices=["7B", "13B", "30B", "65B"],
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=os.path.join(args.input_dir, args.model_size),
model_size=args.model_size,
)
if __name__ == "__main__":
main()

View File

@ -5,9 +5,9 @@
//
// The tokenizer config can be retrieved from:
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
//
// In order to convert the llama weights to a .npz file, run:
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@ -19,62 +19,14 @@ use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::api::sync::Api;
use std::io::Write;
mod model;
use model::{Config, Llama};
const EOS_TOKEN: &str = "</s>";
const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = r"
EDWARD:
I wonder how our princely father 'scaped,
Or whether he be 'scaped away or no
From Clifford's and Northumberland's pursuit:
Had he been ta'en, we should have heard the news;
Had he been slain, we should have heard the news;
Or had he 'scaped, methinks we should have heard
The happy tidings of his good escape.
How fares my brother? why is he so sad?
RICHARD:
I cannot joy, until I be resolved
Where our right valiant father is become.
I saw him in the battle range about;
And watch'd him how he singled Clifford forth.
Methought he bore him in the thickest troop
As doth a lion in a herd of neat;
Or as a bear, encompass'd round with dogs,
Who having pinch'd a few and made them cry,
The rest stand all aloof, and bark at him.
So fared our father with his enemies;
So fled his enemies my warlike father:
Methinks, 'tis prize enough to be his son.
See how the morning opes her golden gates,
And takes her farewell of the glorious sun!
How well resembles it the prime of youth,
Trimm'd like a younker prancing to his love!
EDWARD:
Dazzle mine eyes, or do I see three suns?
RICHARD:
Three glorious suns, each one a perfect sun;
Not separated with the racking clouds,
But sever'd in a pale clear-shining sky.
See, see! they join, embrace, and seem to kiss,
As if they vow'd some league inviolable:
Now are they but one lamp, one light, one sun.
In this the heaven figures some event.
EDWARD:
'Tis wondrous strange, the like yet never heard of.
I think it cites us, brother, to the field,
That we, the sons of brave Plantagenet,
Each one already blazing by our meeds,
Should notwithstanding join our lights together
And over-shine the earth as this the world.
Whate'er it bodes, henceforward will I bear
Upon my target three fair-shining suns.
";
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
@ -111,6 +63,10 @@ struct Args {
#[arg(long)]
use_f32: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
model_id: Option<String>,
@ -119,12 +75,27 @@ struct Args {
#[arg(long)]
use_flash_attn: bool,
/// The folder name that contains safetensor weights and json files
/// (same structure as huggingface online)
#[arg(long)]
local_weights: Option<String>,
}
fn main() -> Result<()> {
use tokenizers::Tokenizer;
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?;
let config = if args.v1 {
@ -151,14 +122,26 @@ fn main() -> Result<()> {
});
println!("loading the model weights from {model_id}");
let api = api.model(model_id);
let tokenizer_filename = api.get("tokenizer.json")?;
let tokenizer_filename = match &args.local_weights {
Some(path) => (path.to_owned() + "tokenizer.json").into(),
_ => api.get("tokenizer.json")?,
};
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = api.get(rfilename)?;
filenames.push(filename);
match &args.local_weights {
Some(path) => {
filenames.push((path.to_owned() + rfilename).into());
}
_ => {
let filename = api.get(rfilename)?;
filenames.push(filename);
}
};
}
println!("building the model");
@ -176,6 +159,7 @@ fn main() -> Result<()> {
}
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@ -184,12 +168,12 @@ fn main() -> Result<()> {
.to_vec();
println!("starting the inference loop");
print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
let context_size = if cache.use_kv_cache && index > 0 {
1
} else {
@ -202,22 +186,27 @@ fn main() -> Result<()> {
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;
token_generated += 1;
tokens.push(next_token);
new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
);
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
if Some(next_token) == eos_token_id {
break;
}
}
let dt = start_gen.elapsed();
println!(
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer.decode(new_tokens, true).map_err(E::msg)?
"\n\n{} tokens generated ({} token/s)\n",
token_generated,
token_generated as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Linear, VarBuilder};
use candle_nn::{Embedding, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -47,6 +47,21 @@ impl Config {
}
}
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
@ -106,8 +121,9 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
@ -118,15 +134,18 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
struct RmsNorm {
scale: Tensor,
eps: f64,
span: tracing::Span,
}
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = vb.get(size, "weight")?;
Ok(Self { scale, eps })
Ok(Self { scale, eps, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
@ -155,6 +174,8 @@ struct CausalSelfAttention {
head_dim: usize,
cache: Cache,
use_flash_attn: bool,
span: tracing::Span,
span_rot: tracing::Span,
}
#[cfg(feature = "flash-attn")]
@ -175,6 +196,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
@ -188,6 +210,7 @@ impl CausalSelfAttention {
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
@ -269,6 +292,8 @@ impl CausalSelfAttention {
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let size_in = cfg.hidden_size;
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
@ -286,6 +311,8 @@ impl CausalSelfAttention {
head_dim: cfg.hidden_size / cfg.n_head,
cache: cache.clone(),
use_flash_attn: cfg.use_flash_attn,
span,
span_rot,
})
}
}
@ -301,15 +328,18 @@ struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
c_proj: Linear,
span: tracing::Span,
}
impl Mlp {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "mlp");
let h_size = cfg.hidden_size;
let i_size = cfg.intermediate_size;
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
@ -319,6 +349,7 @@ impl Mlp {
c_fc1,
c_fc2,
c_proj,
span,
})
}
}
@ -328,10 +359,12 @@ struct Block {
attn: CausalSelfAttention,
rms_2: RmsNorm,
mlp: Mlp,
span: tracing::Span,
}
impl Block {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
@ -341,6 +374,7 @@ impl Block {
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "block");
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
@ -354,6 +388,7 @@ impl Block {
attn,
rms_2,
mlp,
span,
})
}
}

View File

@ -1,5 +1,8 @@
// https://github.com/karpathy/llama2.c
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@ -27,7 +30,7 @@ struct InferenceCmd {
#[arg(long, default_value = "")]
prompt: String,
/// Config file in binary format.
/// Config file in binary or safetensors format.
#[arg(long)]
config: Option<String>,
@ -200,7 +203,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
}
});
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;
@ -225,11 +228,22 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?;
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let is_safetensors = config_path
.extension()
.map_or(false, |v| v == "safetensors");
let (vb, config) = if is_safetensors {
let config = Config::tiny();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
(vb, config)
} else {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
(vb, config)
};
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;

View File

@ -1,118 +1,6 @@
#![allow(dead_code)]
#![allow(unused)]
use crate::model::{Cache, Config, Llama};
use candle::{DType, Device, Result, Tensor};
pub struct Dataset {
valid_tokens: Vec<memmap2::Mmap>,
train_tokens: Vec<memmap2::Mmap>,
}
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
let file = std::fs::File::open(p)?;
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
Ok(mmap)
}
impl Dataset {
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
let dir = dir.as_ref();
let mut bin_files = vec![];
for file in std::fs::read_dir(dir)?.flatten() {
let file = file.path();
if let Some(extension) = file.extension() {
if extension == "bin" {
bin_files.push(file)
}
}
}
if bin_files.len() < 2 {
candle::bail!("found less than two bin files in {:?}", dir)
}
bin_files.sort();
let valid_tokens = mmap_file(&bin_files[0])?;
let train_tokens = bin_files[1..]
.iter()
.map(mmap_file)
.collect::<Result<Vec<_>>>()?;
Ok(Self {
valid_tokens: vec![valid_tokens],
train_tokens,
})
}
}
struct DatasetRandomIter<'a> {
all_tokens: &'a [memmap2::Mmap],
tokens: Vec<&'a memmap2::Mmap>,
current_tokens: &'a memmap2::Mmap,
indexes_in_bytes: Vec<usize>,
seq_len: usize,
device: Device,
}
impl<'a> DatasetRandomIter<'a> {
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
use rand::seq::SliceRandom;
use rand::thread_rng;
let all_tokens = if valid {
&ds.valid_tokens
} else {
&ds.train_tokens
};
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
tokens.shuffle(&mut thread_rng());
let current_tokens = tokens.pop().unwrap();
let seq_len_in_bytes = seq_len * 2;
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
indexes_in_bytes.shuffle(&mut thread_rng());
Self {
all_tokens,
tokens,
current_tokens,
indexes_in_bytes,
seq_len,
device,
}
}
}
impl<'a> Iterator for DatasetRandomIter<'a> {
type Item = Result<(Tensor, Tensor)>;
fn next(&mut self) -> Option<Self::Item> {
use byteorder::{LittleEndian, ReadBytesExt};
use rand::seq::SliceRandom;
use rand::thread_rng;
let seq_len = self.seq_len;
if self.indexes_in_bytes.is_empty() {
if self.tokens.is_empty() {
self.tokens = self.all_tokens.iter().collect();
self.tokens.shuffle(&mut thread_rng());
}
self.current_tokens = self.tokens.pop().unwrap();
let seq_len_in_bytes = self.seq_len * 2;
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
self.indexes_in_bytes.shuffle(&mut thread_rng());
}
let start_idx = self.indexes_in_bytes.pop().unwrap();
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
let mut tokens = vec![0u16; bytes.len() / 2];
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
return Some(Err(err.into()));
}
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
let targets = Tensor::new(&tokens[1..], &self.device);
Some(candle::error::zip(inputs, targets))
}
}
use candle::{DType, Device, Result};
use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
fn valid_loss(
dataset: &Dataset,
@ -121,7 +9,7 @@ fn valid_loss(
device: &Device,
) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let mut sum_ce = 0f64;
let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) {
@ -139,14 +27,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let dataset = Dataset::new(&args.pretokenized_dir)?;
println!(
"loaded dataset, train: {} files, valid: {} files",
dataset.train_tokens.len(),
dataset.valid_tokens.len()
dataset.train_tokens(),
dataset.valid_tokens()
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;

View File

@ -104,7 +104,14 @@ impl TransformerWeights {
})
}
pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
// TODO: As of 2023-08-04, gemm is slower than expected when multiplying a matrix of
// size (1, k) with the transpose of a matrix of size (k, n) as it ends up transposing the
// second matrix back. We detect this case here and as a temporary hack make the weight
// matrix column major rather than row major. This ends up speeding up text generation from
// 120 token/s to 220 token/s on a Ryzen 2600X.
let tr = device.is_cpu() && !candle::utils::has_mkl();
let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) };
let mut ws = std::collections::HashMap::new();
let mut insert = |name: &str, t: Tensor| {
ws.insert(name.to_string(), t);
@ -115,36 +122,36 @@ impl TransformerWeights {
"model.embed_tokens.weight",
self.token_embedding_table.clone(),
);
insert("lm_head.weight", self.token_embedding_table.clone());
insert("lm_head.weight", tr(self.token_embedding_table.clone())?);
insert("model.norm.weight", self.rms_final_weight.clone());
for layer in 0..cfg.n_layers {
ws.insert(
format!("model.layers.{layer}.self_attn.q_proj.weight"),
self.wq.i(layer)?,
tr(self.wq.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.k_proj.weight"),
self.wk.i(layer)?,
tr(self.wk.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.v_proj.weight"),
self.wv.i(layer)?,
tr(self.wv.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.o_proj.weight"),
self.wo.i(layer)?,
tr(self.wo.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.gate_proj.weight"),
self.w1.i(layer)?,
tr(self.w1.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.down_proj.weight"),
self.w2.i(layer)?,
tr(self.w2.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.up_proj.weight"),
self.w3.i(layer)?,
tr(self.w3.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.input_layernorm.weight"),

View File

@ -5,9 +5,6 @@
//
// The tokenizer config can be retrieved from:
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
//
// In order to convert the llama weights to a .npz file, run:
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

View File

@ -63,7 +63,7 @@ struct TrainingArgs {
}
fn training_loop<M: Model>(
m: candle_nn::vision::Dataset,
m: candle_datasets::vision::Dataset,
args: &TrainingArgs,
) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?;
@ -140,7 +140,7 @@ struct Args {
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
let m = candle_nn::vision::mnist::load_dir("data")?;
let m = candle_datasets::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());

View File

@ -0,0 +1,473 @@
#![allow(dead_code)]
//! Attention Based Building Blocks
use candle::{IndexOp, Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
struct GeGlu {
proj: nn::Linear,
span: tracing::Span,
}
impl GeGlu {
fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
let span = tracing::span!(tracing::Level::TRACE, "geglu");
Ok(Self { proj, span })
}
}
impl GeGlu {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
}
}
/// A feed-forward layer.
#[derive(Debug)]
struct FeedForward {
project_in: GeGlu,
linear: nn::Linear,
span: tracing::Span,
}
impl FeedForward {
// The glu parameter in the python code is unused?
// https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347
/// Creates a new feed-forward layer based on some given input dimension, some
/// output dimension, and a multiplier to be used for the intermediary layer.
fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
let inner_dim = dim * mult;
let dim_out = dim_out.unwrap_or(dim);
let vs = vs.pp("net");
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
let span = tracing::span!(tracing::Level::TRACE, "ff");
Ok(Self {
project_in,
linear,
span,
})
}
}
impl FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.project_in.forward(xs)?;
self.linear.forward(&xs)
}
}
#[derive(Debug)]
struct CrossAttention {
to_q: nn::Linear,
to_k: nn::Linear,
to_v: nn::Linear,
to_out: nn::Linear,
heads: usize,
scale: f64,
slice_size: Option<usize>,
span: tracing::Span,
span_attn: tracing::Span,
}
impl CrossAttention {
// Defaults should be heads = 8, dim_head = 64, context_dim = None
fn new(
vs: nn::VarBuilder,
query_dim: usize,
context_dim: Option<usize>,
heads: usize,
dim_head: usize,
slice_size: Option<usize>,
) -> Result<Self> {
let inner_dim = dim_head * heads;
let context_dim = context_dim.unwrap_or(query_dim);
let scale = 1.0 / f64::sqrt(dim_head as f64);
let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
let span = tracing::span!(tracing::Level::TRACE, "xa");
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
Ok(Self {
to_q,
to_k,
to_v,
to_out,
heads,
scale,
slice_size,
span,
span_attn,
})
}
fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
let (batch_size, seq_len, dim) = xs.dims3()?;
xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
.transpose(1, 2)?
.reshape((batch_size * self.heads, seq_len, dim / self.heads))
}
fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
let (batch_size, seq_len, dim) = xs.dims3()?;
xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
.transpose(1, 2)?
.reshape((batch_size / self.heads, seq_len, dim * self.heads))
}
fn sliced_attention(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
slice_size: usize,
) -> Result<Tensor> {
let batch_size_attention = query.dim(0)?;
let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
for i in 0..batch_size_attention / slice_size {
let start_idx = i * slice_size;
let end_idx = (i + 1) * slice_size;
let xs = query
.i(start_idx..end_idx)?
.matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?;
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
hidden_states.push(xs)
}
let hidden_states = Tensor::stack(&hidden_states, 0)?;
self.reshape_batch_dim_to_heads(&hidden_states)
}
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
self.reshape_batch_dim_to_heads(&xs)
}
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let query = self.to_q.forward(xs)?;
let context = context.unwrap_or(xs);
let key = self.to_k.forward(context)?;
let value = self.to_v.forward(context)?;
let query = self.reshape_heads_to_batch_dim(&query)?;
let key = self.reshape_heads_to_batch_dim(&key)?;
let value = self.reshape_heads_to_batch_dim(&value)?;
let xs = match self.slice_size {
None => self.attention(&query, &key, &value)?,
Some(slice_size) => {
if query.dim(0)? / slice_size <= 1 {
self.attention(&query, &key, &value)?
} else {
self.sliced_attention(&query, &key, &value, slice_size)?
}
}
};
self.to_out.forward(&xs)
}
}
/// A basic Transformer block.
#[derive(Debug)]
struct BasicTransformerBlock {
attn1: CrossAttention,
ff: FeedForward,
attn2: CrossAttention,
norm1: nn::LayerNorm,
norm2: nn::LayerNorm,
norm3: nn::LayerNorm,
span: tracing::Span,
}
impl BasicTransformerBlock {
fn new(
vs: nn::VarBuilder,
dim: usize,
n_heads: usize,
d_head: usize,
context_dim: Option<usize>,
sliced_attention_size: Option<usize>,
) -> Result<Self> {
let attn1 = CrossAttention::new(
vs.pp("attn1"),
dim,
None,
n_heads,
d_head,
sliced_attention_size,
)?;
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
let attn2 = CrossAttention::new(
vs.pp("attn2"),
dim,
context_dim,
n_heads,
d_head,
sliced_attention_size,
)?;
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
let span = tracing::span!(tracing::Level::TRACE, "basic-transformer");
Ok(Self {
attn1,
ff,
attn2,
norm1,
norm2,
norm3,
span,
})
}
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
self.ff.forward(&self.norm3.forward(&xs)?)? + xs
}
}
#[derive(Debug, Clone, Copy)]
pub struct SpatialTransformerConfig {
pub depth: usize,
pub num_groups: usize,
pub context_dim: Option<usize>,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for SpatialTransformerConfig {
fn default() -> Self {
Self {
depth: 1,
num_groups: 32,
context_dim: None,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
enum Proj {
Conv2d(nn::Conv2d),
Linear(nn::Linear),
}
// Aka Transformer2DModel
#[derive(Debug)]
pub struct SpatialTransformer {
norm: nn::GroupNorm,
proj_in: Proj,
transformer_blocks: Vec<BasicTransformerBlock>,
proj_out: Proj,
span: tracing::Span,
pub config: SpatialTransformerConfig,
}
impl SpatialTransformer {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
n_heads: usize,
d_head: usize,
config: SpatialTransformerConfig,
) -> Result<Self> {
let inner_dim = n_heads * d_head;
let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?;
let proj_in = if config.use_linear_projection {
Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?)
} else {
Proj::Conv2d(nn::conv2d(
in_channels,
inner_dim,
1,
Default::default(),
vs.pp("proj_in"),
)?)
};
let mut transformer_blocks = vec![];
let vs_tb = vs.pp("transformer_blocks");
for index in 0..config.depth {
let tb = BasicTransformerBlock::new(
vs_tb.pp(&index.to_string()),
inner_dim,
n_heads,
d_head,
config.context_dim,
config.sliced_attention_size,
)?;
transformer_blocks.push(tb)
}
let proj_out = if config.use_linear_projection {
Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?)
} else {
Proj::Conv2d(nn::conv2d(
inner_dim,
in_channels,
1,
Default::default(),
vs.pp("proj_out"),
)?)
};
let span = tracing::span!(tracing::Level::TRACE, "spatial-transformer");
Ok(Self {
norm,
proj_in,
transformer_blocks,
proj_out,
span,
config,
})
}
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let (batch, _channel, height, weight) = xs.dims4()?;
let residual = xs;
let xs = self.norm.forward(xs)?;
let (inner_dim, xs) = match &self.proj_in {
Proj::Conv2d(p) => {
let xs = p.forward(&xs)?;
let inner_dim = xs.dim(1)?;
let xs = xs
.transpose(1, 2)?
.t()?
.reshape((batch, height * weight, inner_dim))?;
(inner_dim, xs)
}
Proj::Linear(p) => {
let inner_dim = xs.dim(1)?;
let xs = xs
.transpose(1, 2)?
.t()?
.reshape((batch, height * weight, inner_dim))?;
(inner_dim, p.forward(&xs)?)
}
};
let mut xs = xs;
for block in self.transformer_blocks.iter() {
xs = block.forward(&xs, context)?
}
let xs = match &self.proj_out {
Proj::Conv2d(p) => p.forward(
&xs.reshape((batch, height, weight, inner_dim))?
.t()?
.transpose(1, 2)?,
)?,
Proj::Linear(p) => p
.forward(&xs)?
.reshape((batch, height, weight, inner_dim))?
.t()?
.transpose(1, 2)?,
};
xs + residual
}
}
/// Configuration for an attention block.
#[derive(Debug, Clone, Copy)]
pub struct AttentionBlockConfig {
pub num_head_channels: Option<usize>,
pub num_groups: usize,
pub rescale_output_factor: f64,
pub eps: f64,
}
impl Default for AttentionBlockConfig {
fn default() -> Self {
Self {
num_head_channels: None,
num_groups: 32,
rescale_output_factor: 1.,
eps: 1e-5,
}
}
}
#[derive(Debug)]
pub struct AttentionBlock {
group_norm: nn::GroupNorm,
query: nn::Linear,
key: nn::Linear,
value: nn::Linear,
proj_attn: nn::Linear,
channels: usize,
num_heads: usize,
span: tracing::Span,
config: AttentionBlockConfig,
}
impl AttentionBlock {
pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
let num_head_channels = config.num_head_channels.unwrap_or(channels);
let num_heads = channels / num_head_channels;
let group_norm =
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
let query = nn::linear(channels, channels, vs.pp("query"))?;
let key = nn::linear(channels, channels, vs.pp("key"))?;
let value = nn::linear(channels, channels, vs.pp("value"))?;
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
Ok(Self {
group_norm,
query,
key,
value,
proj_attn,
channels,
num_heads,
span,
config,
})
}
fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> {
let (batch, t, h_times_d) = xs.dims3()?;
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
.transpose(1, 2)
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = xs;
let (batch, channel, height, width) = xs.dims4()?;
let xs = self
.group_norm
.forward(xs)?
.reshape((batch, channel, height * width))?
.transpose(1, 2)?;
let query_proj = self.query.forward(&xs)?;
let key_proj = self.key.forward(&xs)?;
let value_proj = self.value.forward(&xs)?;
let query_states = self.transpose_for_scores(query_proj)?;
let key_states = self.transpose_for_scores(key_proj)?;
let value_states = self.transpose_for_scores(value_proj)?;
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
let attention_scores =
// TODO: Check that this needs two multiplication by `scale`.
(query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
let xs = attention_probs.matmul(&value_states.contiguous()?)?;
let xs = xs.transpose(1, 2)?.contiguous()?;
let xs = xs.flatten_from(D::Minus2)?;
let xs = self
.proj_attn
.forward(&xs)?
.t()?
.reshape((batch, channel, height, width))?;
(xs + residual)? / self.config.rescale_output_factor
}
}

View File

@ -0,0 +1,305 @@
#![allow(dead_code)]
//! Contrastive Language-Image Pre-Training
//!
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
//! pairs of images with related texts.
//!
//! https://github.com/openai/CLIP
use candle::{Device, Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug, Clone, Copy)]
pub enum Activation {
QuickGelu,
Gelu,
}
impl Activation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
Activation::Gelu => xs.gelu(),
}
}
}
#[derive(Debug, Clone)]
pub struct Config {
vocab_size: usize,
embed_dim: usize, // aka config.hidden_size
activation: Activation, // aka config.hidden_act
intermediate_size: usize,
pub max_position_embeddings: usize,
// The character to use for padding, use EOS when not set.
pub pad_with: Option<String>,
num_hidden_layers: usize,
num_attention_heads: usize,
#[allow(dead_code)]
projection_dim: usize,
}
impl Config {
// The config details can be found in the "text_config" section of this json file:
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
pub fn v1_5() -> Self {
Self {
vocab_size: 49408,
embed_dim: 768,
intermediate_size: 3072,
max_position_embeddings: 77,
pad_with: None,
num_hidden_layers: 12,
num_attention_heads: 12,
projection_dim: 768,
activation: Activation::QuickGelu,
}
}
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json
pub fn v2_1() -> Self {
Self {
vocab_size: 49408,
embed_dim: 1024,
intermediate_size: 4096,
max_position_embeddings: 77,
pad_with: Some("!".to_string()),
num_hidden_layers: 23,
num_attention_heads: 16,
projection_dim: 512,
activation: Activation::Gelu,
}
}
}
// CLIP Text Model
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py
#[derive(Debug)]
struct ClipTextEmbeddings {
token_embedding: candle_nn::Embedding,
position_embedding: candle_nn::Embedding,
position_ids: Tensor,
}
impl ClipTextEmbeddings {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let token_embedding =
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
let position_embedding = candle_nn::embedding(
c.max_position_embeddings,
c.embed_dim,
vs.pp("position_embedding"),
)?;
let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
Ok(ClipTextEmbeddings {
token_embedding,
position_embedding,
position_ids,
})
}
}
impl ClipTextEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let token_embedding = self.token_embedding.forward(xs)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
token_embedding.broadcast_add(&position_embedding)
}
}
#[derive(Debug)]
struct ClipAttention {
k_proj: candle_nn::Linear,
v_proj: candle_nn::Linear,
q_proj: candle_nn::Linear,
out_proj: candle_nn::Linear,
head_dim: usize,
scale: f64,
num_attention_heads: usize,
}
impl ClipAttention {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let embed_dim = c.embed_dim;
let num_attention_heads = c.num_attention_heads;
let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
let head_dim = embed_dim / num_attention_heads;
let scale = (head_dim as f64).powf(-0.5);
Ok(ClipAttention {
k_proj,
v_proj,
q_proj,
out_proj,
head_dim,
scale,
num_attention_heads,
})
}
fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()
}
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
let (bsz, seq_len, embed_dim) = xs.dims3()?;
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
let query_states = self
.shape(&query_states, seq_len, bsz)?
.reshape(proj_shape)?;
let key_states = self
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?;
let value_states = self
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let src_len = key_states.dim(1)?;
let attn_weights = attn_weights
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
.broadcast_add(causal_attention_mask)?;
let attn_weights =
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_output = attn_weights.matmul(&value_states)?;
let attn_output = attn_output
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
.transpose(1, 2)?
.reshape((bsz, seq_len, embed_dim))?;
self.out_proj.forward(&attn_output)
}
}
#[derive(Debug)]
struct ClipMlp {
fc1: candle_nn::Linear,
fc2: candle_nn::Linear,
activation: Activation,
}
impl ClipMlp {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?;
let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?;
Ok(ClipMlp {
fc1,
fc2,
activation: c.activation,
})
}
}
impl ClipMlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?;
self.fc2.forward(&self.activation.forward(&xs)?)
}
}
#[derive(Debug)]
struct ClipEncoderLayer {
self_attn: ClipAttention,
layer_norm1: candle_nn::LayerNorm,
mlp: ClipMlp,
layer_norm2: candle_nn::LayerNorm,
}
impl ClipEncoderLayer {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?;
let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?;
Ok(ClipEncoderLayer {
self_attn,
layer_norm1,
mlp,
layer_norm2,
})
}
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = self.layer_norm1.forward(xs)?;
let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self.layer_norm2.forward(&xs)?;
let xs = self.mlp.forward(&xs)?;
xs + residual
}
}
#[derive(Debug)]
struct ClipEncoder {
layers: Vec<ClipEncoderLayer>,
}
impl ClipEncoder {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let vs = vs.pp("layers");
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
for index in 0..c.num_hidden_layers {
let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
layers.push(layer)
}
Ok(ClipEncoder { layers })
}
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?;
}
Ok(xs)
}
}
/// A CLIP transformer based model.
#[derive(Debug)]
pub struct ClipTextTransformer {
embeddings: ClipTextEmbeddings,
encoder: ClipEncoder,
final_layer_norm: candle_nn::LayerNorm,
}
impl ClipTextTransformer {
pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let vs = vs.pp("text_model");
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
Ok(ClipTextTransformer {
embeddings,
encoder,
final_layer_norm,
})
}
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
mask.broadcast_as((bsz, seq_len, seq_len))
}
}
impl ClipTextTransformer {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (bsz, seq_len) = xs.dims2()?;
let xs = self.embeddings.forward(xs)?;
let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
self.final_layer_norm.forward(&xs)
}
}

View File

@ -0,0 +1,181 @@
#![allow(dead_code)]
//! # Denoising Diffusion Implicit Models
//!
//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler
//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM
//! generative process is the reverse of a Markovian process, DDIM generalizes
//! this to non-Markovian guidance.
//!
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
//! https://arxiv.org/abs/2010.02502
use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
use candle::{Result, Tensor};
/// The configuration for the DDIM scheduler.
#[derive(Debug, Clone, Copy)]
pub struct DDIMSchedulerConfig {
/// The value of beta at the beginning of training.
pub beta_start: f64,
/// The value of beta at the end of training.
pub beta_end: f64,
/// How beta evolved during training.
pub beta_schedule: BetaSchedule,
/// The amount of noise to be added at each step.
pub eta: f64,
/// Adjust the indexes of the inference schedule by this value.
pub steps_offset: usize,
/// prediction type of the scheduler function, one of `epsilon` (predicting
/// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
/// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
pub prediction_type: PredictionType,
/// number of diffusion steps used to train the model
pub train_timesteps: usize,
}
impl Default for DDIMSchedulerConfig {
fn default() -> Self {
Self {
beta_start: 0.00085f64,
beta_end: 0.012f64,
beta_schedule: BetaSchedule::ScaledLinear,
eta: 0.,
steps_offset: 1,
prediction_type: PredictionType::Epsilon,
train_timesteps: 1000,
}
}
}
/// The DDIM scheduler.
#[derive(Debug, Clone)]
pub struct DDIMScheduler {
timesteps: Vec<usize>,
alphas_cumprod: Vec<f64>,
step_ratio: usize,
init_noise_sigma: f64,
pub config: DDIMSchedulerConfig,
}
// clip_sample: False, set_alpha_to_one: False
impl DDIMScheduler {
/// Creates a new DDIM scheduler given the number of steps to be
/// used for inference as well as the number of steps that was used
/// during training.
pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
let step_ratio = config.train_timesteps / inference_steps;
let timesteps: Vec<usize> = (0..(inference_steps))
.map(|s| s * step_ratio + config.steps_offset)
.rev()
.collect();
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => crate::utils::linspace(
config.beta_start.sqrt(),
config.beta_end.sqrt(),
config.train_timesteps,
)?
.sqr()?,
BetaSchedule::Linear => {
crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
}
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
};
let betas = betas.to_vec1::<f64>()?;
let mut alphas_cumprod = Vec::with_capacity(betas.len());
for &beta in betas.iter() {
let alpha = 1.0 - beta;
alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
}
Ok(Self {
alphas_cumprod,
timesteps,
step_ratio,
init_noise_sigma: 1.,
config,
})
}
pub fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
/// Ensures interchangeability with schedulers that need to scale the denoising model input
/// depending on the current timestep.
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
Ok(sample)
}
/// Performs a backward step during inference.
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
timestep
};
// https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
let prev_timestep = if timestep > self.step_ratio {
timestep - self.step_ratio
} else {
0
};
let alpha_prod_t = self.alphas_cumprod[timestep];
let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
let beta_prod_t = 1. - alpha_prod_t;
let beta_prod_t_prev = 1. - alpha_prod_t_prev;
let (pred_original_sample, pred_epsilon) = match self.config.prediction_type {
PredictionType::Epsilon => {
let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)?
* (1. / alpha_prod_t.sqrt()))?;
(pred_original_sample, model_output.clone())
}
PredictionType::VPrediction => {
let pred_original_sample =
((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?;
let pred_epsilon =
((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?;
(pred_original_sample, pred_epsilon)
}
PredictionType::Sample => {
let pred_original_sample = model_output.clone();
let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())?
* (1. / beta_prod_t.sqrt()))?;
(pred_original_sample, pred_epsilon)
}
};
let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev);
let std_dev_t = self.config.eta * variance.sqrt();
let pred_sample_direction =
(pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?;
let prev_sample =
((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?;
if self.config.eta > 0. {
&prev_sample
+ Tensor::randn(
0f32,
std_dev_t as f32,
prev_sample.shape(),
prev_sample.device(),
)?
} else {
Ok(prev_sample)
}
}
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
timestep
};
let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
}
pub fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
}

View File

@ -0,0 +1,65 @@
#![allow(dead_code)]
use candle::{Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
pub struct TimestepEmbedding {
linear_1: nn::Linear,
linear_2: nn::Linear,
}
impl TimestepEmbedding {
// act_fn: "silu"
pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
Ok(Self { linear_1, linear_2 })
}
}
impl TimestepEmbedding {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
self.linear_2.forward(&xs)
}
}
#[derive(Debug)]
pub struct Timesteps {
num_channels: usize,
flip_sin_to_cos: bool,
downscale_freq_shift: f64,
}
impl Timesteps {
pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
Self {
num_channels,
flip_sin_to_cos,
downscale_freq_shift,
}
}
}
impl Timesteps {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let half_dim = (self.num_channels / 2) as u32;
let exponent =
(Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
let emb = exponent.exp()?;
// emb = timesteps[:, None].float() * emb[None, :]
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
let (cos, sin) = (emb.cos()?, emb.sin()?);
let emb = if self.flip_sin_to_cos {
Tensor::cat(&[&cos, &sin], D::Minus1)?
} else {
Tensor::cat(&[&sin, &cos], D::Minus1)?
};
if self.num_channels % 2 == 1 {
emb.pad_with_zeros(D::Minus2, 0, 1)
} else {
Ok(emb)
}
}
}

View File

@ -0,0 +1,326 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod attention;
mod clip;
mod ddim;
mod embeddings;
mod resnet;
mod schedulers;
mod stable_diffusion;
mod unet_2d;
mod unet_2d_blocks;
mod utils;
mod vae;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor};
use clap::Parser;
use tokenizers::Tokenizer;
const GUIDANCE_SCALE: f64 = 7.5;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The prompt to be used for image generation.
#[arg(
long,
default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
)]
prompt: String,
#[arg(long, default_value = "")]
uncond_prompt: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The height in pixels of the generated image.
#[arg(long)]
height: Option<usize>,
/// The width in pixels of the generated image.
#[arg(long)]
width: Option<usize>,
/// The UNet weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
unet_weights: Option<String>,
/// The CLIP weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
clip_weights: Option<String>,
/// The VAE weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
vae_weights: Option<String>,
#[arg(long, value_name = "FILE")]
/// The file specifying the tokenizer to used for tokenization.
tokenizer: Option<String>,
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
#[arg(long)]
sliced_attention_size: Option<usize>,
/// The number of steps to run the diffusion for.
#[arg(long, default_value_t = 30)]
n_steps: usize,
/// The number of samples to generate.
#[arg(long, default_value_t = 1)]
num_samples: i64,
/// The name of the final image to generate.
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
final_image: String,
#[arg(long, value_enum, default_value = "v2-1")]
sd_version: StableDiffusionVersion,
/// Generate intermediary images at each step.
#[arg(long, action)]
intermediary_images: bool,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
enum StableDiffusionVersion {
V1_5,
V2_1,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
Clip,
Unet,
Vae,
}
impl StableDiffusionVersion {
fn repo(&self) -> &'static str {
match self {
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
}
}
fn unet_file(&self) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 => "unet/diffusion_pytorch_model.safetensors",
}
}
fn vae_file(&self) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 => "vae/diffusion_pytorch_model.safetensors",
}
}
fn clip_file(&self) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 => "text_encoder/model.safetensors",
}
}
}
impl ModelFile {
const TOKENIZER_REPO: &str = "openai/clip-vit-base-patch32";
const TOKENIZER_PATH: &str = "tokenizer.json";
fn get(
&self,
filename: Option<String>,
version: StableDiffusionVersion,
) -> Result<std::path::PathBuf> {
use hf_hub::api::sync::Api;
match filename {
Some(filename) => Ok(std::path::PathBuf::from(filename)),
None => {
let (repo, path) = match self {
Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
Self::Clip => (version.repo(), version.clip_file()),
Self::Unet => (version.repo(), version.unet_file()),
Self::Vae => (version.repo(), version.vae_file()),
};
let filename = Api::new()?.model(repo.to_string()).get(path)?;
Ok(filename)
}
}
}
}
fn output_filename(
basename: &str,
sample_idx: i64,
num_samples: i64,
timestep_idx: Option<usize>,
) -> String {
let filename = if num_samples > 1 {
match basename.rsplit_once('.') {
None => format!("{basename}.{sample_idx}.png"),
Some((filename_no_extension, extension)) => {
format!("{filename_no_extension}.{sample_idx}.{extension}")
}
}
} else {
basename.to_string()
};
match timestep_idx {
None => filename,
Some(timestep_idx) => match filename.rsplit_once('.') {
None => format!("{filename}-{timestep_idx}.png"),
Some((filename_no_extension, extension)) => {
format!("{filename_no_extension}-{timestep_idx}.{extension}")
}
},
}
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args {
prompt,
uncond_prompt,
cpu,
height,
width,
n_steps,
tokenizer,
final_image,
sliced_attention_size,
num_samples,
sd_version,
clip_weights,
vae_weights,
unet_weights,
tracing,
..
} = args;
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => {
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)
}
StableDiffusionVersion::V2_1 => {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
}
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version)?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let pad_id = match &sd_config.clip.pad_with {
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
};
println!("Running with prompt \"{prompt}\".");
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while tokens.len() < sd_config.clip.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
println!("Building the Clip transformer.");
let clip_weights = ModelFile::Clip.get(clip_weights, sd_version)?;
let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?;
let text_embeddings = text_model.forward(&tokens)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version)?;
let vae = sd_config.build_vae(&vae_weights, &device)?;
println!("Building the unet.");
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4)?;
let bsize = 1;
for idx in 0..num_samples {
let mut latents = Tensor::randn(
0f32,
1f32,
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
&device,
)?;
// scale the initial noise by the standard deviation required by the scheduler
latents = (latents * scheduler.init_noise_sigma())?;
for (timestep_index, &timestep) in scheduler.timesteps().iter().enumerate() {
println!("Timestep {timestep_index}/{n_steps}");
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
let noise_pred =
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
let noise_pred =
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
latents = scheduler.step(&noise_pred, timestep, &latents)?;
if args.intermediary_images {
let image = vae.decode(&(&latents / 0.18215)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename =
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
crate::utils::save_image(&image, image_filename)?
}
}
println!(
"Generating the final image for sample {}/{}.",
idx + 1,
num_samples
);
let image = vae.decode(&(&latents / 0.18215)?)?;
// TODO: Add the clamping between 0 and 1.
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
crate::utils::save_image(&image, image_filename)?
}
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
run(args)
}

View File

@ -0,0 +1,134 @@
#![allow(dead_code)]
//! ResNet Building Blocks
//!
//! Some Residual Network blocks used in UNet models.
//!
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
//! https://arxiv.org/abs/1512.03385
use crate::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
/// Configuration for a ResNet block.
#[derive(Debug, Clone, Copy)]
pub struct ResnetBlock2DConfig {
/// The number of output channels, defaults to the number of input channels.
pub out_channels: Option<usize>,
pub temb_channels: Option<usize>,
/// The number of groups to use in group normalization.
pub groups: usize,
pub groups_out: Option<usize>,
/// The epsilon to be used in the group normalization operations.
pub eps: f64,
/// Whether to use a 2D convolution in the skip connection. When using None,
/// such a convolution is used if the number of input channels is different from
/// the number of output channels.
pub use_in_shortcut: Option<bool>,
// non_linearity: silu
/// The final output is scaled by dividing by this value.
pub output_scale_factor: f64,
}
impl Default for ResnetBlock2DConfig {
fn default() -> Self {
Self {
out_channels: None,
temb_channels: Some(512),
groups: 32,
groups_out: None,
eps: 1e-6,
use_in_shortcut: None,
output_scale_factor: 1.,
}
}
}
#[derive(Debug)]
pub struct ResnetBlock2D {
norm1: nn::GroupNorm,
conv1: Conv2d,
norm2: nn::GroupNorm,
conv2: Conv2d,
time_emb_proj: Option<nn::Linear>,
conv_shortcut: Option<Conv2d>,
span: tracing::Span,
config: ResnetBlock2DConfig,
}
impl ResnetBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
config: ResnetBlock2DConfig,
) -> Result<Self> {
let out_channels = config.out_channels.unwrap_or(in_channels);
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 1,
};
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
let groups_out = config.groups_out.unwrap_or(config.groups);
let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
let conv2 = conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
let use_in_shortcut = config
.use_in_shortcut
.unwrap_or(in_channels != out_channels);
let conv_shortcut = if use_in_shortcut {
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 0,
};
Some(conv2d(
in_channels,
out_channels,
1,
conv_cfg,
vs.pp("conv_shortcut"),
)?)
} else {
None
};
let time_emb_proj = match config.temb_channels {
None => None,
Some(temb_channels) => Some(nn::linear(
temb_channels,
out_channels,
vs.pp("time_emb_proj"),
)?),
};
let span = tracing::span!(tracing::Level::TRACE, "resnet2d");
Ok(Self {
norm1,
conv1,
norm2,
conv2,
time_emb_proj,
span,
config,
conv_shortcut,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let shortcut_xs = match &self.conv_shortcut {
Some(conv_shortcut) => conv_shortcut.forward(xs)?,
None => xs.clone(),
};
let xs = self.norm1.forward(xs)?;
let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
let xs = match (temb, &self.time_emb_proj) {
(Some(temb), Some(time_emb_proj)) => time_emb_proj
.forward(&nn::ops::silu(temb)?)?
.unsqueeze(D::Minus1)?
.unsqueeze(D::Minus1)?
.broadcast_add(&xs)?,
_ => xs,
};
let xs = self
.conv2
.forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
(shortcut_xs + xs)? / self.config.output_scale_factor
}
}

View File

@ -0,0 +1,45 @@
#![allow(dead_code)]
//! # Diffusion pipelines and models
//!
//! Noise schedulers can be used to set the trade-off between
//! inference speed and quality.
use candle::{Result, Tensor};
/// This represents how beta ranges from its minimum value to the maximum
/// during training.
#[derive(Debug, Clone, Copy)]
pub enum BetaSchedule {
/// Linear interpolation.
Linear,
/// Linear interpolation of the square root of beta.
ScaledLinear,
/// Glide cosine schedule
SquaredcosCapV2,
}
#[derive(Debug, Clone, Copy)]
pub enum PredictionType {
Epsilon,
VPrediction,
Sample,
}
/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
/// `(1-beta)` over time from `t = [0,1]`.
///
/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)`
/// up to that part of the diffusion process.
pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> {
let alpha_bar = |time_step: usize| {
f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2)
};
let mut betas = Vec::with_capacity(num_diffusion_timesteps);
for i in 0..num_diffusion_timesteps {
let t1 = i / num_diffusion_timesteps;
let t2 = (i + 1) / num_diffusion_timesteps;
betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta));
}
let betas_len = betas.len();
Tensor::from_vec(betas, betas_len, &candle::Device::Cpu)
}

View File

@ -0,0 +1,216 @@
#![allow(dead_code)]
use crate::schedulers::PredictionType;
use crate::{clip, ddim, unet_2d, vae};
use candle::{DType, Device, Result};
use candle_nn as nn;
#[derive(Clone, Debug)]
pub struct StableDiffusionConfig {
pub width: usize,
pub height: usize,
pub clip: clip::Config,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: ddim::DDIMSchedulerConfig,
}
impl StableDiffusionConfig {
pub fn v1_5(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, true, 8),
bc(640, true, 8),
bc(1280, true, 8),
bc(1280, false, 8),
],
center_input_sample: false,
cross_attention_dim: 768,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: false,
};
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
height
} else {
512
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
512
};
Self {
width,
height,
clip: clip::Config::v1_5(),
autoencoder,
scheduler: Default::default(),
unet,
}
}
fn v2_1_(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
prediction_type: PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, true, 5),
bc(640, true, 10),
bc(1280, true, 20),
bc(1280, false, 20),
],
center_input_sample: false,
cross_attention_dim: 1024,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
height
} else {
768
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
768
};
Self {
width,
height,
clip: clip::Config::v2_1(),
autoencoder,
scheduler,
unet,
}
}
pub fn v2_1(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json
Self::v2_1_(
sliced_attention_size,
height,
width,
PredictionType::VPrediction,
)
}
pub fn v2_1_inpaint(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
// https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json
// This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction
// type being "epsilon" by default and not "v_prediction".
Self::v2_1_(
sliced_attention_size,
height,
width,
PredictionType::Epsilon,
)
}
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
device: &Device,
) -> Result<vae::AutoEncoderKL> {
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
let weights = weights.deserialize()?;
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
Ok(autoencoder)
}
pub fn build_unet<P: AsRef<std::path::Path>>(
&self,
unet_weights: P,
device: &Device,
in_channels: usize,
) -> Result<unet_2d::UNet2DConditionModel> {
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
let weights = weights.deserialize()?;
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?;
Ok(unet)
}
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
&self,
clip_weights: P,
device: &Device,
) -> Result<clip::ClipTextTransformer> {
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
let weights = weights.deserialize()?;
let vs = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
Ok(text_model)
}
}

View File

@ -0,0 +1,386 @@
#![allow(dead_code)]
//! 2D UNet Denoising Models
//!
//! The 2D Unet models take as input a noisy sample and the current diffusion
//! timestep and return a denoised version of the input.
use crate::embeddings::{TimestepEmbedding, Timesteps};
use crate::unet_2d_blocks::*;
use crate::utils::{conv2d, Conv2d};
use candle::{DType, Result, Tensor};
use candle_nn as nn;
#[derive(Debug, Clone, Copy)]
pub struct BlockConfig {
pub out_channels: usize,
pub use_cross_attn: bool,
pub attention_head_dim: usize,
}
#[derive(Debug, Clone)]
pub struct UNet2DConditionModelConfig {
pub center_input_sample: bool,
pub flip_sin_to_cos: bool,
pub freq_shift: f64,
pub blocks: Vec<BlockConfig>,
pub layers_per_block: usize,
pub downsample_padding: usize,
pub mid_block_scale_factor: f64,
pub norm_num_groups: usize,
pub norm_eps: f64,
pub cross_attention_dim: usize,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for UNet2DConditionModelConfig {
fn default() -> Self {
Self {
center_input_sample: false,
flip_sin_to_cos: true,
freq_shift: 0.,
blocks: vec![
BlockConfig {
out_channels: 320,
use_cross_attn: true,
attention_head_dim: 8,
},
BlockConfig {
out_channels: 640,
use_cross_attn: true,
attention_head_dim: 8,
},
BlockConfig {
out_channels: 1280,
use_cross_attn: true,
attention_head_dim: 8,
},
BlockConfig {
out_channels: 1280,
use_cross_attn: false,
attention_head_dim: 8,
},
],
layers_per_block: 2,
downsample_padding: 1,
mid_block_scale_factor: 1.,
norm_num_groups: 32,
norm_eps: 1e-5,
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub(crate) enum UNetDownBlock {
Basic(DownBlock2D),
CrossAttn(CrossAttnDownBlock2D),
}
#[derive(Debug)]
enum UNetUpBlock {
Basic(UpBlock2D),
CrossAttn(CrossAttnUpBlock2D),
}
#[derive(Debug)]
pub struct UNet2DConditionModel {
conv_in: Conv2d,
time_proj: Timesteps,
time_embedding: TimestepEmbedding,
down_blocks: Vec<UNetDownBlock>,
mid_block: UNetMidBlock2DCrossAttn,
up_blocks: Vec<UNetUpBlock>,
conv_norm_out: nn::GroupNorm,
conv_out: Conv2d,
span: tracing::Span,
config: UNet2DConditionModelConfig,
}
impl UNet2DConditionModel {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: UNet2DConditionModelConfig,
) -> Result<Self> {
let n_blocks = config.blocks.len();
let b_channels = config.blocks[0].out_channels;
let bl_channels = config.blocks.last().unwrap().out_channels;
let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
let time_embed_dim = b_channels * 4;
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 1,
};
let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
let time_embedding =
TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
let vs_db = vs.pp("down_blocks");
let down_blocks = (0..n_blocks)
.map(|i| {
let BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
} = config.blocks[i];
// Enable automatic attention slicing if the config sliced_attention_size is set to 0.
let sliced_attention_size = match config.sliced_attention_size {
Some(0) => Some(attention_head_dim / 2),
_ => config.sliced_attention_size,
};
let in_channels = if i > 0 {
config.blocks[i - 1].out_channels
} else {
b_channels
};
let db_cfg = DownBlock2DConfig {
num_layers: config.layers_per_block,
resnet_eps: config.norm_eps,
resnet_groups: config.norm_num_groups,
add_downsample: i < n_blocks - 1,
downsample_padding: config.downsample_padding,
..Default::default()
};
if use_cross_attn {
let config = CrossAttnDownBlock2DConfig {
downblock: db_cfg,
attn_num_head_channels: attention_head_dim,
cross_attention_dim: config.cross_attention_dim,
sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let block = CrossAttnDownBlock2D::new(
vs_db.pp(&i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
config,
)?;
Ok(UNetDownBlock::CrossAttn(block))
} else {
let block = DownBlock2D::new(
vs_db.pp(&i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
db_cfg,
)?;
Ok(UNetDownBlock::Basic(block))
}
})
.collect::<Result<Vec<_>>>()?;
let mid_cfg = UNetMidBlock2DCrossAttnConfig {
resnet_eps: config.norm_eps,
output_scale_factor: config.mid_block_scale_factor,
cross_attn_dim: config.cross_attention_dim,
attn_num_head_channels: bl_attention_head_dim,
resnet_groups: Some(config.norm_num_groups),
use_linear_projection: config.use_linear_projection,
..Default::default()
};
let mid_block = UNetMidBlock2DCrossAttn::new(
vs.pp("mid_block"),
bl_channels,
Some(time_embed_dim),
mid_cfg,
)?;
let vs_ub = vs.pp("up_blocks");
let up_blocks = (0..n_blocks)
.map(|i| {
let BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
} = config.blocks[n_blocks - 1 - i];
// Enable automatic attention slicing if the config sliced_attention_size is set to 0.
let sliced_attention_size = match config.sliced_attention_size {
Some(0) => Some(attention_head_dim / 2),
_ => config.sliced_attention_size,
};
let prev_out_channels = if i > 0 {
config.blocks[n_blocks - i].out_channels
} else {
bl_channels
};
let in_channels = {
let index = if i == n_blocks - 1 {
0
} else {
n_blocks - i - 2
};
config.blocks[index].out_channels
};
let ub_cfg = UpBlock2DConfig {
num_layers: config.layers_per_block + 1,
resnet_eps: config.norm_eps,
resnet_groups: config.norm_num_groups,
add_upsample: i < n_blocks - 1,
..Default::default()
};
if use_cross_attn {
let config = CrossAttnUpBlock2DConfig {
upblock: ub_cfg,
attn_num_head_channels: attention_head_dim,
cross_attention_dim: config.cross_attention_dim,
sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let block = CrossAttnUpBlock2D::new(
vs_ub.pp(&i.to_string()),
in_channels,
prev_out_channels,
out_channels,
Some(time_embed_dim),
config,
)?;
Ok(UNetUpBlock::CrossAttn(block))
} else {
let block = UpBlock2D::new(
vs_ub.pp(&i.to_string()),
in_channels,
prev_out_channels,
out_channels,
Some(time_embed_dim),
ub_cfg,
)?;
Ok(UNetUpBlock::Basic(block))
}
})
.collect::<Result<Vec<_>>>()?;
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
b_channels,
config.norm_eps,
vs.pp("conv_norm_out"),
)?;
let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
let span = tracing::span!(tracing::Level::TRACE, "unet2d");
Ok(Self {
conv_in,
time_proj,
time_embedding,
down_blocks,
mid_block,
up_blocks,
conv_norm_out,
conv_out,
span,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
timestep: f64,
encoder_hidden_states: &Tensor,
) -> Result<Tensor> {
let _enter = self.span.enter();
self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
}
pub fn forward_with_additional_residuals(
&self,
xs: &Tensor,
timestep: f64,
encoder_hidden_states: &Tensor,
down_block_additional_residuals: Option<&[Tensor]>,
mid_block_additional_residual: Option<&Tensor>,
) -> Result<Tensor> {
let (bsize, _channels, height, width) = xs.dims4()?;
let device = xs.device();
let n_blocks = self.config.blocks.len();
let num_upsamplers = n_blocks - 1;
let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
let forward_upsample_size =
height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
// 0. center input if necessary
let xs = if self.config.center_input_sample {
((xs * 2.0)? - 1.0)?
} else {
xs.clone()
};
// 1. time
let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?;
let emb = self.time_proj.forward(&emb)?;
let emb = self.time_embedding.forward(&emb)?;
// 2. pre-process
let xs = self.conv_in.forward(&xs)?;
// 3. down
let mut down_block_res_xs = vec![xs.clone()];
let mut xs = xs;
for down_block in self.down_blocks.iter() {
let (_xs, res_xs) = match down_block {
UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
UNetDownBlock::CrossAttn(b) => {
b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
}
};
down_block_res_xs.extend(res_xs);
xs = _xs;
}
let new_down_block_res_xs =
if let Some(down_block_additional_residuals) = down_block_additional_residuals {
let mut v = vec![];
// A previous version of this code had a bug because of the addition being made
// in place via += hence modifying the input of the mid block.
for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
v.push((&down_block_res_xs[i] + residuals)?)
}
v
} else {
down_block_res_xs
};
let mut down_block_res_xs = new_down_block_res_xs;
// 4. mid
let xs = self
.mid_block
.forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
let xs = match mid_block_additional_residual {
None => xs,
Some(m) => (m + xs)?,
};
// 5. up
let mut xs = xs;
let mut upsample_size = None;
for (i, up_block) in self.up_blocks.iter().enumerate() {
let n_resnets = match up_block {
UNetUpBlock::Basic(b) => b.resnets.len(),
UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
};
let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
if i < n_blocks - 1 && forward_upsample_size {
let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
upsample_size = Some((h, w))
}
xs = match up_block {
UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
UNetUpBlock::CrossAttn(b) => b.forward(
&xs,
&res_xs,
Some(&emb),
upsample_size,
Some(encoder_hidden_states),
)?,
};
}
// 6. post-process
let xs = self.conv_norm_out.forward(&xs)?;
let xs = nn::ops::silu(&xs)?;
self.conv_out.forward(&xs)
}
}

View File

@ -0,0 +1,851 @@
#![allow(dead_code)]
//! 2D UNet Building Blocks
//!
use crate::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use crate::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
struct Downsample2D {
conv: Option<Conv2d>,
padding: usize,
span: tracing::Span,
}
impl Downsample2D {
fn new(
vs: nn::VarBuilder,
in_channels: usize,
use_conv: bool,
out_channels: usize,
padding: usize,
) -> Result<Self> {
let conv = if use_conv {
let config = nn::Conv2dConfig { stride: 2, padding };
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Some(conv)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "downsample2d");
Ok(Self {
conv,
padding,
span,
})
}
}
impl Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
match &self.conv {
None => xs.avg_pool2d((2, 2), (2, 2)),
Some(conv) => {
if self.padding == 0 {
let xs = xs
.pad_with_zeros(D::Minus1, 0, 1)?
.pad_with_zeros(D::Minus2, 0, 1)?;
conv.forward(&xs)
} else {
conv.forward(xs)
}
}
}
}
}
// This does not support the conv-transpose mode.
#[derive(Debug)]
struct Upsample2D {
conv: Conv2d,
span: tracing::Span,
}
impl Upsample2D {
fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
let config = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
let span = tracing::span!(tracing::Level::TRACE, "upsample2d");
Ok(Self { conv, span })
}
}
impl Upsample2D {
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = match size {
None => {
let (_bsize, _channels, h, w) = xs.dims4()?;
xs.upsample_nearest2d(2 * h, 2 * w)?
}
Some((h, w)) => xs.upsample_nearest2d(h, w)?,
};
self.conv.forward(&xs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct DownEncoderBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_downsample: bool,
pub downsample_padding: usize,
}
impl Default for DownEncoderBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_downsample: true,
downsample_padding: 1,
}
}
}
#[derive(Debug)]
pub struct DownEncoderBlock2D {
resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>,
span: tracing::Span,
pub config: DownEncoderBlock2DConfig,
}
impl DownEncoderBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: DownEncoderBlock2DConfig,
) -> Result<Self> {
let resnets: Vec<_> = {
let vs = vs.pp("resnets");
let conv_cfg = ResnetBlock2DConfig {
eps: config.resnet_eps,
out_channels: Some(out_channels),
groups: config.resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels: None,
..Default::default()
};
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
let downsampler = if config.add_downsample {
let downsample = Downsample2D::new(
vs.pp("downsamplers").pp("0"),
out_channels,
true,
out_channels,
config.downsample_padding,
)?;
Some(downsample)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "down-enc2d");
Ok(Self {
resnets,
downsampler,
span,
config,
})
}
}
impl DownEncoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)?
}
match &self.downsampler {
Some(downsampler) => downsampler.forward(&xs),
None => Ok(xs),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct UpDecoderBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_upsample: bool,
}
impl Default for UpDecoderBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_upsample: true,
}
}
}
#[derive(Debug)]
pub struct UpDecoderBlock2D {
resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>,
span: tracing::Span,
pub config: UpDecoderBlock2DConfig,
}
impl UpDecoderBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: UpDecoderBlock2DConfig,
) -> Result<Self> {
let resnets: Vec<_> = {
let vs = vs.pp("resnets");
let conv_cfg = ResnetBlock2DConfig {
out_channels: Some(out_channels),
eps: config.resnet_eps,
groups: config.resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels: None,
..Default::default()
};
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
let upsampler = if config.add_upsample {
let upsample =
Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
Some(upsample)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "up-dec2d");
Ok(Self {
resnets,
upsampler,
span,
config,
})
}
}
impl UpDecoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)?
}
match &self.upsampler {
Some(upsampler) => upsampler.forward(&xs, None),
None => Ok(xs),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct UNetMidBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: Option<usize>,
pub attn_num_head_channels: Option<usize>,
// attention_type "default"
pub output_scale_factor: f64,
}
impl Default for UNetMidBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: Some(32),
attn_num_head_channels: Some(1),
output_scale_factor: 1.,
}
}
}
#[derive(Debug)]
pub struct UNetMidBlock2D {
resnet: ResnetBlock2D,
attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
span: tracing::Span,
pub config: UNetMidBlock2DConfig,
}
impl UNetMidBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
temb_channels: Option<usize>,
config: UNetMidBlock2DConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let vs_attns = vs.pp("attentions");
let resnet_groups = config
.resnet_groups
.unwrap_or_else(|| usize::min(in_channels / 4, 32));
let resnet_cfg = ResnetBlock2DConfig {
eps: config.resnet_eps,
groups: resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels,
..Default::default()
};
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
let attn_cfg = AttentionBlockConfig {
num_head_channels: config.attn_num_head_channels,
num_groups: resnet_groups,
rescale_output_factor: config.output_scale_factor,
eps: config.resnet_eps,
};
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
let resnet = ResnetBlock2D::new(
vs_resnets.pp(&(index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
attn_resnets.push((attn, resnet))
}
let span = tracing::span!(tracing::Level::TRACE, "mid2d");
Ok(Self {
resnet,
attn_resnets,
span,
config,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs)?, temb)?
}
Ok(xs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct UNetMidBlock2DCrossAttnConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: Option<usize>,
pub attn_num_head_channels: usize,
// attention_type "default"
pub output_scale_factor: f64,
pub cross_attn_dim: usize,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for UNetMidBlock2DCrossAttnConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: Some(32),
attn_num_head_channels: 1,
output_scale_factor: 1.,
cross_attn_dim: 1280,
sliced_attention_size: None, // Sliced attention disabled
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub struct UNetMidBlock2DCrossAttn {
resnet: ResnetBlock2D,
attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
span: tracing::Span,
pub config: UNetMidBlock2DCrossAttnConfig,
}
impl UNetMidBlock2DCrossAttn {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
temb_channels: Option<usize>,
config: UNetMidBlock2DCrossAttnConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let vs_attns = vs.pp("attentions");
let resnet_groups = config
.resnet_groups
.unwrap_or_else(|| usize::min(in_channels / 4, 32));
let resnet_cfg = ResnetBlock2DConfig {
eps: config.resnet_eps,
groups: resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels,
..Default::default()
};
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
let n_heads = config.attn_num_head_channels;
let attn_cfg = SpatialTransformerConfig {
depth: 1,
num_groups: resnet_groups,
context_dim: Some(config.cross_attn_dim),
sliced_attention_size: config.sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
let attn = SpatialTransformer::new(
vs_attns.pp(&index.to_string()),
in_channels,
n_heads,
in_channels / n_heads,
attn_cfg,
)?;
let resnet = ResnetBlock2D::new(
vs_resnets.pp(&(index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
attn_resnets.push((attn, resnet))
}
let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d");
Ok(Self {
resnet,
attn_resnets,
span,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
}
Ok(xs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct DownBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
// resnet_time_scale_shift: "default"
// resnet_act_fn: "swish"
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_downsample: bool,
pub downsample_padding: usize,
}
impl Default for DownBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_downsample: true,
downsample_padding: 1,
}
}
}
#[derive(Debug)]
pub struct DownBlock2D {
resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>,
span: tracing::Span,
pub config: DownBlock2DConfig,
}
impl DownBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: DownBlock2DConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let resnet_cfg = ResnetBlock2DConfig {
out_channels: Some(out_channels),
eps: config.resnet_eps,
output_scale_factor: config.output_scale_factor,
temb_channels,
..Default::default()
};
let resnets = (0..config.num_layers)
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let downsampler = if config.add_downsample {
let downsampler = Downsample2D::new(
vs.pp("downsamplers").pp("0"),
out_channels,
true,
out_channels,
config.downsample_padding,
)?;
Some(downsampler)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "down2d");
Ok(Self {
resnets,
downsampler,
span,
config,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
let _enter = self.span.enter();
let mut xs = xs.clone();
let mut output_states = vec![];
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, temb)?;
output_states.push(xs.clone());
}
let xs = match &self.downsampler {
Some(downsampler) => {
let xs = downsampler.forward(&xs)?;
output_states.push(xs.clone());
xs
}
None => xs,
};
Ok((xs, output_states))
}
}
#[derive(Debug, Clone, Copy)]
pub struct CrossAttnDownBlock2DConfig {
pub downblock: DownBlock2DConfig,
pub attn_num_head_channels: usize,
pub cross_attention_dim: usize,
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for CrossAttnDownBlock2DConfig {
fn default() -> Self {
Self {
downblock: Default::default(),
attn_num_head_channels: 1,
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub struct CrossAttnDownBlock2D {
downblock: DownBlock2D,
attentions: Vec<SpatialTransformer>,
span: tracing::Span,
pub config: CrossAttnDownBlock2DConfig,
}
impl CrossAttnDownBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: CrossAttnDownBlock2DConfig,
) -> Result<Self> {
let downblock = DownBlock2D::new(
vs.clone(),
in_channels,
out_channels,
temb_channels,
config.downblock,
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
depth: 1,
context_dim: Some(config.cross_attention_dim),
num_groups: config.downblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let vs_attn = vs.pp("attentions");
let attentions = (0..config.downblock.num_layers)
.map(|i| {
SpatialTransformer::new(
vs_attn.pp(&i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
cfg,
)
})
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "xa-down2d");
Ok(Self {
downblock,
attentions,
span,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<(Tensor, Vec<Tensor>)> {
let _enter = self.span.enter();
let mut output_states = vec![];
let mut xs = xs.clone();
for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
xs = resnet.forward(&xs, temb)?;
xs = attn.forward(&xs, encoder_hidden_states)?;
output_states.push(xs.clone());
}
let xs = match &self.downblock.downsampler {
Some(downsampler) => {
let xs = downsampler.forward(&xs)?;
output_states.push(xs.clone());
xs
}
None => xs,
};
Ok((xs, output_states))
}
}
#[derive(Debug, Clone, Copy)]
pub struct UpBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
// resnet_time_scale_shift: "default"
// resnet_act_fn: "swish"
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_upsample: bool,
}
impl Default for UpBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_upsample: true,
}
}
}
#[derive(Debug)]
pub struct UpBlock2D {
pub resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>,
span: tracing::Span,
pub config: UpBlock2DConfig,
}
impl UpBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
prev_output_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: UpBlock2DConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let resnet_cfg = ResnetBlock2DConfig {
out_channels: Some(out_channels),
temb_channels,
eps: config.resnet_eps,
output_scale_factor: config.output_scale_factor,
..Default::default()
};
let resnets = (0..config.num_layers)
.map(|i| {
let res_skip_channels = if i == config.num_layers - 1 {
in_channels
} else {
out_channels
};
let resnet_in_channels = if i == 0 {
prev_output_channels
} else {
out_channels
};
let in_channels = resnet_in_channels + res_skip_channels;
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let upsampler = if config.add_upsample {
let upsampler =
Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
Some(upsampler)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "up2d");
Ok(Self {
resnets,
upsampler,
span,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
res_xs: &[Tensor],
temb: Option<&Tensor>,
upsample_size: Option<(usize, usize)>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for (index, resnet) in self.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
xs = resnet.forward(&xs, temb)?;
}
match &self.upsampler {
Some(upsampler) => upsampler.forward(&xs, upsample_size),
None => Ok(xs),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct CrossAttnUpBlock2DConfig {
pub upblock: UpBlock2DConfig,
pub attn_num_head_channels: usize,
pub cross_attention_dim: usize,
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for CrossAttnUpBlock2DConfig {
fn default() -> Self {
Self {
upblock: Default::default(),
attn_num_head_channels: 1,
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub struct CrossAttnUpBlock2D {
pub upblock: UpBlock2D,
pub attentions: Vec<SpatialTransformer>,
span: tracing::Span,
pub config: CrossAttnUpBlock2DConfig,
}
impl CrossAttnUpBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
prev_output_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: CrossAttnUpBlock2DConfig,
) -> Result<Self> {
let upblock = UpBlock2D::new(
vs.clone(),
in_channels,
prev_output_channels,
out_channels,
temb_channels,
config.upblock,
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
depth: 1,
context_dim: Some(config.cross_attention_dim),
num_groups: config.upblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let vs_attn = vs.pp("attentions");
let attentions = (0..config.upblock.num_layers)
.map(|i| {
SpatialTransformer::new(
vs_attn.pp(&i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
cfg,
)
})
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "xa-up2d");
Ok(Self {
upblock,
attentions,
span,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
res_xs: &[Tensor],
temb: Option<&Tensor>,
upsample_size: Option<(usize, usize)>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
xs = resnet.forward(&xs, temb)?;
xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
}
match &self.upblock.upsampler {
Some(upsampler) => upsampler.forward(&xs, upsample_size),
None => Ok(xs),
}
}
}

View File

@ -0,0 +1,57 @@
use candle::{Device, Result, Tensor};
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
if steps < 1 {
candle::bail!("cannot use linspace with steps {steps} <= 1")
}
let delta = (stop - start) / (steps - 1) as f64;
let vs = (0..steps)
.map(|step| start + step as f64 * delta)
.collect::<Vec<_>>();
Tensor::from_vec(vs, steps, &Device::Cpu)
}
/// Saves an image to disk using the image crate, this expects an input with shape
/// (c, width, height).
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
let p = p.as_ref();
let (channel, width, height) = img.dims3()?;
if channel != 3 {
candle::bail!("save_image expects an input of shape (3, width, height)")
}
let img = img.transpose(0, 1)?.t()?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle::bail!("error saving image {p:?}"),
};
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}
// Wrap the conv2d op to provide some tracing.
#[derive(Debug)]
pub struct Conv2d {
inner: candle_nn::Conv2d,
span: tracing::Span,
}
impl Conv2d {
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
pub fn conv2d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: candle_nn::Conv2dConfig,
vs: candle_nn::VarBuilder,
) -> Result<Conv2d> {
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
Ok(Conv2d { inner, span })
}

View File

@ -0,0 +1,378 @@
#![allow(dead_code)]
//! # Variational Auto-Encoder (VAE) Models.
//!
//! Auto-encoder models compress their input to a usually smaller latent space
//! before expanding it back to its original shape. This results in the latent values
//! compressing the original information.
use crate::unet_2d_blocks::{
DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
UpDecoderBlock2D, UpDecoderBlock2DConfig,
};
use candle::{Result, Tensor};
use candle_nn as nn;
#[derive(Debug, Clone)]
struct EncoderConfig {
// down_block_types: DownEncoderBlock2D
block_out_channels: Vec<usize>,
layers_per_block: usize,
norm_num_groups: usize,
double_z: bool,
}
impl Default for EncoderConfig {
fn default() -> Self {
Self {
block_out_channels: vec![64],
layers_per_block: 2,
norm_num_groups: 32,
double_z: true,
}
}
}
#[derive(Debug)]
struct Encoder {
conv_in: nn::Conv2d,
down_blocks: Vec<DownEncoderBlock2D>,
mid_block: UNetMidBlock2D,
conv_norm_out: nn::GroupNorm,
conv_out: nn::Conv2d,
#[allow(dead_code)]
config: EncoderConfig,
}
impl Encoder {
fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: EncoderConfig,
) -> Result<Self> {
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 1,
};
let conv_in = nn::conv2d(
in_channels,
config.block_out_channels[0],
3,
conv_cfg,
vs.pp("conv_in"),
)?;
let mut down_blocks = vec![];
let vs_down_blocks = vs.pp("down_blocks");
for index in 0..config.block_out_channels.len() {
let out_channels = config.block_out_channels[index];
let in_channels = if index > 0 {
config.block_out_channels[index - 1]
} else {
config.block_out_channels[0]
};
let is_final = index + 1 == config.block_out_channels.len();
let cfg = DownEncoderBlock2DConfig {
num_layers: config.layers_per_block,
resnet_eps: 1e-6,
resnet_groups: config.norm_num_groups,
add_downsample: !is_final,
downsample_padding: 0,
..Default::default()
};
let down_block = DownEncoderBlock2D::new(
vs_down_blocks.pp(&index.to_string()),
in_channels,
out_channels,
cfg,
)?;
down_blocks.push(down_block)
}
let last_block_out_channels = *config.block_out_channels.last().unwrap();
let mid_cfg = UNetMidBlock2DConfig {
resnet_eps: 1e-6,
output_scale_factor: 1.,
attn_num_head_channels: None,
resnet_groups: Some(config.norm_num_groups),
..Default::default()
};
let mid_block =
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
last_block_out_channels,
1e-6,
vs.pp("conv_norm_out"),
)?;
let conv_out_channels = if config.double_z {
2 * out_channels
} else {
out_channels
};
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_out = nn::conv2d(
last_block_out_channels,
conv_out_channels,
3,
conv_cfg,
vs.pp("conv_out"),
)?;
Ok(Self {
conv_in,
down_blocks,
mid_block,
conv_norm_out,
conv_out,
config,
})
}
}
impl Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.conv_in.forward(xs)?;
for down_block in self.down_blocks.iter() {
xs = down_block.forward(&xs)?
}
let xs = self.mid_block.forward(&xs, None)?;
let xs = self.conv_norm_out.forward(&xs)?;
let xs = nn::ops::silu(&xs)?;
self.conv_out.forward(&xs)
}
}
#[derive(Debug, Clone)]
struct DecoderConfig {
// up_block_types: UpDecoderBlock2D
block_out_channels: Vec<usize>,
layers_per_block: usize,
norm_num_groups: usize,
}
impl Default for DecoderConfig {
fn default() -> Self {
Self {
block_out_channels: vec![64],
layers_per_block: 2,
norm_num_groups: 32,
}
}
}
#[derive(Debug)]
struct Decoder {
conv_in: nn::Conv2d,
up_blocks: Vec<UpDecoderBlock2D>,
mid_block: UNetMidBlock2D,
conv_norm_out: nn::GroupNorm,
conv_out: nn::Conv2d,
#[allow(dead_code)]
config: DecoderConfig,
}
impl Decoder {
fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: DecoderConfig,
) -> Result<Self> {
let n_block_out_channels = config.block_out_channels.len();
let last_block_out_channels = *config.block_out_channels.last().unwrap();
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 1,
};
let conv_in = nn::conv2d(
in_channels,
last_block_out_channels,
3,
conv_cfg,
vs.pp("conv_in"),
)?;
let mid_cfg = UNetMidBlock2DConfig {
resnet_eps: 1e-6,
output_scale_factor: 1.,
attn_num_head_channels: None,
resnet_groups: Some(config.norm_num_groups),
..Default::default()
};
let mid_block =
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
let mut up_blocks = vec![];
let vs_up_blocks = vs.pp("up_blocks");
let reversed_block_out_channels: Vec<_> =
config.block_out_channels.iter().copied().rev().collect();
for index in 0..n_block_out_channels {
let out_channels = reversed_block_out_channels[index];
let in_channels = if index > 0 {
reversed_block_out_channels[index - 1]
} else {
reversed_block_out_channels[0]
};
let is_final = index + 1 == n_block_out_channels;
let cfg = UpDecoderBlock2DConfig {
num_layers: config.layers_per_block + 1,
resnet_eps: 1e-6,
resnet_groups: config.norm_num_groups,
add_upsample: !is_final,
..Default::default()
};
let up_block = UpDecoderBlock2D::new(
vs_up_blocks.pp(&index.to_string()),
in_channels,
out_channels,
cfg,
)?;
up_blocks.push(up_block)
}
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
config.block_out_channels[0],
1e-6,
vs.pp("conv_norm_out"),
)?;
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_out = nn::conv2d(
config.block_out_channels[0],
out_channels,
3,
conv_cfg,
vs.pp("conv_out"),
)?;
Ok(Self {
conv_in,
up_blocks,
mid_block,
conv_norm_out,
conv_out,
config,
})
}
}
impl Decoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?;
for up_block in self.up_blocks.iter() {
xs = up_block.forward(&xs)?
}
let xs = self.conv_norm_out.forward(&xs)?;
let xs = nn::ops::silu(&xs)?;
self.conv_out.forward(&xs)
}
}
#[derive(Debug, Clone)]
pub struct AutoEncoderKLConfig {
pub block_out_channels: Vec<usize>,
pub layers_per_block: usize,
pub latent_channels: usize,
pub norm_num_groups: usize,
}
impl Default for AutoEncoderKLConfig {
fn default() -> Self {
Self {
block_out_channels: vec![64],
layers_per_block: 1,
latent_channels: 4,
norm_num_groups: 32,
}
}
}
pub struct DiagonalGaussianDistribution {
mean: Tensor,
std: Tensor,
}
impl DiagonalGaussianDistribution {
pub fn new(parameters: &Tensor) -> Result<Self> {
let mut parameters = parameters.chunk(2, 1)?.into_iter();
let mean = parameters.next().unwrap();
let logvar = parameters.next().unwrap();
let std = (logvar * 0.5)?.exp()?;
Ok(DiagonalGaussianDistribution { mean, std })
}
pub fn sample(&self) -> Result<Tensor> {
let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device());
&self.mean + &self.std * sample
}
}
// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485
// This implementation is specific to the config used in stable-diffusion-v1-5
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
#[derive(Debug)]
pub struct AutoEncoderKL {
encoder: Encoder,
decoder: Decoder,
quant_conv: nn::Conv2d,
post_quant_conv: nn::Conv2d,
pub config: AutoEncoderKLConfig,
}
impl AutoEncoderKL {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: AutoEncoderKLConfig,
) -> Result<Self> {
let latent_channels = config.latent_channels;
let encoder_cfg = EncoderConfig {
block_out_channels: config.block_out_channels.clone(),
layers_per_block: config.layers_per_block,
norm_num_groups: config.norm_num_groups,
double_z: true,
};
let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?;
let decoder_cfg = DecoderConfig {
block_out_channels: config.block_out_channels.clone(),
layers_per_block: config.layers_per_block,
norm_num_groups: config.norm_num_groups,
};
let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
let conv_cfg = Default::default();
let quant_conv = nn::conv2d(
2 * latent_channels,
2 * latent_channels,
1,
conv_cfg,
vs.pp("quant_conv"),
)?;
let post_quant_conv = nn::conv2d(
latent_channels,
latent_channels,
1,
conv_cfg,
vs.pp("post_quant_conv"),
)?;
Ok(Self {
encoder,
decoder,
quant_conv,
post_quant_conv,
config,
})
}
/// Returns the distribution in the latent space.
pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
let xs = self.encoder.forward(xs)?;
let parameters = self.quant_conv.forward(&xs)?;
DiagonalGaussianDistribution::new(&parameters)
}
/// Takes as input some sampled values.
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.post_quant_conv.forward(xs)?;
self.decoder.forward(&xs)
}
}

View File

@ -1,18 +1,18 @@
#![allow(dead_code)]
// https://github.com/openai/whisper/blob/main/whisper/model.py
// https://github.com/openai/whisper/blob/main/whisper/model.py/rgs
// TODO:
// - kv-cache support?
// - Language detection?
// - Batch size greater than 1.
// - More token filters (SuppressBlanks, ApplyTimestampRules).
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{safetensors::Load, DType, Device, Tensor};
use candle::{DType, Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
@ -20,6 +20,7 @@ use tokenizers::Tokenizer;
mod audio;
mod model;
use model::{Config, Whisper};
mod multilingual;
const DTYPE: DType = DType::F32;
@ -31,9 +32,6 @@ const HOP_LENGTH: usize = 160;
const CHUNK_LENGTH: usize = 30;
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2
const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame
const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token
const NO_SPEECH_THRESHOLD: f64 = 0.6;
const LOGPROB_THRESHOLD: f64 = -1.0;
@ -41,21 +39,12 @@ const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits.
const SOT_TOKEN: u32 = 50257;
const EOT_TOKEN: u32 = 50256;
const NO_SPEECH_TOKEN: u32 = 50361;
const NO_TIMESTAMP_TOKEN: u32 = 50362;
// From the _get_suppress_tokens function + 50362 (no timestamp)
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
const SUPPRESS_TOKENS: [u32; 91] = [
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
];
const SOT_TOKEN: &str = "<|startoftranscript|>";
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct DecodingResult {
tokens: Vec<u32>,
@ -66,6 +55,7 @@ struct DecodingResult {
compression_ratio: f64,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct Segment {
start: f64,
@ -78,13 +68,24 @@ struct Decoder {
rng: rand::rngs::StdRng,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
transcribe_token: u32,
eot_token: u32,
no_speech_token: u32,
language_token: Option<u32>,
}
impl Decoder {
fn new(model: Whisper, tokenizer: Tokenizer, seed: u64, device: &Device) -> Result<Self> {
fn new(
model: Whisper,
tokenizer: Tokenizer,
seed: u64,
device: &Device,
language_token: Option<u32>,
) -> Result<Self> {
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
.map(|i| {
if SUPPRESS_TOKENS.contains(&i) {
if model.config.suppress_tokens.contains(&i) {
f32::NEG_INFINITY
} else {
0f32
@ -92,43 +93,59 @@ impl Decoder {
})
.collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
tokenizer,
suppress_tokens,
sot_token,
transcribe_token,
eot_token,
no_speech_token,
language_token,
})
}
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
let model = &self.model;
let audio_features = model.encoder.forward(mel)?;
let model = &mut self.model;
let audio_features = model.encoder.forward(mel, true)?;
println!("audio features: {:?}", audio_features.dims());
let sample_len = model.config.max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![SOT_TOKEN];
let mut tokens = vec![self.sot_token];
if let Some(language_token) = self.language_token {
tokens.push(language_token)
}
tokens.push(self.transcribe_token);
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
let logits = logits.squeeze(0)?;
let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(NO_SPEECH_TOKEN as usize)?
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
no_speech_prob = softmax(&logits, 0)?
.i(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
}
let (seq_len, _) = logits.dims2()?;
let logits = logits
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
let (_, seq_len, _) = ys.dims3()?;
let logits = model
.decoder
.final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
let logits = logits.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
@ -145,17 +162,14 @@ impl Decoder {
};
tokens.push(next_token);
let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.i(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
break;
}
sum_logprob += prob.ln();
}
let text = self
.tokenizer
.decode(tokens.clone(), true)
.map_err(E::msg)?;
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
let avg_logprob = sum_logprob / tokens.len() as f64;
Ok(DecodingResult {
@ -219,6 +233,44 @@ impl Decoder {
}
}
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
match tokenizer.token_to_id(token) {
None => candle::bail!("no token-id for {token}"),
Some(id) => Ok(id),
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel {
Tiny,
TinyEn,
Base,
BaseEn,
SmallEn,
MediumEn,
LargeV2,
}
impl WhichModel {
fn is_multilingual(&self) -> bool {
match self {
Self::Tiny | Self::Base | Self::LargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
}
}
fn model_and_revision(&self) -> (&'static str, &'static str) {
match self {
Self::Tiny => ("openai/whisper-tiny", "main"),
Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
Self::Base => ("openai/whisper-base", "refs/pr/22"),
Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"),
Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
Self::MediumEn => ("openai/whisper-medium.en", "refs/pr/11"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -234,6 +286,10 @@ struct Args {
#[arg(long)]
revision: Option<String>,
/// The model to be used, can be tiny, small, medium.
#[arg(long, default_value = "tiny-en")]
model: WhichModel,
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
/// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following
/// repo: https://huggingface.co/datasets/Narsil/candle_demo/
@ -244,20 +300,33 @@ struct Args {
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The mel filters in safetensors format.
#[arg(
long,
default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
)]
filters: String,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Language.
#[arg(long)]
language: Option<String>,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?;
let default_model = "openai/whisper-tiny.en".to_string();
let (default_model, default_revision) = args.model.model_and_revision();
let default_model = default_model.to_string();
let default_revision = default_revision.to_string();
let path = std::path::PathBuf::from(default_model.clone());
let default_revision = "refs/pr/15".to_string();
let (model_id, revision) = match (args.model_id, args.revision) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
@ -301,11 +370,9 @@ fn main() -> Result<()> {
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
let mel_filters = mel_filters.deserialize()?;
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
println!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let mel_bytes = include_bytes!("melfilters.bytes");
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;
@ -328,8 +395,20 @@ fn main() -> Result<()> {
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let model = Whisper::load(&vb, config)?;
let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?;
let mut model = Whisper::load(&vb, config)?;
let language_token = match (args.model.is_multilingual(), args.language) {
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),
(false, None) => None,
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
Ok(token_id) => Some(token_id),
Err(_) => anyhow::bail!("language {language} is not supported"),
},
(false, Some(_)) => {
anyhow::bail!("a language cannot be set for non-multilingual models")
}
};
let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?;
dc.run(&mel)?;
Ok(())
}

View File

@ -1,8 +1,5 @@
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
use candle::{Device, IndexOp, Result, Tensor};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
@ -19,10 +16,21 @@ pub struct Config {
// pub n_text_state: usize,
pub decoder_attention_heads: usize, // n_text_head
pub decoder_layers: usize, // n_text_layer
pub suppress_tokens: Vec<u32>,
}
impl Config {
#[allow(dead_code)]
pub fn tiny_en() -> Self {
let suppress_tokens = vec![
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93,
357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391,
1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329,
7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306,
16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409,
34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361,
50362,
];
Self {
num_mel_bins: 80,
vocab_size: 51864,
@ -34,6 +42,7 @@ impl Config {
// n_text_state: 384,
decoder_attention_heads: 6,
decoder_layers: 4,
suppress_tokens,
}
}
}
@ -42,16 +51,32 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
//
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn conv1d(
@ -66,32 +91,6 @@ fn conv1d(
Ok(Conv1d::new(weight, Some(bias), config))
}
fn conv1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
Ok(Conv1d::new(weight, None, config))
}
struct Dropout {
pr: f64,
}
impl Dropout {
fn new(pr: f64) -> Self {
Self { pr }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// TODO
Ok(x.clone())
}
}
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
@ -105,10 +104,17 @@ struct MultiHeadAttention {
value: Linear,
out: Linear,
n_head: usize,
span: tracing::Span,
softmax_span: tracing::Span,
matmul_span: tracing::Span,
kv_cache: Option<(Tensor, Tensor)>,
}
impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
@ -119,13 +125,42 @@ impl MultiHeadAttention {
value,
out,
n_head,
span,
softmax_span,
matmul_span,
kv_cache: None,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
fn forward(
&mut self,
x: &Tensor,
xa: Option<&Tensor>,
mask: Option<&Tensor>,
flush_cache: bool,
) -> Result<Tensor> {
let _enter = self.span.enter();
let q = self.query.forward(x)?;
let k = self.key.forward(xa.unwrap_or(x))?;
let v = self.value.forward(xa.unwrap_or(x))?;
let (k, v) = match xa {
None => {
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
(k, v)
}
Some(x) => {
if flush_cache {
self.kv_cache = None;
}
if let Some((k, v)) = &self.kv_cache {
(k.clone(), v.clone())
} else {
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
self.kv_cache = Some((k.clone(), v.clone()));
(k, v)
}
}
};
let wv = self.qkv_attention(&q, &k, &v, mask)?;
let out = self.out.forward(&wv)?;
Ok(out)
@ -134,7 +169,7 @@ impl MultiHeadAttention {
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = x.dims3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
x.reshape(target_dims)?.transpose(1, 2)
}
fn qkv_attention(
@ -149,13 +184,24 @@ impl MultiHeadAttention {
let q = (self.reshape_head(q)? * scale)?;
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
let v = self.reshape_head(v)?.contiguous()?;
let mut qk = q.matmul(&k)?;
let mut qk = {
let _enter = self.matmul_span.enter();
q.matmul(&k)?
};
if let Some(mask) = mask {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
let mask = mask.i((0..n_ctx, 0..n_ctx))?;
qk = qk.broadcast_add(&mask)?
}
let w = softmax(&qk, candle::D::Minus1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
let w = {
let _enter = self.softmax_span.enter();
softmax(&qk, candle::D::Minus1)?
};
let wv = {
let _enter = self.matmul_span.enter();
w.matmul(&v)?
}
.transpose(1, 2)?
.flatten_from(2)?;
Ok(wv)
}
}
@ -168,10 +214,12 @@ struct ResidualAttentionBlock {
mlp_linear1: Linear,
mlp_linear2: Linear,
mlp_ln: LayerNorm,
span: tracing::Span,
}
impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
let cross_attn = if ca {
@ -192,14 +240,24 @@ impl ResidualAttentionBlock {
mlp_linear1,
mlp_linear2,
mlp_ln,
span,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
fn forward(
&mut self,
x: &Tensor,
xa: Option<&Tensor>,
mask: Option<&Tensor>,
flush_kv_cache: bool,
) -> Result<Tensor> {
let _enter = self.span.enter();
let attn = self
.attn
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
let mut x = (x + attn)?;
if let Some((attn, ln)) = &self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?;
if let Some((attn, ln)) = &mut self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
}
let mlp = self.mlp_linear2.forward(
&self
@ -207,7 +265,7 @@ impl ResidualAttentionBlock {
.forward(&self.mlp_ln.forward(&x)?)?
.gelu()?,
)?;
Ok((x + mlp)?)
x + mlp
}
}
@ -234,10 +292,16 @@ pub struct AudioEncoder {
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln_post: LayerNorm,
span: tracing::Span,
conv1_span: tracing::Span,
conv2_span: tracing::Span,
}
impl AudioEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
let n_state = cfg.d_model;
let n_head = cfg.encoder_attention_heads;
let n_ctx = cfg.max_source_positions;
@ -264,17 +328,28 @@ impl AudioEncoder {
positional_embedding,
blocks,
ln_post,
conv1_span,
conv2_span,
span,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.conv1.forward(x)?.gelu()?;
let x = self.conv2.forward(&x)?.gelu()?;
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
let _enter = self.span.enter();
let x = {
let _enter = self.conv1_span.enter();
self.conv1.forward(x)?.gelu()?
};
let x = {
let _enter = self.conv2_span.enter();
self.conv2.forward(&x)?.gelu()?
};
let x = x.transpose(1, 2)?;
let (_bsize, seq_len, _hidden) = x.dims3()?;
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
let mut x = x.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, None, None)?
for block in self.blocks.iter_mut() {
x = block.forward(&x, None, None, flush_kv_cache)?
}
let x = self.ln_post.forward(&x)?;
Ok(x)
@ -288,10 +363,14 @@ pub struct TextDecoder {
blocks: Vec<ResidualAttentionBlock>,
ln: LayerNorm,
mask: Tensor,
span: tracing::Span,
span_final: tracing::Span,
}
impl TextDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
@ -307,31 +386,37 @@ impl TextDecoder {
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
Ok(Self {
token_embedding,
positional_embedding,
blocks,
ln,
mask,
span,
span_final,
})
}
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
let _enter = self.span.enter();
let x_dims = x.dims();
let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?;
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, Some(xa), Some(&self.mask))?;
for block in self.blocks.iter_mut() {
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
}
let x = self.ln.forward(&x)?;
let w = self
.token_embedding
.embeddings()
.broadcast_left(x_dims[0])?;
let logits = x.matmul(&w.t()?)?;
self.ln.forward(&x)
}
pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
let b_size = x.dim(0)?;
let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
let logits = {
let _enter = self.span_final.enter();
x.matmul(&w.t()?)?
};
Ok(logits)
}
}
@ -353,10 +438,4 @@ impl Whisper {
config,
})
}
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
let enc = self.encoder.forward(mel)?;
let dec = self.decoder.forward(tokens, &enc)?;
Ok(dec)
}
}

View File

@ -0,0 +1,135 @@
use crate::Whisper;
use candle::{IndexOp, Result, Tensor, D};
use tokenizers::Tokenizer;
const LANGUAGES: [(&str, &str); 99] = [
("en", "english"),
("zh", "chinese"),
("de", "german"),
("es", "spanish"),
("ru", "russian"),
("ko", "korean"),
("fr", "french"),
("ja", "japanese"),
("pt", "portuguese"),
("tr", "turkish"),
("pl", "polish"),
("ca", "catalan"),
("nl", "dutch"),
("ar", "arabic"),
("sv", "swedish"),
("it", "italian"),
("id", "indonesian"),
("hi", "hindi"),
("fi", "finnish"),
("vi", "vietnamese"),
("he", "hebrew"),
("uk", "ukrainian"),
("el", "greek"),
("ms", "malay"),
("cs", "czech"),
("ro", "romanian"),
("da", "danish"),
("hu", "hungarian"),
("ta", "tamil"),
("no", "norwegian"),
("th", "thai"),
("ur", "urdu"),
("hr", "croatian"),
("bg", "bulgarian"),
("lt", "lithuanian"),
("la", "latin"),
("mi", "maori"),
("ml", "malayalam"),
("cy", "welsh"),
("sk", "slovak"),
("te", "telugu"),
("fa", "persian"),
("lv", "latvian"),
("bn", "bengali"),
("sr", "serbian"),
("az", "azerbaijani"),
("sl", "slovenian"),
("kn", "kannada"),
("et", "estonian"),
("mk", "macedonian"),
("br", "breton"),
("eu", "basque"),
("is", "icelandic"),
("hy", "armenian"),
("ne", "nepali"),
("mn", "mongolian"),
("bs", "bosnian"),
("kk", "kazakh"),
("sq", "albanian"),
("sw", "swahili"),
("gl", "galician"),
("mr", "marathi"),
("pa", "punjabi"),
("si", "sinhala"),
("km", "khmer"),
("sn", "shona"),
("yo", "yoruba"),
("so", "somali"),
("af", "afrikaans"),
("oc", "occitan"),
("ka", "georgian"),
("be", "belarusian"),
("tg", "tajik"),
("sd", "sindhi"),
("gu", "gujarati"),
("am", "amharic"),
("yi", "yiddish"),
("lo", "lao"),
("uz", "uzbek"),
("fo", "faroese"),
("ht", "haitian creole"),
("ps", "pashto"),
("tk", "turkmen"),
("nn", "nynorsk"),
("mt", "maltese"),
("sa", "sanskrit"),
("lb", "luxembourgish"),
("my", "myanmar"),
("bo", "tibetan"),
("tl", "tagalog"),
("mg", "malagasy"),
("as", "assamese"),
("tt", "tatar"),
("haw", "hawaiian"),
("ln", "lingala"),
("ha", "hausa"),
("ba", "bashkir"),
("jw", "javanese"),
("su", "sundanese"),
];
/// Returns the token id for the selected language.
pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
let (_bsize, _, seq_len) = mel.dims3()?;
let mel = mel.narrow(2, 0, usize::min(seq_len, model.config.max_source_positions))?;
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
.collect::<Result<Vec<_>>>()?;
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
let audio_features = model.encoder.forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
let logits = model
.decoder
.forward(&tokens, &audio_features, true)?
.i(0)?
.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for ((_, language), p) in probs.iter().take(5) {
println!("{language}: {p}")
}
let language = crate::token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?;
Ok(language)
}

View File

@ -11,3 +11,102 @@ pub fn device(cpu: bool) -> Result<Device> {
Ok(device)
}
}
#[cfg(test)]
mod tests {
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
#[rustfmt::skip]
#[tokio::test]
async fn book_hub_1() {
// ANCHOR: book_hub_1
use candle::Device;
use hf_hub::api::tokio::Api;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").await.unwrap();
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_1
assert_eq!(weights.len(), 206);
}
#[rustfmt::skip]
#[test]
fn book_hub_2() {
// ANCHOR: book_hub_2
use candle::Device;
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_2
assert_eq!(weights.len(), 206);
}
#[rustfmt::skip]
#[test]
fn book_hub_3() {
// ANCHOR: book_hub_3
use candle::{DType, Device, Tensor};
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use safetensors::slice::IndexOp;
use safetensors::SafeTensors;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
// Use safetensors directly
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
let view = tensors
.tensor("bert.encoder.layer.0.attention.self.query.weight")
.unwrap();
// We're going to load shard with rank 1, within a world_size of 4
// We're going to split along dimension 0 doing VIEW[start..stop, :]
let rank = 1;
let world_size = 4;
let dim = 0;
let dtype = view.dtype();
let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];
if size % world_size != 0 {
panic!("The dimension is not divisble by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
// Everything is expressed in tensor dimension
// bytes offsets is handled automatically for safetensors.
let iterator = view.slice(start..stop).unwrap();
tp_shape[dim] = block_size;
// Convert safetensors Dtype to candle DType
let dtype: DType = dtype.try_into().unwrap();
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_3
assert_eq!(view.shape(), &[768, 768]);
assert_eq!(tp_tensor.dims(), &[192, 768]);
}
}

View File

@ -1,17 +1,17 @@
[package]
name = "candle-flash-attn"
version = "0.1.0"
version = "0.1.1"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], version = "0.1.0", package = "candle-core" }
candle = { path = "../candle-core", features = ["cuda"], version = "0.1.1", package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
@ -21,4 +21,4 @@ rayon = "1.7.0"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", version = "0.1.0", features = ["cuda"] }
candle-nn = { path = "../candle-nn", version = "0.1.1", features = ["cuda"] }

View File

@ -88,6 +88,7 @@ fn main() -> Result<()> {
.map(|(cu_file, obj_file)| {
let mut command = std::process::Command::new("nvcc");
command
.arg("-std=c++17")
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("-c")
.args(["-o", obj_file.to_str().unwrap()])

View File

@ -1,13 +1,13 @@
[package]
name = "candle-kernels"
version = "0.1.0"
version = "0.1.1"
edition = "2021"
description = "CUDA kernels for Candle"
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
license = "MIT OR Apache-2.0"
[dependencies]

View File

@ -6,21 +6,12 @@
// FIXME: the minimum compute capabilities are just guesses since the table is not specific enough
// #if __CUDA_ARCH__ < 600
// __device__ __forceinline__ __half __hmax(__half a, __half b) {
// return __float2half(fmaxf(__half2float(a), __half2float(b)));
// }
// __device__ __forceinline__ __half __hmin(__half a, __half b) {
// return __float2half(fminf(__half2float(a), __half2float(b)));
// }
// #endif
#if __CUDA_ARCH__ < 800
#if (__CUDACC_VER_MAJOR__ < 12 || __CUDACC_VER_MINOR__ < 2) && __CUDA_ARCH__ < 800
__device__ __forceinline__ __half __hmax_nan(__half a, __half b) {
// return __hisnan(a) ? a : (__hisnan(b) ? b : __hmax(a, b));
return __hisnan(a) ? a : (__hisnan(b) ? b : __hmax(a, b));
}
__device__ __forceinline__ __half __hmin_nan(__half a, __half b) {
// return __hisnan(a) ? a : (__hisnan(b) ? b : __hmin(a, b));
return __hisnan(a) ? a : (__hisnan(b) ? b : __hmin(a, b));
}
#endif

Some files were not shown because too many files have changed in this diff Show More