Merge branch 'main' into book-trainin-simplified

This commit is contained in:
Evgeny Igumnov
2023-09-22 11:01:23 +06:00
committed by GitHub
195 changed files with 14759 additions and 1789 deletions

1
.gitignore vendored
View File

@ -23,6 +23,7 @@ flamegraph.svg
*.dylib *.dylib
*.so *.so
*.swp *.swp
*.swo
trace-*.json trace-*.json
candle-wasm-examples/*/build candle-wasm-examples/*/build

View File

@ -1,13 +1,58 @@
# Changelog # Changelog
This documents the main changes to the `candle` crate. This documents the main changes to the `candle` crate.
## v0.2.1 - Unreleased ## v0.2.3 - Unreleased
### Added ### Added
### Modified
## v0.2.2 - 2023-09-18
### Added
- Support for `top_p` sampling
[819](https://github.com/huggingface/candle/pull/819).
- T5 model including decoding
[864](https://github.com/huggingface/candle/pull/864).
- 1-d upsampling
[839](https://github.com/huggingface/candle/pull/839).
### Modified
- Bugfix for conv2d
[820](https://github.com/huggingface/candle/pull/820).
- Support tensor based indexing using `.i`
[842](https://github.com/huggingface/candle/pull/842).
## v0.2.1 - 2023-09-11
### Added
- Add some RNNs (GRU and LSTM) in `candle-nn`
[674](https://github.com/huggingface/candle/pull/674),
[688](https://github.com/huggingface/candle/pull/688).
- gguf v2 support
[725](https://github.com/huggingface/candle/pull/725).
- Quantized llama example in Python using the pyo3 api
[716](https://github.com/huggingface/candle/pull/716).
- `candle-nn` layer for conv2d-transposed
[760](https://github.com/huggingface/candle/pull/760).
- Add the Segment-Anything Model (SAM) as an example
[773](https://github.com/huggingface/candle/pull/773).
- TinyViT backbone for the segemnt anything example
[787](https://github.com/huggingface/candle/pull/787).
- Shape with holes support
[770](https://github.com/huggingface/candle/pull/770).
### Modified ### Modified
- Dilations are now supported in conv-transpose2d. - Dilations are now supported in conv-transpose2d.
[671](https://github.com/huggingface/candle/pull/671). [671](https://github.com/huggingface/candle/pull/671).
- Interactive mode for the quantized model
[690](https://github.com/huggingface/candle/pull/690).
- Faster softmax operation
[747](https://github.com/huggingface/candle/pull/747).
- Faster convolution operations on CPU and CUDA via im2col
[802](https://github.com/huggingface/candle/pull/802).
- Moving some models to a more central location
[796](https://github.com/huggingface/candle/pull/796).
## v0.2.0 - 2023-08-30 ## v0.2.0 - 2023-08-30

View File

@ -8,17 +8,16 @@ members = [
"candle-pyo3", "candle-pyo3",
"candle-transformers", "candle-transformers",
"candle-wasm-examples/llama2-c", "candle-wasm-examples/llama2-c",
"candle-wasm-examples/segment-anything",
"candle-wasm-examples/whisper", "candle-wasm-examples/whisper",
"candle-wasm-examples/yolo", "candle-wasm-examples/yolo",
"candle-wasm-examples/bert",
] ]
exclude = [ exclude = ["candle-flash-attn", "candle-kernels"]
"candle-flash-attn",
"candle-kernels",
]
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "0.2.1" version = "0.2.3"
edition = "2021" edition = "2021"
description = "Minimalist ML framework." description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle" repository = "https://github.com/huggingface/candle"
@ -33,7 +32,7 @@ byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.14", features = ["f16"] } cudarc = { version = "0.9.14", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up. # TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.15.6", package = "candle-gemm" } gemm = { version = "0.16.0", package = "candle-gemm" }
hf-hub = "0.3.0" hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] } image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }

115
README.md
View File

@ -8,7 +8,9 @@ Candle is a minimalist ML framework for Rust with a focus on performance (includ
and ease of use. Try our online demos: and ease of use. Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper), [whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2), [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
[yolo](https://huggingface.co/spaces/lmz/candle-yolo). [yolo](https://huggingface.co/spaces/lmz/candle-yolo),
[Segment
Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
## Get started ## Get started
@ -45,37 +47,54 @@ For more advanced examples, please have a look at the following section.
## Check out our examples ## Check out our examples
Check out our [examples](./candle-examples/examples/): These online demos run entirely in your browser:
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
object recognition.
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
We also provide a some command line based examples using state of the art models:
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM. - [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
- [Falcon](./candle-examples/examples/falcon/): general LLM. - [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code - [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
generation. generation.
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of - [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp). [llama.cpp](https://github.com/ggerganov/llama.cpp).
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
image generative model.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
- [yolo-v3](./candle-examples/examples/yolo-v3/) and - [yolo-v3](./candle-examples/examples/yolo-v3/) and
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose [yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
estimation models. estimation models.
Run them using the following commands:
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.od.jpg" width="200"><img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.pose.jpg" width="200">
- [segment-anything](./candle-examples/examples/segment-anything/): image
segmentation model with prompt.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
Run them using commands like:
``` ```
cargo run --example whisper --release
cargo run --example llama --release
cargo run --example falcon --release
cargo run --example bert --release
cargo run --example bigcode --release
cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch"
cargo run --example dinov2 --release -- --image path/to/myinput.jpg
cargo run --example quantized --release cargo run --example quantized --release
cargo run --example yolo-v3 --release -- myimage.jpg
cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose
``` ```
In order to use **CUDA** add `--features cuda` to the example command line. If In order to use **CUDA** add `--features cuda` to the example command line. If
@ -85,7 +104,8 @@ There are also some wasm examples for whisper and
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with [llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
`trunk` or try them online: `trunk` or try them online:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper), [whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[llama2](https://huggingface.co/spaces/lmz/candle-llama2). [llama2](https://huggingface.co/spaces/lmz/candle-llama2),
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
For LLaMA2, run the following command to retrieve the weight files and start a For LLaMA2, run the following command to retrieve the weight files and start a
test server: test server:
@ -98,6 +118,15 @@ trunk serve --release --port 8081
And then head over to And then head over to
[http://localhost:8081/](http://localhost:8081/). [http://localhost:8081/](http://localhost:8081/).
<!--- ANCHOR: useful_libraries --->
## Useful Libraries
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
If you have an addition to this list, please submit a pull request.
<!--- ANCHOR_END: useful_libraries --->
<!--- ANCHOR: features ---> <!--- ANCHOR: features --->
## Features ## Features
@ -110,10 +139,21 @@ And then head over to
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL. - CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
- WASM support, run your models in a browser. - WASM support, run your models in a browser.
- Included models. - Included models.
- LLMs: LLaMA v1 and v2, Falcon, StarCoder. - Language Models.
- LLaMA v1 and v2.
- Falcon.
- StarCoder.
- T5.
- Bert.
- Whisper (multi-lingual support). - Whisper (multi-lingual support).
- Stable Diffusion. - Stable Diffusion v1.5, v2.1, XL v1.0.
- Computer Vision: DINOv2, EfficientNet, yolo-v3, yolo-v8. - Wurstchen v2.
- Computer Vision Models.
- DINOv2.
- EfficientNet.
- yolo-v3.
- yolo-v8.
- Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files. - File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments. - Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types. - Quantization support using the llama.cpp quantized types.
@ -243,6 +283,35 @@ authentication token. See issue
git submodule update --init git submodule update --init
``` ```
#### Compiling with flash-attention fails
```
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ...:
```
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
```
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
```
#### Linking error on windows when running rustdoc or mdbook tests
```
Couldn't compile the test.
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
error: linking with `link.exe` failed: exit code: 1181
//very long chain of linking
= note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'
```
Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:
```
mdbook test candle-book -L .\target\debug\deps\ `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
```
#### Tracking down errors #### Tracking down errors
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle

View File

@ -11,11 +11,11 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.2.1" } candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
candle-nn = { path = "../candle-nn", version = "0.2.1" } candle-nn = { path = "../candle-nn", version = "0.2.3" }
candle-transformers = { path = "../candle-transformers", version = "0.2.1" } candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true } candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }

View File

@ -10,10 +10,10 @@
# Reference Guide # Reference Guide
- [Running a model](inference/README.md) - [Running a model](inference/inference.md)
- [Using the hub](inference/hub.md) - [Using the hub](inference/hub.md)
- [Error management](error_manage.md) - [Error management](error_manage.md)
- [Training](training/README.md) - [Training](training/training.md)
- [Simplified](training/simplified.md) - [Simplified](training/simplified.md)
- [MNIST](training/mnist.md) - [MNIST](training/mnist.md)
- [Fine-tuning]() - [Fine-tuning]()

View File

@ -29,7 +29,7 @@ After adding `RUST_BACKTRACE=1`:
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] } Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
``` ```
Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }` Not super pretty at the moment, but we can see error occurred on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces

View File

@ -6,7 +6,7 @@ Open `src/main.rs` and fill in this content:
```rust ```rust
# extern crate candle_core; # extern crate candle_core;
use candle_core::{DType, Device, Result, Tensor}; use candle_core::{Device, Result, Tensor};
struct Model { struct Model {
first: Tensor, first: Tensor,
@ -25,11 +25,11 @@ fn main() -> Result<()> {
// Use Device::new_cuda(0)?; to use the GPU. // Use Device::new_cuda(0)?; to use the GPU.
let device = Device::Cpu; let device = Device::Cpu;
let first = Tensor::zeros((784, 100), DType::F32, &device)?; let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let second = Tensor::zeros((100, 10), DType::F32, &device)?; let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let model = Model { first, second }; let model = Model { first, second };
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?; let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?; let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit"); println!("Digit {digit:?} digit");
@ -50,7 +50,7 @@ the classical `Linear` layer. We can do as such
```rust ```rust
# extern crate candle_core; # extern crate candle_core;
# use candle_core::{DType, Device, Result, Tensor}; # use candle_core::{Device, Result, Tensor};
struct Linear{ struct Linear{
weight: Tensor, weight: Tensor,
bias: Tensor, bias: Tensor,
@ -80,7 +80,7 @@ This will change the model running code into a new function
```rust ```rust
# extern crate candle_core; # extern crate candle_core;
# use candle_core::{DType, Device, Result, Tensor}; # use candle_core::{Device, Result, Tensor};
# struct Linear{ # struct Linear{
# weight: Tensor, # weight: Tensor,
# bias: Tensor, # bias: Tensor,
@ -110,15 +110,15 @@ fn main() -> Result<()> {
let device = Device::cuda_if_available(0)?; let device = Device::cuda_if_available(0)?;
// Creating a dummy model // Creating a dummy model
let weight = Tensor::zeros((784, 100), DType::F32, &device)?; let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let bias = Tensor::zeros((100, ), DType::F32, &device)?; let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear{weight, bias}; let first = Linear{weight, bias};
let weight = Tensor::zeros((100, 10), DType::F32, &device)?; let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let bias = Tensor::zeros((10, ), DType::F32, &device)?; let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear{weight, bias}; let second = Linear{weight, bias};
let model = Model { first, second }; let model = Model { first, second };
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?; let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
// Inference on the model // Inference on the model
let digit = model.forward(&dummy_image)?; let digit = model.forward(&dummy_image)?;
@ -146,7 +146,7 @@ And rewrite our examples using it
```rust ```rust
# extern crate candle_core; # extern crate candle_core;
# extern crate candle_nn; # extern crate candle_nn;
use candle_core::{DType, Device, Result, Tensor}; use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module}; use candle_nn::{Linear, Module};
struct Model { struct Model {
@ -167,15 +167,15 @@ fn main() -> Result<()> {
let device = Device::Cpu; let device = Device::Cpu;
// This has changed (784, 100) -> (100, 784) ! // This has changed (784, 100) -> (100, 784) !
let weight = Tensor::zeros((100, 784), DType::F32, &device)?; let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
let bias = Tensor::zeros((100, ), DType::F32, &device)?; let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear::new(weight, Some(bias)); let first = Linear::new(weight, Some(bias));
let weight = Tensor::zeros((10, 100), DType::F32, &device)?; let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
let bias = Tensor::zeros((10, ), DType::F32, &device)?; let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear::new(weight, Some(bias)); let second = Linear::new(weight, Some(bias));
let model = Model { first, second }; let model = Model { first, second };
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?; let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?; let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit"); println!("Digit {digit:?} digit");
@ -188,8 +188,8 @@ Feel free to modify this example to use `Conv2d` to create a classical convnet i
Now that we have the running dummy code we can get to more advanced topics: Now that we have the running dummy code we can get to more advanced topics:
- [For PyTorch users](./guide/cheatsheet.md) - [For PyTorch users](../guide/cheatsheet.md)
- [Running existing models](./inference/README.md) - [Running existing models](../inference/inference.md)
- [Training models](./training/README.md) - [Training models](../training/training.md)

View File

@ -12,7 +12,7 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true } byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.2.1", optional = true } candle-kernels = { path = "../candle-kernels", version = "0.2.3", optional = true }
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
gemm = { workspace = true } gemm = { workspace = true }
half = { workspace = true } half = { workspace = true }

View File

@ -1,166 +0,0 @@
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_core::quantized::GgmlType;
use candle_core::{Device, Result, Tensor, D};
use clap::{Parser, Subcommand};
fn softmax<D: candle_core::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
let dim = dim.to_index(xs.shape(), "softmax")?;
let max = xs.max_keepdim(dim)?;
let diff = xs.broadcast_sub(&max)?;
let num = diff.exp()?;
let den = num.sum_keepdim(dim)?;
num.broadcast_div(&den)
}
trait Benchmark {
type PreProcessData;
type RunResult;
fn preprocess() -> Result<Self::PreProcessData>;
fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>;
const ITERS: usize;
}
// Conv1d example as used in whisper.
struct Conv1d;
impl Benchmark for Conv1d {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
Ok((inp, w))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.conv1d(&d.1, 0, 1, 1, 1)
}
const ITERS: usize = 5;
}
// Conv2d example as used in stable-diffusion.
struct Conv2d;
impl Benchmark for Conv2d {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
Ok((inp, w))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.conv2d(&d.1, 0, 1, 1, 1)
}
const ITERS: usize = 1;
}
struct Matmul;
impl Benchmark for Matmul {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
Ok((lhs, rhs))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.matmul(&d.1)
}
const ITERS: usize = 100;
}
// This benchmark is similar to:
// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp
struct QMatMul;
impl Benchmark for QMatMul {
type PreProcessData = (candle_core::quantized::QMatMul, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let zeros = vec![candle_core::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
let mm = candle_core::quantized::QTensor::new(zeros, (4096, 11008))?;
let mm = candle_core::quantized::QMatMul::from_qtensor(mm);
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
Ok((mm, arg))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.forward(&d.1)
}
const ITERS: usize = 100;
}
struct Softmax;
impl Benchmark for Softmax {
type PreProcessData = Tensor;
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
// Typical whisper tiny size.
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
Ok(x)
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
softmax(d, D::Minus1)
}
const ITERS: usize = 100;
}
fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
use std::hint::black_box;
let iters = iters.unwrap_or(B::ITERS);
let d = B::preprocess()?;
let start = std::time::Instant::now();
for _iter in 0..iters {
let _res = black_box(B::run_one(black_box(&d))?);
}
println!("{:?}", start.elapsed() / iters as u32);
Ok(())
}
#[derive(Subcommand, Debug, Clone)]
enum Task {
Conv1d,
Conv2d,
Matmul,
Qmatmul,
Softmax,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// The benchmark to be run.
#[command(subcommand)]
task: Task,
#[arg(long)]
iters: Option<usize>,
}
fn main() -> Result<()> {
let args = Args::parse();
match args.task {
Task::Conv1d => run::<Conv1d>(args.iters)?,
Task::Conv2d => run::<Conv2d>(args.iters)?,
Task::Matmul => run::<Matmul>(args.iters)?,
Task::Softmax => run::<Softmax>(args.iters)?,
Task::Qmatmul => run::<QMatMul>(args.iters)?,
}
Ok(())
}

View File

@ -218,12 +218,65 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
Ok(()) Ok(())
} }
fn run_quantize_safetensors(
in_file: std::path::PathBuf,
out_file: std::path::PathBuf,
q: Quantization,
) -> Result<()> {
let mut out_file = std::fs::File::create(out_file)?;
let tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
println!("tensors: {}", tensors.len());
let quantize_fn = match q {
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
let qtensors = tensors
.into_par_iter()
.map(|(name, tensor)| {
println!(" quantizing {name} {tensor:?}");
let should_quantize = tensor.rank() == 2 && tensor.dim(0)? % 256 == 0;
let tensor = if should_quantize {
quantize_fn(&tensor)?
} else {
QTensor::quantize::<f32>(&tensor)?
};
Ok((name, tensor))
})
.collect::<Result<Vec<_>>>()?;
let qtensors = qtensors
.iter()
.map(|(k, v)| (k.as_str(), v))
.collect::<Vec<_>>();
gguf_file::write(&mut out_file, &[], &qtensors)?;
Ok(())
}
fn run_quantize( fn run_quantize(
in_file: std::path::PathBuf, in_file: std::path::PathBuf,
out_file: std::path::PathBuf, out_file: std::path::PathBuf,
q: Quantization, q: Quantization,
qmode: QuantizationMode, qmode: QuantizationMode,
) -> Result<()> { ) -> Result<()> {
if let Some(extension) = in_file.extension() {
if extension == "safetensors" {
return run_quantize_safetensors(in_file, out_file, q);
}
}
// Open the out file early so as to fail directly on missing directories etc. // Open the out file early so as to fail directly on missing directories etc.
let mut out_file = std::fs::File::create(out_file)?; let mut out_file = std::fs::File::create(out_file)?;
let mut in_ = std::fs::File::open(&in_file)?; let mut in_ = std::fs::File::open(&in_file)?;

View File

@ -370,6 +370,38 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a) y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
} }
#[inline]
pub fn vs_tanh_inplace(y: &mut [f32]) {
unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vs_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
#[inline]
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vd_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
macro_rules! binary_op { macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => { ($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline] #[inline]

View File

@ -57,6 +57,7 @@ pub trait BackendStorage: Sized {
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>; fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>; fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>; fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>; fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;

View File

@ -91,13 +91,14 @@ impl Tensor {
} }
} }
Op::Reshape(node) Op::Reshape(node)
| Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D(node) | Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. } | Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. }
| Op::Copy(node) | Op::Copy(node)
| Op::Broadcast(node) | Op::Broadcast(node)
| Op::Cmp(node, _) | Op::Cmp(node, _)
| Op::Reduce(node, _, _) | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDType(node) | Op::ToDType(node)
| Op::ToDevice(node) | Op::ToDevice(node)
| Op::Transpose(node, _, _) | Op::Transpose(node, _, _)
@ -111,6 +112,7 @@ impl Tensor {
track_grad |= tg; track_grad |= tg;
nodes nodes
} }
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
} }
} else { } else {
nodes nodes
@ -262,6 +264,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?; *sum_grad = sum_grad.add(&grad_arg)?;
} }
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest1d",
})?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d", op: "upsample-nearest2d",
})?, })?,
@ -437,6 +442,10 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)? *sum_grad = sum_grad.add(&arg_grad)?
} }
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?, Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
}
Op::Unary(arg, UnaryOp::Relu) => { Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?; let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
@ -517,6 +526,7 @@ impl Tensor {
} }
} }
#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>); pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore { impl GradStore {

763
candle-core/src/cpu/erf.rs Normal file
View File

@ -0,0 +1,763 @@
#![allow(clippy::excessive_precision)]
// Code taken from https://github.com/statrs-dev/statrs
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
//! related functions
mod evaluate {
//! Provides functions that don't have a numerical solution and must
//! be solved computationally (e.g. evaluation of a polynomial)
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
/// coeffecient
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
/// `2z^2 - z + 3`
///
/// # Remarks
///
/// Returns 0 for a 0 length coefficient slice
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
let n = coeff.len();
if n == 0 {
return 0.0;
}
let mut sum = *coeff.last().unwrap();
for c in coeff[0..n - 1].iter().rev() {
sum = *c + z * sum;
}
sum
}
}
use std::f64;
/// `erf` calculates the error function at `x`.
pub fn erf(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x >= 0.0 && x.is_infinite() {
1.0
} else if x <= 0.0 && x.is_infinite() {
-1.0
} else if x == 0. {
0.0
} else {
erf_impl(x, false)
}
}
/// `erf_inv` calculates the inverse error function
/// at `x`.
pub fn erf_inv(x: f64) -> f64 {
if x == 0.0 {
0.0
} else if x >= 1.0 {
f64::INFINITY
} else if x <= -1.0 {
f64::NEG_INFINITY
} else if x < 0.0 {
erf_inv_impl(-x, 1.0 + x, -1.0)
} else {
erf_inv_impl(x, 1.0 - x, 1.0)
}
}
/// `erfc` calculates the complementary error function
/// at `x`.
pub fn erfc(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x == f64::INFINITY {
0.0
} else if x == f64::NEG_INFINITY {
2.0
} else {
erf_impl(x, true)
}
}
/// `erfc_inv` calculates the complementary inverse
/// error function at `x`.
pub fn erfc_inv(x: f64) -> f64 {
if x <= 0.0 {
f64::INFINITY
} else if x >= 2.0 {
f64::NEG_INFINITY
} else if x > 1.0 {
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
} else {
erf_inv_impl(1.0 - x, x, 1.0)
}
}
// **********************************************************
// ********** Coefficients for erf_impl polynomial **********
// **********************************************************
/// Polynomial coefficients for a numerator of `erf_impl`
/// in the interval [1e-10, 0.5].
const ERF_IMPL_AN: &[f64] = &[
0.00337916709551257388990745,
-0.00073695653048167948530905,
-0.374732337392919607868241,
0.0817442448733587196071743,
-0.0421089319936548595203468,
0.0070165709512095756344528,
-0.00495091255982435110337458,
0.000871646599037922480317225,
];
/// Polynomial coefficients for a denominator of `erf_impl`
/// in the interval [1e-10, 0.5]
const ERF_IMPL_AD: &[f64] = &[
1.0,
-0.218088218087924645390535,
0.412542972725442099083918,
-0.0841891147873106755410271,
0.0655338856400241519690695,
-0.0120019604454941768171266,
0.00408165558926174048329689,
-0.000615900721557769691924509,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [0.5, 0.75].
const ERF_IMPL_BN: &[f64] = &[
-0.0361790390718262471360258,
0.292251883444882683221149,
0.281447041797604512774415,
0.125610208862766947294894,
0.0274135028268930549240776,
0.00250839672168065762786937,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [0.5, 0.75].
const ERF_IMPL_BD: &[f64] = &[
1.0,
1.8545005897903486499845,
1.43575803037831418074962,
0.582827658753036572454135,
0.124810476932949746447682,
0.0113724176546353285778481,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [0.75, 1.25].
const ERF_IMPL_CN: &[f64] = &[
-0.0397876892611136856954425,
0.153165212467878293257683,
0.191260295600936245503129,
0.10276327061989304213645,
0.029637090615738836726027,
0.0046093486780275489468812,
0.000307607820348680180548455,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [0.75, 1.25].
const ERF_IMPL_CD: &[f64] = &[
1.0,
1.95520072987627704987886,
1.64762317199384860109595,
0.768238607022126250082483,
0.209793185936509782784315,
0.0319569316899913392596356,
0.00213363160895785378615014,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [1.25, 2.25].
const ERF_IMPL_DN: &[f64] = &[
-0.0300838560557949717328341,
0.0538578829844454508530552,
0.0726211541651914182692959,
0.0367628469888049348429018,
0.00964629015572527529605267,
0.00133453480075291076745275,
0.778087599782504251917881e-4,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [1.25, 2.25].
const ERF_IMPL_DD: &[f64] = &[
1.0,
1.75967098147167528287343,
1.32883571437961120556307,
0.552528596508757581287907,
0.133793056941332861912279,
0.0179509645176280768640766,
0.00104712440019937356634038,
-0.106640381820357337177643e-7,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [2.25, 3.5].
const ERF_IMPL_EN: &[f64] = &[
-0.0117907570137227847827732,
0.014262132090538809896674,
0.0202234435902960820020765,
0.00930668299990432009042239,
0.00213357802422065994322516,
0.00025022987386460102395382,
0.120534912219588189822126e-4,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [2.25, 3.5].
const ERF_IMPL_ED: &[f64] = &[
1.0,
1.50376225203620482047419,
0.965397786204462896346934,
0.339265230476796681555511,
0.0689740649541569716897427,
0.00771060262491768307365526,
0.000371421101531069302990367,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [3.5, 5.25].
const ERF_IMPL_FN: &[f64] = &[
-0.00546954795538729307482955,
0.00404190278731707110245394,
0.0054963369553161170521356,
0.00212616472603945399437862,
0.000394984014495083900689956,
0.365565477064442377259271e-4,
0.135485897109932323253786e-5,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [3.5, 5.25].
const ERF_IMPL_FD: &[f64] = &[
1.0,
1.21019697773630784832251,
0.620914668221143886601045,
0.173038430661142762569515,
0.0276550813773432047594539,
0.00240625974424309709745382,
0.891811817251336577241006e-4,
-0.465528836283382684461025e-11,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [5.25, 8].
const ERF_IMPL_GN: &[f64] = &[
-0.00270722535905778347999196,
0.0013187563425029400461378,
0.00119925933261002333923989,
0.00027849619811344664248235,
0.267822988218331849989363e-4,
0.923043672315028197865066e-6,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [5.25, 8].
const ERF_IMPL_GD: &[f64] = &[
1.0,
0.814632808543141591118279,
0.268901665856299542168425,
0.0449877216103041118694989,
0.00381759663320248459168994,
0.000131571897888596914350697,
0.404815359675764138445257e-11,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [8, 11.5].
const ERF_IMPL_HN: &[f64] = &[
-0.00109946720691742196814323,
0.000406425442750422675169153,
0.000274499489416900707787024,
0.465293770646659383436343e-4,
0.320955425395767463401993e-5,
0.778286018145020892261936e-7,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [8, 11.5].
const ERF_IMPL_HD: &[f64] = &[
1.0,
0.588173710611846046373373,
0.139363331289409746077541,
0.0166329340417083678763028,
0.00100023921310234908642639,
0.24254837521587225125068e-4,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [11.5, 17].
const ERF_IMPL_IN: &[f64] = &[
-0.00056907993601094962855594,
0.000169498540373762264416984,
0.518472354581100890120501e-4,
0.382819312231928859704678e-5,
0.824989931281894431781794e-7,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [11.5, 17].
const ERF_IMPL_ID: &[f64] = &[
1.0,
0.339637250051139347430323,
0.043472647870310663055044,
0.00248549335224637114641629,
0.535633305337152900549536e-4,
-0.117490944405459578783846e-12,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [17, 24].
const ERF_IMPL_JN: &[f64] = &[
-0.000241313599483991337479091,
0.574224975202501512365975e-4,
0.115998962927383778460557e-4,
0.581762134402593739370875e-6,
0.853971555085673614607418e-8,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [17, 24].
const ERF_IMPL_JD: &[f64] = &[
1.0,
0.233044138299687841018015,
0.0204186940546440312625597,
0.000797185647564398289151125,
0.117019281670172327758019e-4,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [24, 38].
const ERF_IMPL_KN: &[f64] = &[
-0.000146674699277760365803642,
0.162666552112280519955647e-4,
0.269116248509165239294897e-5,
0.979584479468091935086972e-7,
0.101994647625723465722285e-8,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [24, 38].
const ERF_IMPL_KD: &[f64] = &[
1.0,
0.165907812944847226546036,
0.0103361716191505884359634,
0.000286593026373868366935721,
0.298401570840900340874568e-5,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [38, 60].
const ERF_IMPL_LN: &[f64] = &[
-0.583905797629771786720406e-4,
0.412510325105496173512992e-5,
0.431790922420250949096906e-6,
0.993365155590013193345569e-8,
0.653480510020104699270084e-10,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [38, 60].
const ERF_IMPL_LD: &[f64] = &[
1.0,
0.105077086072039915406159,
0.00414278428675475620830226,
0.726338754644523769144108e-4,
0.477818471047398785369849e-6,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [60, 85].
const ERF_IMPL_MN: &[f64] = &[
-0.196457797609229579459841e-4,
0.157243887666800692441195e-5,
0.543902511192700878690335e-7,
0.317472492369117710852685e-9,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [60, 85].
const ERF_IMPL_MD: &[f64] = &[
1.0,
0.052803989240957632204885,
0.000926876069151753290378112,
0.541011723226630257077328e-5,
0.535093845803642394908747e-15,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [85, 110].
const ERF_IMPL_NN: &[f64] = &[
-0.789224703978722689089794e-5,
0.622088451660986955124162e-6,
0.145728445676882396797184e-7,
0.603715505542715364529243e-10,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [85, 110].
const ERF_IMPL_ND: &[f64] = &[
1.0,
0.0375328846356293715248719,
0.000467919535974625308126054,
0.193847039275845656900547e-5,
];
// **********************************************************
// ********** Coefficients for erf_inv_impl polynomial ******
// **********************************************************
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0, 0.5].
const ERF_INV_IMPL_AN: &[f64] = &[
-0.000508781949658280665617,
-0.00836874819741736770379,
0.0334806625409744615033,
-0.0126926147662974029034,
-0.0365637971411762664006,
0.0219878681111168899165,
0.00822687874676915743155,
-0.00538772965071242932965,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0, 0.5].
const ERF_INV_IMPL_AD: &[f64] = &[
1.0,
-0.970005043303290640362,
-1.56574558234175846809,
1.56221558398423026363,
0.662328840472002992063,
-0.71228902341542847553,
-0.0527396382340099713954,
0.0795283687341571680018,
-0.00233393759374190016776,
0.000886216390456424707504,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.5, 0.75].
const ERF_INV_IMPL_BN: &[f64] = &[
-0.202433508355938759655,
0.105264680699391713268,
8.37050328343119927838,
17.6447298408374015486,
-18.8510648058714251895,
-44.6382324441786960818,
17.445385985570866523,
21.1294655448340526258,
-3.67192254707729348546,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.5, 0.75].
const ERF_INV_IMPL_BD: &[f64] = &[
1.0,
6.24264124854247537712,
3.9713437953343869095,
-28.6608180499800029974,
-20.1432634680485188801,
48.5609213108739935468,
10.8268667355460159008,
-22.6436933413139721736,
1.72114765761200282724,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x less than 3.
const ERF_INV_IMPL_CN: &[f64] = &[
-0.131102781679951906451,
-0.163794047193317060787,
0.117030156341995252019,
0.387079738972604337464,
0.337785538912035898924,
0.142869534408157156766,
0.0290157910005329060432,
0.00214558995388805277169,
-0.679465575181126350155e-6,
0.285225331782217055858e-7,
-0.681149956853776992068e-9,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x less than 3.
const ERF_INV_IMPL_CD: &[f64] = &[
1.0,
3.46625407242567245975,
5.38168345707006855425,
4.77846592945843778382,
2.59301921623620271374,
0.848854343457902036425,
0.152264338295331783612,
0.01105924229346489121,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 3 and 6.
const ERF_INV_IMPL_DN: &[f64] = &[
-0.0350353787183177984712,
-0.00222426529213447927281,
0.0185573306514231072324,
0.00950804701325919603619,
0.00187123492819559223345,
0.000157544617424960554631,
0.460469890584317994083e-5,
-0.230404776911882601748e-9,
0.266339227425782031962e-11,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 3 and 6.
const ERF_INV_IMPL_DD: &[f64] = &[
1.0,
1.3653349817554063097,
0.762059164553623404043,
0.220091105764131249824,
0.0341589143670947727934,
0.00263861676657015992959,
0.764675292302794483503e-4,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 6 and 18.
const ERF_INV_IMPL_EN: &[f64] = &[
-0.0167431005076633737133,
-0.00112951438745580278863,
0.00105628862152492910091,
0.000209386317487588078668,
0.149624783758342370182e-4,
0.449696789927706453732e-6,
0.462596163522878599135e-8,
-0.281128735628831791805e-13,
0.99055709973310326855e-16,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 6 and 18.
const ERF_INV_IMPL_ED: &[f64] = &[
1.0,
0.591429344886417493481,
0.138151865749083321638,
0.0160746087093676504695,
0.000964011807005165528527,
0.275335474764726041141e-4,
0.282243172016108031869e-6,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 18 and 44.
const ERF_INV_IMPL_FN: &[f64] = &[
-0.0024978212791898131227,
-0.779190719229053954292e-5,
0.254723037413027451751e-4,
0.162397777342510920873e-5,
0.396341011304801168516e-7,
0.411632831190944208473e-9,
0.145596286718675035587e-11,
-0.116765012397184275695e-17,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 18 and 44.
const ERF_INV_IMPL_FD: &[f64] = &[
1.0,
0.207123112214422517181,
0.0169410838120975906478,
0.000690538265622684595676,
0.145007359818232637924e-4,
0.144437756628144157666e-6,
0.509761276599778486139e-9,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x greater than 44.
const ERF_INV_IMPL_GN: &[f64] = &[
-0.000539042911019078575891,
-0.28398759004727721098e-6,
0.899465114892291446442e-6,
0.229345859265920864296e-7,
0.225561444863500149219e-9,
0.947846627503022684216e-12,
0.135880130108924861008e-14,
-0.348890393399948882918e-21,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x greater than 44.
const ERF_INV_IMPL_GD: &[f64] = &[
1.0,
0.0845746234001899436914,
0.00282092984726264681981,
0.468292921940894236786e-4,
0.399968812193862100054e-6,
0.161809290887904476097e-8,
0.231558608310259605225e-11,
];
/// `erf_impl` computes the error function at `z`.
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
fn erf_impl(z: f64, inv: bool) -> f64 {
if z < 0.0 {
if !inv {
return -erf_impl(-z, false);
}
if z < -0.5 {
return 2.0 - erf_impl(-z, true);
}
return 1.0 + erf_impl(-z, false);
}
let result = if z < 0.5 {
if z < 1e-10 {
z * 1.125 + z * 0.003379167095512573896158903121545171688
} else {
z * 1.125
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
}
} else if z < 110.0 {
let (r, b) = if z < 0.75 {
(
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
0.3440242112,
)
} else if z < 1.25 {
(
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
0.419990927,
)
} else if z < 2.25 {
(
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
0.4898625016,
)
} else if z < 3.5 {
(
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
0.5317370892,
)
} else if z < 5.25 {
(
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
0.5489973426,
)
} else if z < 8.0 {
(
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
0.5571740866,
)
} else if z < 11.5 {
(
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
0.5609807968,
)
} else if z < 17.0 {
(
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
0.5626493692,
)
} else if z < 24.0 {
(
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
0.5634598136,
)
} else if z < 38.0 {
(
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
0.5638477802,
)
} else if z < 60.0 {
(
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
0.5640528202,
)
} else if z < 85.0 {
(
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
0.5641309023,
)
} else {
(
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
0.5641584396,
)
};
let g = (-z * z).exp() / z;
g * b + g * r
} else {
0.0
};
if inv && z >= 0.5 {
result
} else if z >= 0.5 || inv {
1.0 - result
} else {
result
}
}
// `erf_inv_impl` computes the inverse error function where
// `p`,`q`, and `s` are the first, second, and third intermediate
// parameters respectively
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
let result = if p <= 0.5 {
let y = 0.0891314744949340820313;
let g = p * (p + 10.0);
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
g * y + g * r
} else if q >= 0.25 {
let y = 2.249481201171875;
let g = (-2.0 * q.ln()).sqrt();
let xs = q - 0.25;
let r =
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
g / (y + r)
} else {
let x = (-q.ln()).sqrt();
if x < 3.0 {
let y = 0.807220458984375;
let xs = x - 1.125;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
y * x + r * x
} else if x < 6.0 {
let y = 0.93995571136474609375;
let xs = x - 3.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
y * x + r * x
} else if x < 18.0 {
let y = 0.98362827301025390625;
let xs = x - 6.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
y * x + r * x
} else if x < 44.0 {
let y = 0.99714565277099609375;
let xs = x - 18.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
y * x + r * x
} else {
let y = 0.99941349029541015625;
let xs = x - 44.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
y * x + r * x
}
};
s * result
}

View File

@ -1,4 +1,7 @@
pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy { pub trait VecOps: num_traits::NumAssign + Copy {
fn min(self, rhs: Self) -> Self;
fn max(self, rhs: Self) -> Self;
/// Dot-product of two vectors. /// Dot-product of two vectors.
/// ///
/// # Safety /// # Safety
@ -37,10 +40,7 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) { unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs; *res = *xs;
for i in 1..len { for i in 1..len {
let x = *xs.add(i); *res = (*res).max(*xs.add(i))
if x > *res {
*res = x
}
} }
} }
@ -54,15 +54,22 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) { unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs; *res = *xs;
for i in 1..len { for i in 1..len {
let x = *xs.add(i); *res = (*res).min(*xs.add(i))
if x < *res {
*res = x
}
} }
} }
} }
impl VecOps for f32 { impl VecOps for f32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
#[inline(always)] #[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
super::vec_dot_f32(lhs, rhs, res, len) super::vec_dot_f32(lhs, rhs, res, len)
@ -75,6 +82,16 @@ impl VecOps for f32 {
} }
impl VecOps for half::f16 { impl VecOps for half::f16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
#[inline(always)] #[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32; let mut res_f32 = 0f32;
@ -83,11 +100,61 @@ impl VecOps for half::f16 {
} }
} }
impl VecOps for f64 {} impl VecOps for f64 {
impl VecOps for half::bf16 {} #[inline(always)]
impl VecOps for u8 {} fn min(self, other: Self) -> Self {
impl VecOps for u32 {} Self::min(self, other)
impl VecOps for i64 {} }
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}
impl VecOps for half::bf16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}
impl VecOps for u8 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for u32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
#[inline(always)] #[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) { pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {

View File

@ -1,3 +1,4 @@
pub mod erf;
pub mod kernels; pub mod kernels;
trait Cpu<const ARR: usize> { trait Cpu<const ARR: usize> {

View File

@ -2,6 +2,10 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16}; use half::{bf16, f16};
use rayon::prelude::*;
const USE_IM2COL_CONV1D: bool = true;
const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
// intercept the oom errors to avoid panicking and provide a proper error. // intercept the oom errors to avoid panicking and provide a proper error.
@ -445,7 +449,7 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
} }
// This function maps over two strided index sequences. // This function maps over two strided index sequences.
fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>( pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
lhs_l: &Layout, lhs_l: &Layout,
rhs_l: &Layout, rhs_l: &Layout,
lhs: &[T], lhs: &[T],
@ -525,7 +529,7 @@ fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
} }
// Similar to binary_map but with vectorized variants. // Similar to binary_map but with vectorized variants.
fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>( pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout, lhs_l: &Layout,
rhs_l: &Layout, rhs_l: &Layout,
lhs: &[T], lhs: &[T],
@ -723,6 +727,36 @@ impl Map1 for MaxPool2D {
} }
} }
struct UpsampleNearest1D(usize);
impl Map1 for UpsampleNearest1D {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
// TODO: Specialized implementation for the case 2*sz?
let dst_sz = self.0;
let (b_sz, c, src_sz) = layout.shape().dims3()?;
let stride = layout.stride();
let stride_sz = stride[2];
let src_index = layout.start_offset();
let scale_sz = src_sz as f64 / dst_sz as f64;
let mut dst = vec![T::zero(); b_sz * c * dst_sz];
let src_idxs = (0..dst_sz)
.map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
.collect::<Vec<_>>();
for b_idx in 0..b_sz {
let dst = &mut dst[b_idx * c * dst_sz..];
let src_index = src_index + b_idx * stride[0];
for c_idx in 0..c {
let dst = &mut dst[c_idx * dst_sz..];
let src_index = src_index + c_idx * stride[1];
for (idx, src_idx) in src_idxs.iter().enumerate() {
dst[idx] = src[src_index + src_idx * stride_sz]
}
}
}
Ok(dst)
}
}
struct UpsampleNearest2D(usize, usize); struct UpsampleNearest2D(usize, usize);
impl Map1 for UpsampleNearest2D { impl Map1 for UpsampleNearest2D {
@ -1052,10 +1086,8 @@ impl<'a> Map2 for Conv1D<'a> {
} }
} }
let num_threads = crate::utils::get_num_threads();
for offset in 0..p.k_size { for offset in 0..p.k_size {
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let dst_idx = dst_c_idx * l_out; let dst_idx = dst_c_idx * l_out;
let k_cont = (0..p.c_in) let k_cont = (0..p.c_in)
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2]) .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
@ -1090,6 +1122,140 @@ impl<'a> Map2 for Conv1D<'a> {
} }
} }
struct Im2Col1D {
l_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col1D {
fn l_out(&self, l: usize) -> usize {
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
}
}
impl Map1 for Im2Col1D {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
let &Self {
l_k,
stride,
dilation,
padding,
} = self;
let (b, c, l) = layout.shape().dims3()?;
let l_out = self.l_out(l);
let src = &vs[layout.start_offset()..];
let mut dst = vec![T::zero(); b * l_out * c * l_k];
let (src_s0, src_s1, src_s2) = {
let s = layout.stride();
(s[0], s[1], s[2])
};
// TODO: provide specialized kernels for the common use cases.
// - l_k = 1
// - padding = 0
// - stride = 1
// - dilation = 1
for b_idx in 0..b {
let src_idx = b_idx * src_s0;
let dst_idx = b_idx * l_out * c * l_k;
for l_idx in 0..l_out {
let dst_idx = dst_idx + l_idx * c * l_k;
for c_idx in 0..c {
let dst_idx = dst_idx + c_idx * l_k;
let src_idx = c_idx * src_s1 + src_idx;
for l_k_idx in 0..l_k {
let src_l = l_idx * stride + l_k_idx * dilation;
if padding != 0 && (src_l < padding || src_l >= l + padding) {
continue;
}
let src_l = src_l - padding;
let src_idx = src_idx + src_l * src_s2;
let dst_idx = dst_idx + l_k_idx;
dst[dst_idx] = src[src_idx]
}
}
}
}
Ok(dst)
}
}
struct Im2Col {
h_k: usize,
w_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col {
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
(h_out, w_out)
}
}
impl Map1 for Im2Col {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
let &Self {
h_k,
w_k,
stride,
dilation,
padding,
} = self;
let (b, c, h, w) = layout.shape().dims4()?;
let (h_out, w_out) = self.hw_out(h, w);
let src = &vs[layout.start_offset()..];
let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
let (src_s0, src_s1, src_s2, src_s3) = {
let s = layout.stride();
(s[0], s[1], s[2], s[3])
};
// TODO: provide specialized kernels for the common use cases.
// - h_k = w_k = 1
// - padding = 0
// - stride = 1
// - dilation = 1
for b_idx in 0..b {
let src_idx = b_idx * src_s0;
let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
for h_idx in 0..h_out {
let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
for w_idx in 0..w_out {
let dst_idx = dst_idx + w_idx * c * h_k * w_k;
for c_idx in 0..c {
let dst_idx = dst_idx + c_idx * h_k * w_k;
let src_idx = c_idx * src_s1 + src_idx;
for h_k_idx in 0..h_k {
let src_h = h_idx * stride + h_k_idx * dilation;
if padding != 0 && (src_h < padding || src_h >= h + padding) {
continue;
}
let src_h = src_h - padding;
let src_idx = src_idx + src_h * src_s2;
let dst_idx = dst_idx + h_k_idx * w_k;
for w_k_idx in 0..w_k {
let src_w = w_idx * stride + w_k_idx * dilation;
if padding != 0 && (src_w < padding || src_w >= w + padding) {
continue;
}
let src_w = src_w - padding;
let src_idx = src_idx + src_w * src_s3;
let dst_idx = dst_idx + w_k_idx;
dst[dst_idx] = src[src_idx]
}
}
}
}
}
}
Ok(dst)
}
}
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> { impl<'a> Map2 for Conv2D<'a> {
@ -1123,11 +1289,9 @@ impl<'a> Map2 for Conv2D<'a> {
} }
} }
let num_threads = crate::utils::get_num_threads();
for offset_h in 0..p.k_h { for offset_h in 0..p.k_h {
for offset_w in 0..p.k_w { for offset_w in 0..p.k_w {
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let dst_idx = dst_c_idx * out_w * out_h; let dst_idx = dst_c_idx * out_w * out_h;
let k_cont = (0..p.c_in) let k_cont = (0..p.c_in)
.map(|c_in_idx| { .map(|c_in_idx| {
@ -1216,11 +1380,10 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
} }
} }
} }
let num_threads = crate::utils::get_num_threads();
for k_y in 0..p.k_h { for k_y in 0..p.k_h {
for k_x in 0..p.k_w { for k_x in 0..p.k_w {
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in) let k_cont = (0..p.c_in)
.map(|c_in_idx| { .map(|c_in_idx| {
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3] k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
@ -1298,8 +1461,9 @@ impl Map2 for MatMul {
) -> Result<Vec<T>> { ) -> Result<Vec<T>> {
use gemm::{gemm, Parallelism}; use gemm::{gemm, Parallelism};
if T::DTYPE == DType::BF16 { match T::DTYPE {
return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?; DType::F16 | DType::F32 | DType::F64 => {}
_ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
} }
let (b, m, n, k) = self.0; let (b, m, n, k) = self.0;
@ -2003,6 +2167,10 @@ impl BackendStorage for CpuStorage {
MaxPool2D(kernel_size, stride).map(self, layout) MaxPool2D(kernel_size, stride).map(self, layout)
} }
fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
UpsampleNearest1D(sz).map(self, layout)
}
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> { fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
UpsampleNearest2D(h, w).map(self, layout) UpsampleNearest2D(h, w).map(self, layout)
} }
@ -2231,7 +2399,40 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout, kernel_l: &Layout,
params: &crate::conv::ParamsConv1D, params: &crate::conv::ParamsConv1D,
) -> Result<Self> { ) -> Result<Self> {
Conv1D(params).map(self, l, kernel, kernel_l) if !USE_IM2COL_CONV1D {
return Conv1D(params).map(self, l, kernel, kernel_l);
}
let op = Im2Col1D {
l_k: params.k_size,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
};
let col = op.map(self, l)?;
let b = params.b_size;
let n = params.c_out;
let l_out = params.l_out();
let k = op.l_k * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
} }
fn conv2d( fn conv2d(
@ -2241,7 +2442,43 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout, kernel_l: &Layout,
params: &crate::conv::ParamsConv2D, params: &crate::conv::ParamsConv2D,
) -> Result<Self> { ) -> Result<Self> {
Conv2D(params).map(self, l, kernel, kernel_l) if !USE_IM2COL_CONV2D {
return Conv2D(params).map(self, l, kernel, kernel_l);
}
let op = Im2Col {
h_k: params.k_h,
w_k: params.k_w,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
};
let col = op.map(self, l)?;
let b = params.b_size;
let n = params.c_out;
let (h_out, w_out) = (params.out_h(), params.out_w());
let k = op.h_k * op.w_k * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
} }
fn conv_transpose2d( fn conv_transpose2d(

View File

@ -1,7 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage}; use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
use candle_kernels as kernels; pub use candle_kernels as kernels;
pub use cudarc; pub use cudarc;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{ use cudarc::driver::{
@ -312,6 +312,13 @@ impl BackendDevice for CudaDevice {
// cudarc changes. // cudarc changes.
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap(); let curand = self.curand.lock().unwrap();
// curand can only generate an odd number of values.
// https://github.com/huggingface/candle/issues/734
let elem_count_round = if elem_count % 2 == 1 {
elem_count + 1
} else {
elem_count
};
let slice = match dtype { let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype { Err(CudaError::UnsupportedDtype {
@ -321,7 +328,7 @@ impl BackendDevice for CudaDevice {
.w()? .w()?
} }
DType::F32 => { DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?; let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand curand
.0 .0
.fill_with_normal(&mut data, mean as f32, std as f32) .fill_with_normal(&mut data, mean as f32, std as f32)
@ -329,7 +336,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data) CudaStorageSlice::F32(data)
} }
DType::F64 => { DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?; let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?; curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data) CudaStorageSlice::F64(data)
} }
@ -383,7 +390,7 @@ impl BackendDevice for CudaDevice {
} }
#[derive(Debug)] #[derive(Debug)]
enum CudaStorageSlice { pub enum CudaStorageSlice {
U8(CudaSlice<u8>), U8(CudaSlice<u8>),
U32(CudaSlice<u32>), U32(CudaSlice<u32>),
I64(CudaSlice<i64>), I64(CudaSlice<i64>),
@ -394,7 +401,7 @@ enum CudaStorageSlice {
} }
type S = CudaStorageSlice; type S = CudaStorageSlice;
trait Map1 { pub trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self, &self,
src: &CudaSlice<T>, src: &CudaSlice<T>,
@ -416,7 +423,7 @@ trait Map1 {
} }
} }
trait Map2 { pub trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self, &self,
src1: &CudaSlice<T>, src1: &CudaSlice<T>,
@ -441,7 +448,7 @@ trait Map2 {
} }
} }
trait Map2InPlace { pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self, &self,
dst: &mut CudaSlice<T>, dst: &mut CudaSlice<T>,
@ -472,7 +479,7 @@ trait Map2InPlace {
} }
} }
trait Map1Any { pub trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self, &self,
src: &CudaSlice<T>, src: &CudaSlice<T>,
@ -495,7 +502,7 @@ trait Map1Any {
} }
} }
trait Map2Any { pub trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self, &self,
src1: &CudaSlice<T>, src1: &CudaSlice<T>,
@ -532,7 +539,7 @@ impl Map1 for Clone {
} }
} }
fn kernel_name<T: WithDType>(root: &str) -> String { pub fn kernel_name<T: WithDType>(root: &str) -> String {
let dtype = T::DTYPE.as_str(); let dtype = T::DTYPE.as_str();
format!("{root}_{dtype}") format!("{root}_{dtype}")
} }
@ -593,6 +600,105 @@ impl Map1 for Elu {
} }
} }
struct Im2Col1D {
l_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col1D {
fn l_out(&self, l: usize) -> usize {
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
}
}
impl Map1 for Im2Col1D {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let l_out = self.l_out(dims[2]);
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (
dst_el,
l_out,
self.l_k,
self.stride,
self.padding,
self.dilation,
&ds,
src,
&dst,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
struct Im2Col {
h_k: usize,
w_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col {
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
(h_out, w_out)
}
}
impl Map1 for Im2Col {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (
dst_el,
h_out,
w_out,
self.h_k,
self.w_k,
self.stride,
self.padding,
self.dilation,
&ds,
src,
&dst,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
struct Powf(f64); struct Powf(f64);
impl Map1 for Powf { impl Map1 for Powf {
fn f<T: DeviceRepr + WithDType>( fn f<T: DeviceRepr + WithDType>(
@ -1310,8 +1416,8 @@ fn slice_src_and_dst<'a, T>(
#[derive(Debug)] #[derive(Debug)]
pub struct CudaStorage { pub struct CudaStorage {
slice: CudaStorageSlice, pub slice: CudaStorageSlice,
device: CudaDevice, pub device: CudaDevice,
} }
pub trait CudaDType: Sized { pub trait CudaDType: Sized {
@ -1650,9 +1756,46 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout, kernel_l: &Layout,
params: &crate::conv::ParamsConv1D, params: &crate::conv::ParamsConv1D,
) -> Result<Self> { ) -> Result<Self> {
const USE_IM2COL_CONV1D: bool = true;
let device = self.device().clone(); let device = self.device().clone();
if !USE_IM2COL_CONV1D {
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device }) return Ok(Self { slice, device });
}
let col = Im2Col1D {
l_k: params.k_size,
stride: params.stride,
dilation: params.dilation,
padding: params.padding,
}
.map(&self.slice, &device, l)?;
let col = Self { slice: col, device };
let l_out = params.l_out();
let b = params.b_size;
let n = params.c_out;
let k = params.k_size * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
} }
#[cfg(not(feature = "cudnn"))] #[cfg(not(feature = "cudnn"))]
@ -1663,9 +1806,50 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout, kernel_l: &Layout,
params: &crate::conv::ParamsConv2D, params: &crate::conv::ParamsConv2D,
) -> Result<Self> { ) -> Result<Self> {
const USE_IM2COL_CONV2D: bool = true;
let device = self.device().clone(); let device = self.device().clone();
if !USE_IM2COL_CONV2D {
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device }) return Ok(Self { slice, device });
}
let col = Im2Col {
h_k: params.k_h,
w_k: params.k_w,
stride: params.stride,
dilation: params.dilation,
padding: params.padding,
}
.map(&self.slice, &device, l)?;
let col = Self { slice: col, device };
let h_out = params.out_h();
let w_out = params.out_w();
let b = params.b_size;
let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
} }
#[cfg(feature = "cudnn")] #[cfg(feature = "cudnn")]
@ -1770,6 +1954,10 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device }) Ok(Self { slice, device })
} }
fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> {
crate::bail!("upsample-nearest1d is not supported on cuda")
}
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> { fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
let device = self.device().clone(); let device = self.device().clone();
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?; let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
@ -1889,6 +2077,9 @@ impl BackendStorage for CudaStorage {
let src_shape = src_l.shape(); let src_shape = src_l.shape();
let dims = src_shape.dims(); let dims = src_shape.dims();
let el_count = src_shape.elem_count(); let el_count = src_shape.elem_count();
if el_count == 0 {
return Ok(());
}
let cfg = LaunchConfig::for_num_elems(el_count as u32); let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device; let dev = &self.device;
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?; let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;

View File

@ -54,8 +54,8 @@ pub(crate) fn launch_conv2d<
let x_shape = [ let x_shape = [
params.b_size as i32, params.b_size as i32,
params.c_in as i32, params.c_in as i32,
params.i_w as i32,
params.i_h as i32, params.i_h as i32,
params.i_w as i32,
]; ];
// Note that `src` already starts at the proper offset. // Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() { let x = if src_l.is_contiguous() {
@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d<
[ [
params.c_out as i32, params.c_out as i32,
params.c_in as i32, params.c_in as i32,
params.k_w as i32,
params.k_h as i32, params.k_h as i32,
params.k_w as i32,
], ],
)?; )?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32); let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor( let y = cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, w_out, h_out], [params.b_size as i32, params.c_out as i32, h_out, w_out],
)?; )?;
let conv2d = Conv2dForward { let conv2d = Conv2dForward {
conv: &conv, conv: &conv,

View File

@ -1,15 +1,24 @@
//! Types for elements that can be stored and manipulated using tensors.
#![allow(clippy::redundant_closure_call)] #![allow(clippy::redundant_closure_call)]
use crate::backend::BackendStorage; use crate::backend::BackendStorage;
use crate::{CpuStorage, Error, Result}; use crate::{CpuStorage, Error, Result};
/// The different types of elements allowed in tensors.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType { pub enum DType {
// Unsigned 8 bits integer.
U8, U8,
// Unsigned 32 bits integer.
U32, U32,
// Signed 64 bits integer.
I64, I64,
// Brain floating-point using half precision (16 bits).
BF16, BF16,
// Floating-point using half precision (16 bits).
F16, F16,
// Floating-point using single precision (32 bits).
F32, F32,
// Floating-point using double precision (64 bits).
F64, F64,
} }
@ -33,6 +42,7 @@ impl std::str::FromStr for DType {
} }
impl DType { impl DType {
/// String representation for dtypes.
pub fn as_str(&self) -> &'static str { pub fn as_str(&self) -> &'static str {
match self { match self {
Self::U8 => "u8", Self::U8 => "u8",
@ -45,6 +55,7 @@ impl DType {
} }
} }
/// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
pub fn size_in_bytes(&self) -> usize { pub fn size_in_bytes(&self) -> usize {
match self { match self {
Self::U8 => 1, Self::U8 => 1,

View File

@ -152,6 +152,10 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> { fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }

View File

@ -30,7 +30,7 @@ pub enum Error {
UnsupportedDTypeForOp(DType, &'static str), UnsupportedDTypeForOp(DType, &'static str),
// === Dimension Index Errors === // === Dimension Index Errors ===
#[error("{op}: dimension index {dim} out of range for {shape:?}")] #[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
DimOutOfRange { DimOutOfRange {
shape: Shape, shape: Shape,
dim: i32, dim: i32,
@ -207,11 +207,11 @@ pub type Result<T> = std::result::Result<T, Error>;
impl Error { impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err)) Self::Wrapped(Box::new(err)).bt()
} }
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self { pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Msg(err.to_string()) Self::Msg(err.to_string()).bt()
} }
pub fn bt(self) -> Self { pub fn bt(self) -> Self {

View File

@ -46,19 +46,31 @@ impl Tensor {
current_dim += 1; current_dim += 1;
out out
} }
TensorIndexer::IndexSelect(indexes) => {
if indexes.rank() != 1 {
crate::bail!("multi-dimensional tensor indexing is not supported")
}
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
current_dim += 1;
out
}
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
}; };
} }
Ok(x) Ok(x)
} }
} }
#[derive(Debug, Clone)] #[derive(Debug)]
/// Generic structure used to index a slice of the tensor /// Generic structure used to index a slice of the tensor
pub enum TensorIndexer { pub enum TensorIndexer {
/// This selects the elemnts for which an index has some specific value. /// This selects the elemnts for which an index has some specific value.
Select(usize), Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor /// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>), Narrow(Bound<usize>, Bound<usize>),
/// Indexing via a 1d tensor
IndexSelect(Tensor),
Err(Error),
} }
impl From<usize> for TensorIndexer { impl From<usize> for TensorIndexer {
@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer {
} }
} }
impl From<&[u32]> for TensorIndexer {
fn from(index: &[u32]) -> Self {
match Tensor::new(index, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<Vec<u32>> for TensorIndexer {
fn from(index: Vec<u32>) -> Self {
let len = index.len();
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<&Tensor> for TensorIndexer {
fn from(tensor: &Tensor) -> Self {
TensorIndexer::IndexSelect(tensor.clone())
}
}
macro_rules! impl_from_range { macro_rules! impl_from_range {
($range_type:ty) => { ($range_type:ty) => {
impl From<$range_type> for TensorIndexer { impl From<$range_type> for TensorIndexer {

View File

@ -59,6 +59,7 @@ mod op;
pub mod pickle; pub mod pickle;
pub mod quantized; pub mod quantized;
pub mod safetensors; pub mod safetensors;
pub mod scalar;
pub mod shape; pub mod shape;
mod storage; mod storage;
mod strided_index; mod strided_index;
@ -109,14 +110,8 @@ impl ToUsize2 for (usize, usize) {
} }
// A simple trait defining a module with forward method using a single argument. // A simple trait defining a module with forward method using a single argument.
pub trait Module: std::fmt::Debug { pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>; fn forward(&self, xs: &Tensor) -> Result<Tensor>;
/// Change the module to use training mode vs eval mode.
///
/// The default implementation does nothing as this is only used for a couple modules such as
/// dropout or batch-normalization.
fn set_training(&mut self, _training: bool) {}
} }
impl Module for quantized::QMatMul { impl Module for quantized::QMatMul {
@ -124,3 +119,9 @@ impl Module for quantized::QMatMul {
self.forward(xs) self.forward(xs)
} }
} }
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self(xs)
}
}

View File

@ -58,6 +58,8 @@ pub enum UnaryOp {
Sqr, Sqr,
Sqrt, Sqrt,
Gelu, Gelu,
GeluErf,
Erf,
Relu, Relu,
Tanh, Tanh,
} }
@ -116,6 +118,7 @@ pub enum Op {
stride: (usize, usize), stride: (usize, usize),
}, },
UpsampleNearest1D(Tensor),
UpsampleNearest2D(Tensor), UpsampleNearest2D(Tensor),
Cat(Vec<Tensor>, usize), Cat(Vec<Tensor>, usize),
@ -324,6 +327,8 @@ pub(crate) struct Recip;
pub(crate) struct Sqr; pub(crate) struct Sqr;
pub(crate) struct Sqrt; pub(crate) struct Sqrt;
pub(crate) struct Gelu; pub(crate) struct Gelu;
pub(crate) struct GeluErf;
pub(crate) struct Erf;
pub(crate) struct Relu; pub(crate) struct Relu;
pub(crate) struct Tanh; pub(crate) struct Tanh;
@ -600,6 +605,92 @@ impl UnaryOpT for Gelu {
fn f64_vec(xs: &[f64], ys: &mut [f64]) { fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys) crate::mkl::vd_gelu(xs, ys)
} }
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::vs_gelu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::vd_gelu(xs, ys)
}
}
impl UnaryOpT for Erf {
const NAME: &'static str = "erf";
const KERNEL: &'static str = "uerf";
const V: Self = Erf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
crate::cpu::erf::erf(v)
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
impl UnaryOpT for GeluErf {
const NAME: &'static str = "gelu_erf";
const KERNEL: &'static str = "ugelu_erf";
const V: Self = GeluErf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
} }
impl UnaryOpT for Relu { impl UnaryOpT for Relu {

View File

@ -85,7 +85,7 @@ const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
pub struct BlockQ8_1 { pub struct BlockQ8_1 {
pub(crate) d: f16, pub(crate) d: f16,
pub(crate) s: f16, pub(crate) s: f16,
pub(crate) qs: [u8; QK8_1], pub(crate) qs: [i8; QK8_1],
} }
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36); const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
@ -278,6 +278,7 @@ impl GgmlType for BlockQ4_1 {
} }
sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
+ f16::to_f32(xs.m) * f16::to_f32(ys.s)
} }
Ok(sumf) Ok(sumf)
} }
@ -471,6 +472,7 @@ impl GgmlType for BlockQ5_1 {
} }
sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
+ f16::to_f32(xs.m) * f16::to_f32(ys.s)
} }
Ok(sumf) Ok(sumf)
} }
@ -652,8 +654,8 @@ impl GgmlType for BlockQ8_1 {
for j in 0..Self::BLCK_SIZE / 2 { for j in 0..Self::BLCK_SIZE / 2 {
let v0 = xs[j] * id; let v0 = xs[j] * id;
let v1 = xs[j + Self::BLCK_SIZE / 2] * id; let v1 = xs[j + Self::BLCK_SIZE / 2] * id;
ys.qs[j] = f32::round(v0) as u8; ys.qs[j] = f32::round(v0) as i8;
ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as u8; ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as i8;
sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32; sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32;
} }
ys.s = f16::from_f32(sum as f32) * ys.d; ys.s = f16::from_f32(sum as f32) * ys.d;

View File

@ -229,7 +229,7 @@ impl QTensor {
} }
} }
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct QMatMul(std::sync::Arc<QTensor>); pub struct QMatMul(std::sync::Arc<QTensor>);
impl QMatMul { impl QMatMul {

View File

@ -78,11 +78,7 @@ impl st::View for &Tensor {
} }
impl Tensor { impl Tensor {
pub fn save_safetensors<P: AsRef<std::path::Path>>( pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
&self,
name: &str,
filename: P,
) -> Result<()> {
let data = [(name, self.clone())]; let data = [(name, self.clone())];
Ok(st::serialize_to_file(data, &None, filename.as_ref())?) Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
} }
@ -267,7 +263,7 @@ impl MmapedFile {
/// # Safety /// # Safety
/// ///
/// The unsafe is inherited from [`memmap2::MmapOptions`]. /// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
let p = p.as_ref(); let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let inner = memmap2::MmapOptions::new() let inner = memmap2::MmapOptions::new()

23
candle-core/src/scalar.rs Normal file
View File

@ -0,0 +1,23 @@
use crate::{Result, Tensor, WithDType};
pub enum TensorScalar {
Tensor(Tensor),
Scalar(Tensor),
}
pub trait TensorOrScalar {
fn to_tensor_scalar(self) -> Result<TensorScalar>;
}
impl TensorOrScalar for &Tensor {
fn to_tensor_scalar(self) -> Result<TensorScalar> {
Ok(TensorScalar::Tensor(self.clone()))
}
}
impl<T: WithDType> TensorOrScalar for T {
fn to_tensor_scalar(self) -> Result<TensorScalar> {
let scalar = Tensor::new(self, &crate::Device::Cpu)?;
Ok(TensorScalar::Scalar(scalar))
}
}

View File

@ -1,3 +1,4 @@
//! The shape of a tensor is a tuple with the size of each of its dimensions.
#![allow(clippy::redundant_closure_call)] #![allow(clippy::redundant_closure_call)]
use crate::{Error, Result}; use crate::{Error, Result};
@ -72,6 +73,14 @@ impl From<(usize, usize, usize, usize, usize)> for Shape {
} }
} }
impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
Self(vec![
d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
])
}
}
impl From<Vec<usize>> for Shape { impl From<Vec<usize>> for Shape {
fn from(dims: Vec<usize>) -> Self { fn from(dims: Vec<usize>) -> Self {
Self(dims) Self(dims)
@ -119,6 +128,7 @@ impl Shape {
Self(dims.to_vec()) Self(dims.to_vec())
} }
/// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
pub fn rank(&self) -> usize { pub fn rank(&self) -> usize {
self.0.len() self.0.len()
} }
@ -127,10 +137,12 @@ impl Shape {
self.0 self.0
} }
/// The dimensions as a slice of `usize`.
pub fn dims(&self) -> &[usize] { pub fn dims(&self) -> &[usize] {
&self.0 &self.0
} }
/// The total number of elements, this is the product of all dimension sizes.
pub fn elem_count(&self) -> usize { pub fn elem_count(&self) -> usize {
self.0.iter().product() self.0.iter().product()
} }
@ -182,6 +194,8 @@ impl Shape {
true true
} }
/// Modifies the shape by adding a list of additional dimensions at the end of the existing
/// dimensions.
pub fn extend(mut self, additional_dims: &[usize]) -> Self { pub fn extend(mut self, additional_dims: &[usize]) -> Self {
self.0.extend(additional_dims); self.0.extend(additional_dims);
self self
@ -419,6 +433,29 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
} }
} }
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
let d3 = self.3.to_index(shape, op)?;
let d4 = self.4.to_index(shape, op)?;
Ok(vec![d0, d1, d2, d3, d4])
}
}
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
let d3 = self.3.to_index(shape, op)?;
let d4 = self.4.to_index(shape, op)?;
let d5 = self.5.to_index(shape, op)?;
Ok(vec![d0, d1, d2, d3, d4, d5])
}
}
extract_dims!(dims0, 0, |_: &[usize]| (), ()); extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
@ -457,3 +494,171 @@ mod tests {
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]); assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
} }
} }
pub trait ShapeWithOneHole {
fn into_shape(self, el_count: usize) -> Result<Shape>;
}
impl<S: Into<Shape>> ShapeWithOneHole for S {
fn into_shape(self, _el_count: usize) -> Result<Shape> {
Ok(self.into())
}
}
impl ShapeWithOneHole for ((),) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
Ok(el_count.into())
}
}
impl ShapeWithOneHole for ((), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1) = self;
if el_count % d1 != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((el_count / d1, d1).into())
}
}
impl ShapeWithOneHole for (usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, ()) = self;
if el_count % d1 != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((d1, el_count / d1).into())
}
}
impl ShapeWithOneHole for ((), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2) = self;
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2).into())
}
}
impl ShapeWithOneHole for (usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2) = self;
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2).into())
}
}
impl ShapeWithOneHole for (usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, ()) = self;
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d).into())
}
}
impl ShapeWithOneHole for ((), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3).into())
}
}
impl ShapeWithOneHole for (usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3).into())
}
}
impl ShapeWithOneHole for (usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, ()) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d).into())
}
}
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3, d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3, d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3, d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, (), d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, d4, ()) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, d4, el_count / d).into())
}
}

View File

@ -369,6 +369,19 @@ impl Storage {
} }
} }
pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> { pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
match self { match self {
Storage::Cpu(storage) => { Storage::Cpu(storage) => {

View File

@ -1,8 +1,10 @@
//! Tensors are N-dimenional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)] #![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage}; use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{ use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp, BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
}; };
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims}; use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
@ -103,6 +105,28 @@ macro_rules! binary_op {
}; };
} }
macro_rules! binary_op_scalar {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
let rhs = match rhs.to_tensor_scalar()? {
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
crate::scalar::TensorScalar::Scalar(rhs) => rhs
.to_dtype(self.dtype())?
.to_device(self.device())?
.broadcast_as(self.shape())?,
};
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(),
self.layout(),
rhs.layout(),
)?;
let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
Ok(from_storage(storage, shape.clone(), op, false))
}
};
}
macro_rules! broadcast_binary_op { macro_rules! broadcast_binary_op {
($fn_name:ident, $inner_fn_name:ident) => { ($fn_name:ident, $inner_fn_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> { pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
@ -445,8 +469,8 @@ impl Tensor {
binary_op!(mul, Mul); binary_op!(mul, Mul);
binary_op!(sub, Sub); binary_op!(sub, Sub);
binary_op!(div, Div); binary_op!(div, Div);
binary_op!(maximum, Maximum); binary_op_scalar!(maximum, Maximum);
binary_op!(minimum, Minimum); binary_op_scalar!(minimum, Minimum);
broadcast_binary_op!(broadcast_add, add); broadcast_binary_op!(broadcast_add, add);
broadcast_binary_op!(broadcast_mul, mul); broadcast_binary_op!(broadcast_mul, mul);
broadcast_binary_op!(broadcast_sub, sub); broadcast_binary_op!(broadcast_sub, sub);
@ -465,6 +489,8 @@ impl Tensor {
unary_op!(sqr, Sqr); unary_op!(sqr, Sqr);
unary_op!(sqrt, Sqrt); unary_op!(sqrt, Sqrt);
unary_op!(gelu, Gelu); unary_op!(gelu, Gelu);
unary_op!(gelu_erf, GeluErf);
unary_op!(erf, Erf);
unary_op!(relu, Relu); unary_op!(relu, Relu);
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
@ -642,7 +668,12 @@ impl Tensor {
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?; let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
let mut dims = self.dims().to_vec(); let mut dims = self.dims().to_vec();
dims[dim] = 1; dims[dim] = 1;
let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec())); let op = match op {
ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
}
ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
};
let res = from_storage(storage, dims, op, false); let res = from_storage(storage, dims, op, false);
if keepdim { if keepdim {
Ok(res) Ok(res)
@ -775,8 +806,15 @@ impl Tensor {
/// comparison operation is specified by the `op` argument. /// comparison operation is specified by the `op` argument.
/// ///
/// The returned tensor has the same shape as the original tensors and uses `u8` elements. /// The returned tensor has the same shape as the original tensors and uses `u8` elements.
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> { pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
let shape = self.same_shape_binary_op(rhs, "cmp")?; let rhs = match rhs.to_tensor_scalar()? {
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
crate::scalar::TensorScalar::Scalar(rhs) => rhs
.to_dtype(self.dtype())?
.to_device(self.device())?
.broadcast_as(self.shape())?,
};
let shape = self.same_shape_binary_op(&rhs, "cmp")?;
let storage = self let storage = self
.storage() .storage()
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?; .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
@ -785,45 +823,68 @@ impl Tensor {
} }
/// Element-wise equality. /// Element-wise equality.
pub fn eq(&self, rhs: &Self) -> Result<Self> { pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Eq) self.cmp(rhs, CmpOp::Eq)
} }
/// Element-wise non-equality. /// Element-wise non-equality.
pub fn ne(&self, rhs: &Self) -> Result<Self> { pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Ne) self.cmp(rhs, CmpOp::Ne)
} }
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self < /// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
/// rhs` and 0 otherwise. /// rhs` and 0 otherwise.
pub fn lt(&self, rhs: &Self) -> Result<Self> { pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Lt) self.cmp(rhs, CmpOp::Lt)
} }
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self > /// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
/// rhs` and 0 otherwise. /// rhs` and 0 otherwise.
pub fn gt(&self, rhs: &Self) -> Result<Self> { pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Gt) self.cmp(rhs, CmpOp::Gt)
} }
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >= /// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
/// rhs` and 0 otherwise. /// rhs` and 0 otherwise.
pub fn ge(&self, rhs: &Self) -> Result<Self> { pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Ge) self.cmp(rhs, CmpOp::Ge)
} }
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <= /// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
/// rhs` and 0 otherwise. /// rhs` and 0 otherwise.
pub fn le(&self, rhs: &Self) -> Result<Self> { pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Le) self.cmp(rhs, CmpOp::Le)
} }
/// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the /// Clamp the tensor values to be between `min` and `max`.
pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
self.maximum(min)?.minimum(max)
}
/// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
///
/// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
/// tensor also has three dimensions, `(batch, channels, target_size)`.
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
let (n, c, _l) = self.dims3()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
let storage = self
.storage()
.upsample_nearest1d(self.layout(), target_size)?;
Ok(from_storage(storage, (n, c, target_size), op, false))
}
/// Alias for `interpolate1d`.
pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
self.interpolate1d(target_size)
}
/// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
/// nearest element. /// nearest element.
/// ///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`. /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?; let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D); let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
let storage = self let storage = self
@ -832,6 +893,11 @@ impl Tensor {
Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
} }
/// Alias for `interpolate2d`.
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
self.interpolate2d(target_h, target_w)
}
/// 2D average pooling over an input tensor with multiple channels. /// 2D average pooling over an input tensor with multiple channels.
/// ///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
@ -1684,12 +1750,15 @@ impl Tensor {
Ok(from_storage(storage, shape, BackpropOp::none(), true)) Ok(from_storage(storage, shape, BackpropOp::none(), true))
} }
// TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the /// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same. /// original tensor is the same.
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
/// a new storage and copies the data over, the returned tensor is always contiguous. /// a new storage and copies the data over, the returned tensor is always contiguous.
/// ///
/// The shape can be specified using a tuple of `usize` and at most one `()` in which case
/// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
/// as to match the number of elements in the tensor.
///
/// ```rust /// ```rust
/// # use candle_core::{Tensor, DType, Device, D}; /// # use candle_core::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
@ -1699,10 +1768,14 @@ impl Tensor {
/// ///
/// let c = a.reshape((3, 2))?; /// let c = a.reshape((3, 2))?;
/// assert_eq!(c.shape().dims(), &[3, 2]); /// assert_eq!(c.shape().dims(), &[3, 2]);
///
/// let c = a.reshape((2, (), 1))?;
/// assert_eq!(c.shape().dims(), &[2, 3, 1]);
///
/// # Ok::<(), candle_core::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> { pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
let shape = shape.into(); let shape = s.into_shape(self.elem_count())?;
if shape.elem_count() != self.elem_count() { if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp { return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(), lhs: self.shape().clone(),
@ -1836,6 +1909,34 @@ impl Tensor {
for arg in args { for arg in args {
arg.as_ref().check_dim(dim, "cat")?; arg.as_ref().check_dim(dim, "cat")?;
} }
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg0.rank() != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: arg0.rank(),
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
if dim == 0 { if dim == 0 {
Self::cat0(args) Self::cat0(args)
} else { } else {

View File

@ -1,4 +1,4 @@
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor}; use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> { fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?; let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@ -33,6 +33,44 @@ fn tensor_2d(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn clamp(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
let tensor = tensor.clamp(1.5, 6.2)?;
assert_eq!(
tensor.to_vec2::<f32>()?,
[[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]],
);
Ok(())
}
fn unary_op(device: &Device) -> Result<()> {
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
[
[-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[
[-0.004, 0.8413, 3.9999, -0.046, 0.3457],
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
[
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
]
);
Ok(())
}
fn binary_op(device: &Device) -> Result<()> { fn binary_op(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor1 = Tensor::new(data, device)?; let tensor1 = Tensor::new(data, device)?;
@ -877,6 +915,14 @@ fn broadcasting(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn randn(device: &Device) -> Result<()> {
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]);
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]);
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu); test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu); test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
@ -889,6 +935,7 @@ test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu); test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu); test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu); test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu); test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu); test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu); test_device!(cmp, cmp_cpu, cmp_gpu);
@ -899,6 +946,8 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu); test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu); test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu); test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
test_device!(randn, randn_cpu, randn_gpu);
test_device!(clamp, clamp_cpu, clamp_gpu);
// There was originally a bug on the CPU implementation for randn // There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381 // https://github.com/huggingface/candle/issues/381

View File

@ -11,8 +11,8 @@ readme = "README.md"
[dependencies] [dependencies]
byteorder = { workspace = true } byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.2.1" } candle-nn = { path = "../candle-nn", version = "0.2.3" }
hf-hub = { workspace = true} hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true } memmap2 = { workspace = true }

View File

@ -8,13 +8,9 @@ use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File; use std::fs::File;
use std::io::{self, BufReader, Read}; use std::io::{self, BufReader, Read};
fn read_u32<T: Read>(reader: &mut T) -> Result<u32> { fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> {
let mut b = vec![0u8; 4]; use byteorder::ReadBytesExt;
reader.read_exact(&mut b)?; reader.read_u32::<byteorder::BigEndian>()
let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| {
(s + basis * u64::from(x), basis * 256)
});
Ok(result as u32)
} }
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> { fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {

View File

@ -11,19 +11,19 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.2.1" } candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
candle-nn = { path = "../candle-nn", version = "0.2.1" } candle-nn = { path = "../candle-nn", version = "0.2.3" }
candle-transformers = { path = "../candle-transformers", version = "0.2.1" } candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true } half = { workspace = true, optional = true }
image = { workspace = true } image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
@ -50,7 +50,7 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cudnn = ["candle/cudnn"] cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "dep:candle-flash-attn"] flash-attn = ["cuda", "candle-transformers/flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"] nccl = ["cuda", "cudarc/nccl", "dep:half"]

View File

@ -0,0 +1,44 @@
# candle-bert
Bert is a general large language model. In this example it can be used for two
different tasks:
- Compute sentence embeddings for a prompt.
- Compute similarities between a set of sentences.
## Sentence embeddings
Bert is used to compute the sentence embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
cargo run --example bert --release -- --prompt "Here is a test sentence"
> [[[ 0.0798, -0.0665, -0.0247, ..., -0.1082, -0.1000, -0.2751],
> [ 0.4218, 0.2690, 0.2740, ..., 0.3889, 1.3503, 0.9908],
> [ 0.0466, 0.3041, -0.1143, ..., 0.4427, 0.6926, -0.1515],
> ...
> [ 0.3396, 0.4320, -0.4408, ..., 0.9212, 0.2331, -0.6777],
> [ 0.2789, 0.7539, 0.4306, ..., -0.0095, 0.3375, -1.7529],
> [ 0.6737, 0.7882, 0.0548, ..., 0.1836, 0.7299, -0.6617]]]
> Tensor[[1, 7, 384], f32]
```
## Similarities
In this example, Bert is used to compute the sentence embeddings for a set of
sentences (hardcoded in the examples). Then cosine similarities are computed for
each sentence pair and they are reported by decreasing values, hence the first
reported pair contains the two sentences that have the highest similarity score.
The sentence embeddings are computed using average pooling through all the
sentence tokens, including some potential padding.
```bash
cargo run --example bert --release
> score: 0.85 'The new movie is awesome' 'The new movie is so great'
> score: 0.61 'The cat sits outside' 'The cat plays in the garden'
> score: 0.52 'I love pasta' 'Do you like pizza?'
> score: 0.23 'The new movie is awesome' 'Do you like pizza?'
> score: 0.22 'I love pasta' 'The new movie is awesome'
```

View File

@ -3,14 +3,13 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
mod model; use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use anyhow::{anyhow, Error as E, Result}; use anyhow::{anyhow, Error as E, Result};
use candle::Tensor; use candle::Tensor;
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use clap::Parser; use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
use model::{BertModel, Config, DTYPE};
use tokenizers::{PaddingParams, Tokenizer}; use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)] #[derive(Parser, Debug)]

View File

@ -0,0 +1,19 @@
# candle-starcoder: code generation model
[StarCoder/BigCode](https://huggingface.co/bigcode/starcoderbase-1b) is a LLM
model specialized to code generation. The initial model was trained on 80
programming languages.
## Running some example
```bash
cargo run --example bigcode --release -- --prompt "fn fact(n: u64) -> u64 "
> fn fact(n: u64) -> u64 {
> if n == 0 {
> 1
> } else {
> n * fact(n - 1)
> }
> }
```

View File

@ -7,8 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use clap::Parser; use clap::Parser;
mod model; use candle_transformers::models::bigcode::{Config, GPTBigCode};
use model::{Config, GPTBigCode};
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
@ -29,9 +28,10 @@ impl TextGeneration {
tokenizer: Tokenizer, tokenizer: Tokenizer,
seed: u64, seed: u64,
temp: Option<f64>, temp: Option<f64>,
top_p: Option<f64>,
device: &Device, device: &Device,
) -> Self { ) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp); let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self { Self {
model, model,
tokenizer, tokenizer,
@ -95,6 +95,10 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -150,7 +154,14 @@ fn main() -> Result<()> {
let model = GPTBigCode::load(vb, config)?; let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device); let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?; pipeline.run(&args.prompt, args.sample_len)?;
Ok(()) Ok(())
} }

View File

@ -0,0 +1,19 @@
# candle-dinov2
[DINOv2](https://github.com/facebookresearch/dinov2) is a computer vision model.
In this example, it is used as an ImageNet classifier: the model returns the
probability for the image to belong to each of the 1000 ImageNet categories.
## Running some example
```bash
cargo run --example dinov2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
> mountain bike, all-terrain bike, off-roader: 43.67%
> bicycle-built-for-two, tandem bicycle, tandem: 33.20%
> crash helmet : 13.23%
> unicycle, monocycle : 2.44%
> maillot : 2.42%
```
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)

View File

@ -9,285 +9,10 @@ extern crate accelerate_src;
use clap::Parser; use clap::Parser;
use candle::{DType, IndexOp, Result, Tensor, D}; use candle::{DType, IndexOp, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; use candle_nn::{Module, VarBuilder};
use candle_transformers::models::dinov2;
const IMG_SIZE: usize = 518;
const PATCH_SIZE: usize = 14;
const NUM_CLASSES: usize = 1000;
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
if bias {
candle_nn::linear(in_dim, out_dim, vb)
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)
}
}
#[derive(Debug)]
struct Attention {
qkv: Linear,
proj: Linear,
num_heads: usize,
scale: f64,
}
impl Attention {
fn new(
vb: VarBuilder,
dim: usize,
num_heads: usize,
qkv_bias: bool,
proj_bias: bool,
) -> Result<Self> {
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
let scale = 1. / ((dim / num_heads) as f64).sqrt();
Ok(Self {
qkv,
proj,
num_heads,
scale,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, n, c) = xs.dims3()?;
let qkv = self
.qkv
.forward(xs)?
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
.transpose(1, 2)? // 02134
.transpose(0, 1)? // 20134
.transpose(2, 3)?; // 20314
let q = (qkv.i(0)? * self.scale)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
self.proj.forward(&attn)
}
}
#[derive(Debug)]
struct LayerScale {
gamma: Tensor,
}
impl LayerScale {
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
let gamma = vb.get(dim, "gamma")?;
Ok(Self { gamma })
}
}
impl Module for LayerScale {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&self.gamma)
}
}
#[derive(Debug)]
struct Mlp {
fc1: Linear,
fc2: Linear,
}
impl Mlp {
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
let out_features = in_features;
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?.gelu()?;
self.fc2.forward(&xs)
}
}
#[derive(Debug)]
struct Block {
norm1: LayerNorm,
attn: Attention,
ls1: LayerScale,
norm2: LayerNorm,
mlp: Mlp,
ls2: LayerScale,
}
impl Block {
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
Ok(Self {
norm1,
attn,
ls1,
norm2,
mlp,
ls2,
})
}
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = self
.ls1
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self
.ls2
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
xs + residual
}
}
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
patch_size: (usize, usize),
num_patches: usize,
}
impl PatchEmbed {
fn new(
vb: VarBuilder,
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
) -> Result<Self> {
let config = candle_nn::Conv2dConfig {
stride: patch_size,
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
let num_patches = (img_size / patch_size) * (img_size / patch_size);
Ok(Self {
proj,
patch_size: (patch_size, patch_size),
num_patches,
})
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _c, h, w) = xs.dims4()?;
let (patch_h, patch_w) = self.patch_size;
if (h % patch_h) != 0 {
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
}
if (w % patch_w) != 0 {
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
}
let xs = self.proj.forward(xs)?;
let (b, c, h, w) = xs.dims4()?;
// flatten embeddings.
xs.reshape((b, c, h * w))?.transpose(1, 2)
}
}
#[derive(Debug)]
pub struct DinoVisionTransformer {
patch_embed: PatchEmbed,
cls_token: Tensor,
pos_embed: Tensor,
blocks: Vec<Block>,
norm: LayerNorm,
head: Linear,
}
impl DinoVisionTransformer {
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
let patch_embed =
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
let num_tokens = 1;
let pos_embed = vb.get(
(1, patch_embed.num_patches + num_tokens, embed_dim),
"pos_embed",
)?;
let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
cls_token,
pos_embed,
blocks,
norm,
head,
})
}
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
let npatch = xs.dim(1)? - 1;
let n = self.pos_embed.dim(1)? - 1;
let sqrt_n = (n as f64).sqrt();
if npatch == n && w == h {
return Ok(xs.clone());
}
let class_pos_embed = self.pos_embed.i((.., ..1))?;
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
let dim = xs.dim(D::Minus1)?;
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
let patch_pos_embed = patch_pos_embed
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
.transpose(2, 3)?
.transpose(1, 2)?;
// This uses bicubic interpolation in the original implementation.
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
let el_count = patch_pos_embed.shape().elem_count();
let patch_pos_embed =
patch_pos_embed
.transpose(1, 2)?
.transpose(2, 3)?
.reshape((1, el_count / dim, dim))?;
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
}
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _nc, w, h) = xs.dims4()?;
let xs = self.patch_embed.forward(xs)?;
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
}
}
impl Module for DinoVisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.prepare_tokens_with_mask(xs)?;
for blk in self.blocks.iter() {
xs = blk.forward(&xs)?
}
let xs = self.norm.forward(&xs)?;
let xs_norm_clstoken = xs.i((.., 0))?;
let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
self.head.forward(&xs)
}
}
pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
DinoVisionTransformer::new(vb, 12, 384, 6)
}
#[derive(Parser)] #[derive(Parser)]
struct Args { struct Args {
#[arg(long)] #[arg(long)]
@ -320,7 +45,7 @@ pub fn main() -> anyhow::Result<()> {
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? }; let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let weights = weights.deserialize()?; let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let model = vit_small(vb)?; let model = dinov2::vit_small(vb)?;
println!("model built"); println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?; let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)? let prs = candle_nn::ops::softmax(&logits, D::Minus1)?

View File

@ -8,340 +8,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig};
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn as nn;
use nn::{Module, VarBuilder};
// Based on the Python version from torchvision.
// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
#[derive(Debug, Clone, Copy)]
pub struct MBConvConfig {
expand_ratio: f64,
kernel: usize,
stride: usize,
input_channels: usize,
out_channels: usize,
num_layers: usize,
}
fn make_divisible(v: f64, divisor: usize) -> usize {
let min_value = divisor;
let new_v = usize::max(
min_value,
(v + divisor as f64 * 0.5) as usize / divisor * divisor,
);
if (new_v as f64) < 0.9 * v {
new_v + divisor
} else {
new_v
}
}
fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
let bneck_conf = |e, k, s, i, o, n| {
let input_channels = make_divisible(i as f64 * width_mult, 8);
let out_channels = make_divisible(o as f64 * width_mult, 8);
let num_layers = (n as f64 * depth_mult).ceil() as usize;
MBConvConfig {
expand_ratio: e,
kernel: k,
stride: s,
input_channels,
out_channels,
num_layers,
}
};
vec![
bneck_conf(1., 3, 1, 32, 16, 1),
bneck_conf(6., 3, 2, 16, 24, 2),
bneck_conf(6., 5, 2, 24, 40, 2),
bneck_conf(6., 3, 2, 40, 80, 3),
bneck_conf(6., 5, 1, 80, 112, 3),
bneck_conf(6., 5, 2, 112, 192, 4),
bneck_conf(6., 3, 1, 192, 320, 1),
]
}
impl MBConvConfig {
fn b0() -> Vec<Self> {
bneck_confs(1.0, 1.0)
}
fn b1() -> Vec<Self> {
bneck_confs(1.0, 1.1)
}
fn b2() -> Vec<Self> {
bneck_confs(1.1, 1.2)
}
fn b3() -> Vec<Self> {
bneck_confs(1.2, 1.4)
}
fn b4() -> Vec<Self> {
bneck_confs(1.4, 1.8)
}
fn b5() -> Vec<Self> {
bneck_confs(1.6, 2.2)
}
fn b6() -> Vec<Self> {
bneck_confs(1.8, 2.6)
}
fn b7() -> Vec<Self> {
bneck_confs(2.0, 3.1)
}
}
/// Conv2D with same padding.
#[derive(Debug)]
struct Conv2DSame {
conv2d: nn::Conv2d,
s: usize,
k: usize,
}
impl Conv2DSame {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
bias: bool,
) -> Result<Self> {
let conv_config = nn::Conv2dConfig {
stride,
groups,
..Default::default()
};
let conv2d = if bias {
nn::conv2d(i, o, k, conv_config, vb)?
} else {
nn::conv2d_no_bias(i, o, k, conv_config, vb)?
};
Ok(Self {
conv2d,
s: stride,
k,
})
}
}
impl Module for Conv2DSame {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let s = self.s;
let k = self.k;
let (_, _, ih, iw) = xs.dims4()?;
let oh = (ih + s - 1) / s;
let ow = (iw + s - 1) / s;
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
if pad_h > 0 || pad_w > 0 {
let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
self.conv2d.forward(&xs)
} else {
self.conv2d.forward(xs)
}
}
}
#[derive(Debug)]
struct ConvNormActivation {
conv2d: Conv2DSame,
bn2d: nn::BatchNorm,
activation: bool,
}
impl ConvNormActivation {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
) -> Result<Self> {
let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
Ok(Self {
conv2d,
bn2d,
activation: true,
})
}
fn no_activation(self) -> Self {
Self {
activation: false,
..self
}
}
}
impl Module for ConvNormActivation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.conv2d.forward(xs)?;
let xs = self.bn2d.forward(&xs)?;
if self.activation {
swish(&xs)
} else {
Ok(xs)
}
}
}
#[derive(Debug)]
struct SqueezeExcitation {
fc1: Conv2DSame,
fc2: Conv2DSame,
}
impl SqueezeExcitation {
fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for SqueezeExcitation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
// equivalent to adaptive_avg_pool2d([1, 1])
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = self.fc1.forward(&xs)?;
let xs = swish(&xs)?;
let xs = self.fc2.forward(&xs)?;
let xs = nn::ops::sigmoid(&xs)?;
residual.broadcast_mul(&xs)
}
}
#[derive(Debug)]
struct MBConv {
expand_cna: Option<ConvNormActivation>,
depthwise_cna: ConvNormActivation,
squeeze_excitation: SqueezeExcitation,
project_cna: ConvNormActivation,
config: MBConvConfig,
}
impl MBConv {
fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
let vb = vb.pp("block");
let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
let expand_cna = if exp != c.input_channels {
Some(ConvNormActivation::new(
vb.pp("0"),
c.input_channels,
exp,
1,
1,
1,
)?)
} else {
None
};
let start_index = if expand_cna.is_some() { 1 } else { 0 };
let depthwise_cna =
ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
let squeeze_channels = usize::max(1, c.input_channels / 4);
let squeeze_excitation =
SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
let project_cna =
ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
.no_activation();
Ok(Self {
expand_cna,
depthwise_cna,
squeeze_excitation,
project_cna,
config: c,
})
}
}
impl Module for MBConv {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let use_res_connect =
self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
let ys = match &self.expand_cna {
Some(expand_cna) => expand_cna.forward(xs)?,
None => xs.clone(),
};
let ys = self.depthwise_cna.forward(&ys)?;
let ys = self.squeeze_excitation.forward(&ys)?;
let ys = self.project_cna.forward(&ys)?;
if use_res_connect {
ys + xs
} else {
Ok(ys)
}
}
}
fn swish(s: &Tensor) -> Result<Tensor> {
s * nn::ops::sigmoid(s)?
}
#[derive(Debug)]
struct EfficientNet {
init_cna: ConvNormActivation,
blocks: Vec<MBConv>,
final_cna: ConvNormActivation,
classifier: nn::Linear,
}
impl EfficientNet {
fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
let f_p = p.pp("features");
let first_in_c = configs[0].input_channels;
let last_out_c = configs.last().unwrap().out_channels;
let final_out_c = 4 * last_out_c;
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
let nconfigs = configs.len();
let mut blocks = vec![];
for (index, cnf) in configs.into_iter().enumerate() {
let f_p = f_p.pp(index + 1);
for r_index in 0..cnf.num_layers {
let cnf = if r_index == 0 {
cnf
} else {
MBConvConfig {
input_channels: cnf.out_channels,
stride: 1,
..cnf
}
};
blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
}
}
let final_cna =
ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
Ok(Self {
init_cna,
blocks,
final_cna,
classifier,
})
}
}
impl Module for EfficientNet {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.init_cna.forward(xs)?;
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
let xs = self.final_cna.forward(&xs)?;
// Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
self.classifier.forward(&xs)
}
}
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum Which { enum Which {
B0, B0,

View File

@ -0,0 +1,3 @@
# candle-falcon
Falcon is a general large language model.

View File

@ -14,8 +14,7 @@ use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
mod model; use candle_transformers::models::falcon::{Config, Falcon};
use model::{Config, Falcon};
struct TextGeneration { struct TextGeneration {
model: Falcon, model: Falcon,
@ -26,17 +25,25 @@ struct TextGeneration {
repeat_last_n: usize, repeat_last_n: usize,
} }
struct GenerationOptions {
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration { impl TextGeneration {
fn new( fn new(
model: Falcon, model: Falcon,
tokenizer: Tokenizer, tokenizer: Tokenizer,
generation_options: GenerationOptions,
seed: u64, seed: u64,
temp: Option<f64>,
device: &Device, device: &Device,
repeat_penalty: f32,
repeat_last_n: usize,
) -> Self { ) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp); let logits_processor =
LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p);
let repeat_penalty = generation_options.repeat_penalty;
let repeat_last_n = generation_options.repeat_last_n;
Self { Self {
model, model,
tokenizer, tokenizer,
@ -119,6 +126,10 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -186,15 +197,14 @@ fn main() -> Result<()> {
let model = Falcon::load(vb, config)?; let model = Falcon::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new( let generation_options = GenerationOptions {
model, temp: args.temperature,
tokenizer, top_p: args.top_p,
args.seed, repeat_penalty: args.repeat_penalty,
args.temperature, repeat_last_n: args.repeat_last_n,
&device, };
args.repeat_penalty, let mut pipeline =
args.repeat_last_n, TextGeneration::new(model, tokenizer, generation_options, args.seed, &device);
);
pipeline.run(&args.prompt, args.sample_len)?; pipeline.run(&args.prompt, args.sample_len)?;
Ok(()) Ok(())
} }

View File

@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write; use std::io::Write;
mod model; use candle_transformers::models::llama as model;
use model::{Config, Llama, LlamaConfig}; use model::{Config, Llama, LlamaConfig};
const EOS_TOKEN: &str = "</s>"; const EOS_TOKEN: &str = "</s>";
const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = "My favorite theorem is "; const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -43,6 +42,10 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -194,7 +197,7 @@ fn main() -> Result<()> {
println!("starting the inference loop"); println!("starting the inference loop");
print!("{prompt}"); print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
let mut index_pos = 0; let mut index_pos = 0;
let mut token_generated = 0; let mut token_generated = 0;

View File

@ -27,6 +27,10 @@ struct InferenceCmd {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
#[arg(long, default_value = "")] #[arg(long, default_value = "")]
prompt: String, prompt: String,
@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> {
None => { None => {
let cmd = InferenceCmd { let cmd = InferenceCmd {
temperature: None, temperature: None,
top_p: None,
prompt: "".to_string(), prompt: "".to_string(),
config: None, config: None,
model_id: "karpathy/tinyllamas".to_string(), model_id: "karpathy/tinyllamas".to_string(),
@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let model = Llama::load(vb, &cache, config)?; let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop"); println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature); let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
let mut index_pos = 0; let mut index_pos = 0;
print!("{}", args.prompt); print!("{}", args.prompt);

View File

@ -89,6 +89,10 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -222,7 +226,7 @@ fn main() -> Result<()> {
.to_vec(); .to_vec();
println!("starting the inference loop"); println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let mut new_tokens = vec![]; let mut new_tokens = vec![];
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
let mut index_pos = 0; let mut index_pos = 0;

View File

@ -13,7 +13,6 @@ extern crate accelerate_src;
mod encodec_model; mod encodec_model;
mod musicgen_model; mod musicgen_model;
mod nn; mod nn;
mod t5_model;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
@ -78,7 +77,7 @@ fn main() -> Result<()> {
let model = model.deserialize()?; let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device); let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let config = GenConfig::small(); let config = GenConfig::small();
let model = MusicgenForConditionalGeneration::load(vb, config)?; let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
let tokens = tokenizer let tokens = tokenizer
.encode(args.prompt.as_str(), true) .encode(args.prompt.as_str(), true)

View File

@ -1,9 +1,10 @@
use crate::{encodec_model, t5_model}; use crate::encodec_model;
use candle::{DType, Device, Result, Tensor, D}; use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{ use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module, embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder, VarBuilder,
}; };
use candle_transformers::models::t5;
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83 // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -370,7 +371,7 @@ impl MusicgenForCausalLM {
#[derive(Debug)] #[derive(Debug)]
pub struct MusicgenForConditionalGeneration { pub struct MusicgenForConditionalGeneration {
pub text_encoder: crate::t5_model::T5EncoderModel, pub text_encoder: t5::T5EncoderModel,
pub audio_encoder: crate::encodec_model::EncodecModel, pub audio_encoder: crate::encodec_model::EncodecModel,
pub decoder: MusicgenForCausalLM, pub decoder: MusicgenForCausalLM,
cfg: GenConfig, cfg: GenConfig,
@ -379,7 +380,7 @@ pub struct MusicgenForConditionalGeneration {
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct GenConfig { pub struct GenConfig {
musicgen: Config, musicgen: Config,
t5: crate::t5_model::Config, t5: t5::Config,
encodec: crate::encodec_model::Config, encodec: crate::encodec_model::Config,
} }
@ -387,7 +388,7 @@ impl GenConfig {
pub fn small() -> Self { pub fn small() -> Self {
Self { Self {
musicgen: Config::musicgen_small(), musicgen: Config::musicgen_small(),
t5: t5_model::Config::musicgen_small(), t5: t5::Config::musicgen_small(),
encodec: encodec_model::Config::musicgen_small(), encodec: encodec_model::Config::musicgen_small(),
} }
} }
@ -399,7 +400,7 @@ impl MusicgenForConditionalGeneration {
} }
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> { pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?; let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let audio_encoder = let audio_encoder =
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?; encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?; let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;

View File

@ -1,397 +0,0 @@
// T5 Text Encoder
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use candle::{DType, Result, Tensor, D};
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
vocab_size: usize,
d_model: usize,
d_kv: usize,
d_ff: usize,
num_layers: usize,
num_decoder_layers: Option<usize>,
num_heads: usize,
relative_attention_num_buckets: usize,
relative_attention_max_distance: usize,
dropout_rate: f64,
layer_norm_epsilon: f64,
initializer_factor: f64,
feed_forward_proj: Activation,
is_decoder: bool,
is_encoder_decoder: bool,
use_cache: bool,
pad_token_id: usize,
eos_token_id: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
vocab_size: 32128,
d_model: 512,
d_kv: 64,
d_ff: 2048,
num_layers: 6,
num_decoder_layers: None,
num_heads: 8,
relative_attention_num_buckets: 32,
relative_attention_max_distance: 128,
dropout_rate: 0.1,
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: Activation::Relu,
is_decoder: false,
is_encoder_decoder: true,
use_cache: true,
pad_token_id: 0,
eos_token_id: 1,
}
}
}
impl Config {
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184
pub fn musicgen_small() -> Self {
Self {
d_ff: 3072,
d_kv: 64,
d_model: 768,
dropout_rate: 0.1,
eos_token_id: 1,
feed_forward_proj: Activation::Relu,
initializer_factor: 1.0,
is_decoder: false,
is_encoder_decoder: true,
layer_norm_epsilon: 1e-6,
num_decoder_layers: Some(12),
num_heads: 12,
num_layers: 12,
pad_token_id: 0,
relative_attention_max_distance: 128,
relative_attention_num_buckets: 32,
use_cache: true,
vocab_size: 32128,
}
}
}
#[derive(Debug)]
struct T5LayerNorm {
weight: Tensor,
variance_epsilon: f64,
}
impl T5LayerNorm {
fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(h, "weight")?;
Ok(Self {
weight,
variance_epsilon: eps,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let dtype = xs.dtype();
let xs_f32 = xs.to_dtype(DType::F32)?;
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;
let xs = xs.broadcast_mul(&self.weight)?;
Ok(xs)
}
}
#[derive(Debug)]
struct T5DenseActDense {
wi: Linear,
wo: Linear,
act: Activation,
}
impl T5DenseActDense {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
Ok(Self {
wi,
wo,
act: Activation::Relu,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.wi.forward(xs)?;
let xs = self.act.forward(&xs)?;
let xs = self.wo.forward(&xs)?;
Ok(xs)
}
}
#[derive(Debug)]
struct T5LayerFF {
dense_relu_dense: T5DenseActDense,
layer_norm: T5LayerNorm,
}
impl T5LayerFF {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
// is_gated_act is not supported.
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
Ok(Self {
dense_relu_dense,
layer_norm,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = self.layer_norm.forward(xs)?;
let ys = self.dense_relu_dense.forward(&ys)?;
let xs = (xs + ys)?;
Ok(xs)
}
}
#[derive(Debug)]
struct T5Attention {
q: Linear,
k: Linear,
v: Linear,
o: Linear,
n_heads: usize,
d_kv: usize,
relative_attention_bias: Option<Embedding>,
relative_attention_num_buckets: usize,
relative_attention_max_distance: usize,
inner_dim: usize,
}
impl T5Attention {
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv;
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
let relative_attention_bias = if h {
let emb = embedding(
cfg.relative_attention_num_buckets,
cfg.num_heads,
vb.pp("relative_attention_bias"),
)?;
Some(emb)
} else {
None
};
Ok(Self {
q,
k,
v,
o,
n_heads: cfg.num_heads,
d_kv: cfg.d_kv,
relative_attention_bias,
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
relative_attention_max_distance: cfg.relative_attention_max_distance,
inner_dim,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: Apply the mask(s)?
// TODO: kv caching.
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
let q = self.q.forward(xs)?;
let k = self.k.forward(xs)?;
let v = self.v.forward(xs)?;
let q = q
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let v = v
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let scores = q.matmul(&k.t()?)?;
let scores = match &self.relative_attention_bias {
None => scores,
Some(relative_attention_bias) => {
let query_length = seq_len;
let key_length = seq_len;
// This only handles the bidirectional case.
let num_buckets = self.relative_attention_num_buckets / 2;
let relative_position = (0..query_length as u32)
.map(|i| {
(0..key_length as u32)
.map(|j| {
if i < j {
j - i + num_buckets as u32
} else {
i - j
}
})
.collect::<Vec<u32>>()
})
.collect::<Vec<Vec<_>>>();
let relative_buckets = Tensor::new(relative_position, q.device())?;
let position_bias = relative_attention_bias
.forward(&relative_buckets)?
.permute((2, 0, 1))?
.unsqueeze(0)?;
(scores + position_bias)?
// TODO: position_bias_masked?
}
};
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
.reshape((b_sz, seq_len, self.inner_dim))?;
let attn_output = self.o.forward(&attn_output)?;
Ok(attn_output)
}
}
#[derive(Debug)]
struct T5LayerSelfAttention {
self_attention: T5Attention,
layer_norm: T5LayerNorm,
}
impl T5LayerSelfAttention {
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
Ok(Self {
self_attention,
layer_norm,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let normed_xs = self.layer_norm.forward(xs)?;
let ys = self.self_attention.forward(&normed_xs)?;
let ys = (xs + ys)?;
Ok(ys)
}
}
#[derive(Debug)]
struct T5LayerCrossAttention {}
impl T5LayerCrossAttention {
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
todo!()
}
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}
#[derive(Debug)]
struct T5Block {
self_attn: T5LayerSelfAttention,
cross_attn: Option<T5LayerCrossAttention>,
ff: T5LayerFF,
}
impl T5Block {
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = vb.pp("layer");
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
let cross_attn = if cfg.is_decoder {
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
} else {
None
};
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
Ok(Self {
self_attn,
cross_attn,
ff,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.self_attn.forward(xs)?;
// TODO: clamp for f16?
if let Some(cross_attn) = &self.cross_attn {
xs = cross_attn.forward(&xs)?;
// TODO: clamp for f16?
}
let xs = self.ff.forward(&xs)?;
// TODO: clamp for f16?
Ok(xs)
}
}
#[derive(Debug)]
struct T5Stack {
block: Vec<T5Block>,
shared: Arc<Embedding>,
final_layer_norm: T5LayerNorm,
}
impl T5Stack {
fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
let block = (0..cfg.num_layers)
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
.collect::<Result<Vec<_>>>()?;
let final_layer_norm = T5LayerNorm::load(
cfg.d_model,
cfg.layer_norm_epsilon,
vb.pp("final_layer_norm"),
)?;
Ok(Self {
block,
shared: shared.clone(),
final_layer_norm,
})
}
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
let mut hidden_states = input_embeds;
for block in self.block.iter() {
hidden_states = block.forward(&hidden_states)?
}
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
Ok(hidden_states)
}
}
#[derive(Debug)]
pub struct T5EncoderModel {
shared: Arc<Embedding>,
encoder: T5Stack,
}
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
Ok(Self { shared, encoder })
}
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let encoder_outputs = self.encoder.forward(input_ids)?;
Ok(encoder_outputs)
}
}

View File

@ -0,0 +1,17 @@
# candle-quantized-t5
This example uses a quantized version of the t5 model.
```bash
$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle."
...
Eine schöne Kerze.
```
The weight file is automatically retrieved from the hub. It is also possible to
generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via:
```bash
cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
```

View File

@ -0,0 +1,214 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use std::io::Write;
use std::path::PathBuf;
use candle_transformers::models::quantized_t5 as t5;
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
T5Small,
FlanT5Small,
FlanT5Base,
FlanT5Large,
FlanT5Xl,
FlanT5Xxl,
}
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model repository to use on the HuggingFace hub.
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
weight_file: Option<String>,
// Enable/disable decoding.
#[arg(long, default_value = "false")]
disable_cache: bool,
/// Use this prompt, otherwise compute sentence similarities.
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "t5-small")]
which: Which,
}
struct T5ModelBuilder {
device: Device,
config: t5::Config,
weights_filename: PathBuf,
}
impl T5ModelBuilder {
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
let device = Device::Cpu;
let default_model = "lmz/candle-quantized-t5".to_string();
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, "main".to_string()),
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?;
let api = api.repo(repo);
let config_filename = match args.which {
Which::T5Small => api.get("config.json")?,
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
};
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = match &args.weight_file {
Some(filename) => std::path::PathBuf::from(filename),
None => match args.which {
Which::T5Small => api.get("model.gguf")?,
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
},
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?;
config.use_cache = !args.disable_cache;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok((
Self {
device,
config,
weights_filename,
},
tokenizer,
))
}
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
let device = &builder.device;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mut model = builder.build_model()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let temperature = if args.temperature <= 0. {
None
} else {
Some(args.temperature)
};
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
let encoder_output = model.encode(&input_token_ids)?;
let start = std::time::Instant::now();
for index in 0.. {
if output_token_ids.len() > 512 {
break;
}
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
} else {
let last_token = *output_token_ids.last().unwrap();
Tensor::new(&[last_token], device)?.unsqueeze(0)?
};
let logits = model
.decode(&decoder_token_ids, &encoder_output)?
.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&output_token_ids[start_at..],
)?
};
let next_token_id = logits_processor.sample(&logits)?;
if next_token_id as usize == builder.config.eos_token_id {
break;
}
output_token_ids.push(next_token_id);
if let Some(text) = tokenizer.id_to_token(next_token_id) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
}
let dt = start.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",
output_token_ids.len(),
output_token_ids.len() as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -0,0 +1,37 @@
# candle-quantized-llama: Fast Inference of quantized LLaMA models
This example provides a quantized LLaMA model similar to
[llama.cpp](https://github.com/ggerganov/llama.cpp). This is based on candle
built-in quantization methods. Supported features include:
- 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit integer quantization support.
- SIMD optimizations on Apple Silicon and x86.
- Support using the `gguf` and `ggml` file formats.
The weights are automatically downloaded for you from the [HuggingFace
Hub](https://huggingface.co/) on the first run. There are various command line
flags to use local files instead, run with `--help` to learn about them.
![Axiom of Choice](./assets/aoc.gif)
## Running some example.
```bash
cargo run --example quantized --release -- --prompt "The best thing about coding in rust is "
> avx: true, neon: false, simd128: false, f16c: true
> temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64
> loaded 291 tensors (3.79GB) in 2.17s
> params: HParams { n_vocab: 32000, n_embd: 4096, n_mult: 256, n_head: 32, n_layer: 32, n_rot: 128, ftype: 2 }
> The best thing about coding in rust is 1.) that I dont need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines.
```
## Command-line flags
Run with `--help` to see all options.
- `--which`: specify the model to use, e.g. `7b`, `13-chat`, `7b-code`.
- `--prompt interactive`: interactive mode where multiple prompts can be
entered.
- `--model mymodelfile.gguf`: use a local model file rather than getting one
from the hub.

Binary file not shown.

After

Width:  |  Height:  |  Size: 119 KiB

View File

@ -12,7 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor}; use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
mod model; use candle_transformers::models::quantized_llama as model;
use model::ModelWeights; use model::ModelWeights;
const DEFAULT_PROMPT: &str = "My favorite theorem is "; const DEFAULT_PROMPT: &str = "My favorite theorem is ";
@ -71,6 +71,10 @@ struct Args {
#[arg(long, default_value_t = 0.8)] #[arg(long, default_value_t = 0.8)]
temperature: f64, temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -310,7 +314,7 @@ fn main() -> anyhow::Result<()> {
prompt_tokens prompt_tokens
}; };
let mut all_tokens = vec![]; let mut all_tokens = vec![];
let mut logits_processor = LogitsProcessor::new(args.seed, temperature); let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let start_prompt_processing = std::time::Instant::now(); let start_prompt_processing = std::time::Instant::now();
let mut next_token = { let mut next_token = {

View File

@ -0,0 +1,40 @@
# candle-segment-anything: Segment-Anything Model
This example is based on Meta AI [Segment-Anything
Model](https://github.com/facebookresearch/segment-anything). This model
provides a robust and fast image segmentation pipeline that can be tweaked via
some prompting (requesting some points to be in the target mask, requesting some
points to be part of the background so _not_ in the target mask, specifying some
bounding box).
The default backbone can be replaced by the smaller and faster TinyViT model
based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
## Running some example.
```bash
cargo run --example segment-anything --release -- \
--image candle-examples/examples/yolo-v8/assets/bike.jpg
--use-tiny
--point-x 0.4
--point-y 0.3
```
Running this command generates a `sam_merged.jpg` file containing the original
image with a blue overlay of the selected mask. The red dot represents the prompt
specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part
of the target mask.
The values used for `--point-x` and `--point-y` should be between 0 and 1 and
are proportional to the image dimension, i.e. use 0.5 for the image center.
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)
![Leading group, Giro d'Italia 2021](./assets/sam_merged.jpg)
### Command-line flags
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
one.
- `--point-x`, `--point-y`: specifies the location of the target point.
- `--threshold`: sets the threshold value to be part of the mask, a negative
value results in a larger mask and can be specified via `--threshold=-1.2`.

Binary file not shown.

After

Width:  |  Height:  |  Size: 157 KiB

View File

@ -0,0 +1,164 @@
//! SAM: Segment Anything Model
//! https://github.com/facebookresearch/segment-anything
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::DType;
use candle_nn::VarBuilder;
use candle_transformers::models::segment_anything::sam;
use clap::Parser;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
generate_masks: bool,
/// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, default_value_t = 0.5)]
point_x: f64,
/// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, default_value_t = 0.5)]
point_y: f64,
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
/// mask, positive makes the mask more selective.
#[arg(long, default_value_t = 0.)]
threshold: f32,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Use the TinyViT based models from MobileSAM
#[arg(long)]
use_tiny: bool,
}
pub fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?;
let (image, initial_h, initial_w) =
candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
let image = image.to_device(&device)?;
println!("loaded image {image:?}");
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-sam".to_string());
let filename = if args.use_tiny {
"mobile_sam-tiny-vitt.safetensors"
} else {
"sam_vit_b_01ec64.safetensors"
};
api.get(filename)?
}
};
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = if args.use_tiny {
sam::Sam::new_tiny(vb)? // tiny vit_t
} else {
sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
};
if args.generate_masks {
// Default options similar to the Python version.
let bboxes = sam.generate_masks(
&image,
/* points_per_side */ 32,
/* crop_n_layer */ 0,
/* crop_overlap_ratio */ 512. / 1500.,
/* crop_n_points_downscale_factor */ 1,
)?;
for (idx, bbox) in bboxes.iter().enumerate() {
println!("{idx} {bbox:?}");
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
let (h, w) = mask.dims2()?;
let mask = mask.broadcast_as((3, h, w))?;
candle_examples::save_image_resize(
&mask,
format!("sam_mask{idx}.png"),
initial_h,
initial_w,
)?;
}
} else {
let point = Some((args.point_x, args.point_y));
let start_time = std::time::Instant::now();
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
println!(
"mask generated in {:.2}s",
start_time.elapsed().as_secs_f32()
);
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
let mask = (mask.ge(args.threshold)? * 255.)?;
let (_one, h, w) = mask.dims3()?;
let mask = mask.expand((3, h, w))?;
let mut img = image::io::Reader::open(&args.image)?
.decode()
.map_err(candle::Error::wrap)?;
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
Some(image) => image,
None => anyhow::bail!("error saving merged image"),
};
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
img.width(),
img.height(),
image::imageops::FilterType::CatmullRom,
);
for x in 0..img.width() {
for y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
if mask_p.0[0] > 100 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
img_p.0[1] /= 2;
img_p.0[0] /= 2;
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
}
}
}
let (x, y) = (
(args.point_x * img.width() as f64) as i32,
(args.point_y * img.height() as f64) as i32,
);
imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
.save("sam_merged.jpg")?
}
Ok(())
}

View File

@ -0,0 +1,63 @@
# candle-stable-diffusion: A Diffusers API in Rust/Candle
![rusty robot holding a candle](./assets/stable-diffusion-xl.jpg)
_A rusty robot holding a fire torch in its hand_, generated by Stable Diffusion
XL using Rust and [candle](https://github.com/huggingface/candle).
The `stable-diffusion` example is a conversion of
[diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle
rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1,
as well as Stable Diffusion XL 1.0.
## Getting the weights
The weights are automatically downloaded for you from the [HuggingFace
Hub](https://huggingface.co/) on the first run. There are various command line
flags to use local files instead, run with `--help` to learn about them.
## Running some example.
```bash
cargo run --example stable-diffusion --release --features=cuda,cudnn \
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
```
The final image is named `sd_final.png` by default.
The default scheduler is the Denoising Diffusion Implicit Model scheduler (DDIM). The
original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim).
### Command-line flags
- `--prompt`: the prompt to be used to generate the image.
- `--uncond-prompt`: the optional unconditional prompt.
- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`, or
`xl`.
- `--cpu`: use the cpu rather than the gpu (much slower).
- `--height`, `--width`: set the height and width for the generated image.
- `--n-steps`: the number of steps to be used in the diffusion process.
- `--num-samples`: the number of samples to generate.
- `--final-image`: the filename for the generated image(s).
### Using flash-attention
Using flash attention makes image generation a lot faster and uses less memory.
The downside is some long compilation time. You can set the
`CANDLE_FLASH_ATTN_BUILD_DIR` environment variable to something like
`/home/user/.candle` to ensures that the compilation artifacts are properly
cached.
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
and using the command line flag `--use-flash-attn`.
## Image to Image Pipeline
...
## FAQ
### Memory Issues
This requires a GPU with more than 8GB of memory, as a fallback the CPU version can be used
with the `--cpu` flag but is much slower.
Alternatively, reducing the height and width with the `--height` and `--width`
flag is likely to reduce memory usage significantly.

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

View File

@ -4,20 +4,10 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
mod attention; use candle_transformers::models::stable_diffusion;
mod clip;
mod ddim;
mod embeddings;
mod resnet;
mod schedulers;
mod stable_diffusion;
mod unet_2d;
mod unet_2d_blocks;
mod utils;
mod vae;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor, D}; use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser; use clap::Parser;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -96,6 +86,15 @@ struct Args {
#[arg(long)] #[arg(long)]
use_f16: bool, use_f16: bool,
#[arg(long, value_name = "FILE")]
img2img: Option<String>,
/// The strength, indicates how much to transform the initial image. The
/// value must be between 0 and 1, a value of 1 discards the initial image
/// information.
#[arg(long, default_value_t = 0.8)]
img2img_strength: f64,
} }
#[derive(Debug, Clone, Copy, clap::ValueEnum)] #[derive(Debug, Clone, Copy, clap::ValueEnum)]
@ -306,6 +305,26 @@ fn text_embeddings(
Ok(text_embeddings) Ok(text_embeddings)
} }
fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> {
let img = image::io::Reader::open(path)?.decode()?;
let (height, width) = (img.height() as usize, img.width() as usize);
let height = height - height % 32;
let width = width - width % 32;
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::CatmullRom,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?
.unsqueeze(0)?;
Ok(img)
}
fn run(args: Args) -> Result<()> { fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder; use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
@ -328,9 +347,15 @@ fn run(args: Args) -> Result<()> {
tracing, tracing,
use_f16, use_f16,
use_flash_attn, use_flash_attn,
img2img,
img2img_strength,
.. ..
} = args; } = args;
if !(0. ..=1.).contains(&img2img_strength) {
anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}")
}
let _guard = if tracing { let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init(); tracing_subscriber::registry().with(chrome_layer).init();
@ -382,25 +407,53 @@ fn run(args: Args) -> Result<()> {
println!("Building the autoencoder."); println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?; let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?; let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
let init_latent_dist = match &img2img {
None => None,
Some(image) => {
let image = image_preprocess(image)?.to_device(&device)?;
Some(vae.encode(&image)?)
}
};
println!("Building the unet."); println!("Building the unet.");
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?; let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?; let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
let t_start = if img2img.is_some() {
n_steps - (n_steps as f64 * img2img_strength) as usize
} else {
0
};
let bsize = 1; let bsize = 1;
for idx in 0..num_samples { for idx in 0..num_samples {
let mut latents = Tensor::randn( let timesteps = scheduler.timesteps();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
if t_start < timesteps.len() {
let noise = latents.randn_like(0f64, 1f64)?;
scheduler.add_noise(&latents, noise, timesteps[t_start])?
} else {
latents
}
}
None => {
let latents = Tensor::randn(
0f32, 0f32,
1f32, 1f32,
(bsize, 4, sd_config.height / 8, sd_config.width / 8), (bsize, 4, sd_config.height / 8, sd_config.width / 8),
&device, &device,
)? )?;
.to_dtype(dtype)?;
// scale the initial noise by the standard deviation required by the scheduler // scale the initial noise by the standard deviation required by the scheduler
latents = (latents * scheduler.init_noise_sigma())?; (latents * scheduler.init_noise_sigma())?
}
};
let mut latents = latents.to_dtype(dtype)?;
println!("starting sampling"); println!("starting sampling");
for (timestep_index, &timestep) in scheduler.timesteps().iter().enumerate() { for (timestep_index, &timestep) in timesteps.iter().enumerate() {
if timestep_index < t_start {
continue;
}
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;

View File

@ -0,0 +1,25 @@
# candle-t5
## Encoder-decoder example:
```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
...
Running on CPU, to run on GPU, build this example with `--features cuda`
Eine schöne Kerze.
9 tokens generated (2.42 token/s)
```
## Sentence embedding example:
```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
...
[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
[-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],
[ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962],
[-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990],
[ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]]
Tensor[[1, 5, 512], f32]
Took 303.766583ms
```

View File

@ -0,0 +1,314 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use std::io::Write;
use std::path::PathBuf;
use candle_transformers::models::t5;
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32;
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model repository to use on the HuggingFace hub.
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
/// Enable decoding.
#[arg(long)]
decode: bool,
// Enable/disable decoding.
#[arg(long, default_value = "false")]
disable_cache: bool,
/// Use this prompt, otherwise compute sentence similarities.
#[arg(long)]
prompt: Option<String>,
/// If set along with --decode, will use this prompt to initialize the decoder.
#[arg(long)]
decoder_prompt: Option<String>,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
struct T5ModelBuilder {
device: Device,
config: t5::Config,
weights_filename: Vec<PathBuf>,
}
impl T5ModelBuilder {
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
let device = candle_examples::device(args.cpu)?;
let default_model = "t5-small".to_string();
let default_revision = "refs/pr/15".to_string();
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
let api = Api::new()?;
let api = api.repo(repo);
let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = if model_id == "google/flan-t5-xxl" {
vec![
api.get("model-00001-of-00005.safetensors")?,
api.get("model-00002-of-00005.safetensors")?,
api.get("model-00003-of-00005.safetensors")?,
api.get("model-00004-of-00005.safetensors")?,
api.get("model-00005-of-00005.safetensors")?,
]
} else {
vec![api.get("model.safetensors")?]
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?;
config.use_cache = !args.disable_cache;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok((
Self {
device,
config,
weights_filename,
},
tokenizer,
))
}
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
let weights = self
.weights_filename
.iter()
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
.collect::<candle::Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|w| w.deserialize())
.collect::<candle::Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
}
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
let weights = self
.weights_filename
.iter()
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
.collect::<candle::Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|w| w.deserialize())
.collect::<candle::Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
let device = &builder.device;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
match args.prompt {
Some(prompt) => {
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
if !args.decode {
let mut model = builder.build_encoder()?;
let start = std::time::Instant::now();
let ys = model.forward(&input_token_ids)?;
println!("{ys}");
println!("Took {:?}", start.elapsed());
} else {
let mut model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
if let Some(decoder_prompt) = &args.decoder_prompt {
print!("{decoder_prompt}");
output_token_ids.extend(
tokenizer
.encode(decoder_prompt.to_string(), false)
.map_err(E::msg)?
.get_ids()
.to_vec(),
);
}
let temperature = if args.temperature <= 0. {
None
} else {
Some(args.temperature)
};
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
let encoder_output = model.encode(&input_token_ids)?;
let start = std::time::Instant::now();
for index in 0.. {
if output_token_ids.len() > 512 {
break;
}
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
} else {
let last_token = *output_token_ids.last().unwrap();
Tensor::new(&[last_token], device)?.unsqueeze(0)?
};
let logits = model
.decode(&decoder_token_ids, &encoder_output)?
.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&output_token_ids[start_at..],
)?
};
let next_token_id = logits_processor.sample(&logits)?;
if next_token_id as usize == builder.config.eos_token_id {
break;
}
output_token_ids.push(next_token_id);
if let Some(text) = tokenizer.id_to_token(next_token_id) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
}
let dt = start.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",
output_token_ids.len(),
output_token_ids.len() as f64 / dt.as_secs_f64(),
);
}
}
None => {
let mut model = builder.build_encoder()?;
let sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
];
let n_sentences = sentences.len();
let mut all_embeddings = Vec::with_capacity(n_sentences);
for sentence in sentences {
let tokens = tokenizer
.encode(sentence, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
let embeddings = model.forward(&token_ids)?;
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if args.normalize_embeddings {
normalize_l2(&embeddings)?
} else {
embeddings
};
println!("pooled embeddings {:?}", embeddings.shape());
all_embeddings.push(embeddings)
}
let mut similarities = vec![];
for (i, e_i) in all_embeddings.iter().enumerate() {
for (j, e_j) in all_embeddings
.iter()
.enumerate()
.take(n_sentences)
.skip(i + 1)
{
let sum_ij = (e_i * e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (e_i * e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (e_j * e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}
}
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}

View File

@ -0,0 +1,39 @@
# candle-whisper: speech recognition
An implementation of [OpenAI Whisper](https://github.com/openai/whisper) using
candle. Whisper is a general purpose speech recognition model, it can be used to
convert audio files (in the `.wav` format) to text. Supported features include
language detection as well as multilingual speech recognition.
## Running some example
If no audio file is passed as input, a [sample
file](https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav) is automatically downloaded
from the hub.
```bash
cargo run --example whisper --release
> No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
> loaded wav data: Header { audio_format: 1, channel_count: 1, sampling_rate: 16000, bytes_per_second: 32000, bytes_per_sample: 2, bits_per_sample: 16 }
> pcm data loaded 176000
> loaded mel: [1, 80, 3000]
> 0.0s -- 30.0s: And so my fellow Americans ask not what your country can do for you ask what you can do for your country
```
In order to use the multilingual mode, specify a multilingual model via the
`--model` flag, see the details below.
## Command line flags
- `--input`: the audio file to be converted to text, in wav format.
- `--language`: force the language to some specific value rather than being
detected, e.g. `en`.
- `--task`: the task to be performed, can be `transcribe` (return the text data
in the original language) or `translate` (translate the text to English).
- `--timestamps`: enable the timestamp mode where some timestamps are reported
for each recognized audio extracts.
- `--model`: the model to be used. Models that do not end with `-en` are
multilingual models, other ones are English only models. The supported models
are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`, `medium`,
`medium.en`, `large`, and `large-v2`.

View File

@ -10,41 +10,16 @@ extern crate accelerate_src;
extern crate intel_mkl_src; extern crate intel_mkl_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor}; use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder}; use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng}; use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
mod audio;
mod model;
use model::{Config, Whisper};
mod multilingual; mod multilingual;
use candle_transformers::models::whisper::{self as m, audio, model};
const DTYPE: DType = DType::F32; use model::{Config, Whisper};
// Audio parameters.
const SAMPLE_RATE: usize = 16000;
const N_FFT: usize = 400;
const N_MELS: usize = 80;
const HOP_LENGTH: usize = 160;
const CHUNK_LENGTH: usize = 30;
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
const NO_SPEECH_THRESHOLD: f64 = 0.6;
const LOGPROB_THRESHOLD: f64 = -1.0;
const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits.
const SOT_TOKEN: &str = "<|startoftranscript|>";
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
const TRANSLATE_TOKEN: &str = "<|translate|>";
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -94,7 +69,7 @@ impl Decoder {
timestamps: bool, timestamps: bool,
verbose: bool, verbose: bool,
) -> Result<Self> { ) -> Result<Self> {
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
// Suppress the notimestamps token when in timestamps mode. // Suppress the notimestamps token when in timestamps mode.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452 // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32) let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
@ -109,11 +84,11 @@ impl Decoder {
}) })
.collect(); .collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?; let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?; let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?; let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
Ok(Self { Ok(Self {
model, model,
rng: rand::rngs::StdRng::seed_from_u64(seed), rng: rand::rngs::StdRng::seed_from_u64(seed),
@ -220,17 +195,17 @@ impl Decoder {
} }
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> { fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
for (i, &t) in TEMPERATURES.iter().enumerate() { for (i, &t) in m::TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult> = self.decode(segment, t); let dr: Result<DecodingResult> = self.decode(segment, t);
if i == TEMPERATURES.len() - 1 { if i == m::TEMPERATURES.len() - 1 {
return dr; return dr;
} }
// On errors, we try again with a different temperature. // On errors, we try again with a different temperature.
match dr { match dr {
Ok(dr) => { Ok(dr) => {
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < LOGPROB_THRESHOLD; || dr.avg_logprob < m::LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
return Ok(dr); return Ok(dr);
} }
} }
@ -248,13 +223,13 @@ impl Decoder {
let mut segments = vec![]; let mut segments = vec![];
while seek < content_frames { while seek < content_frames {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, N_FRAMES); let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?; let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?; let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size; seek += segment_size;
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
println!("no speech detected, skipping {seek} {dr:?}"); println!("no speech detected, skipping {seek} {dr:?}");
continue; continue;
} }
@ -431,7 +406,6 @@ fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
let _guard = if args.tracing { let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init(); tracing_subscriber::registry().with(chrome_layer).init();
Some(guard) Some(guard)
@ -493,8 +467,8 @@ fn main() -> Result<()> {
let mut input = std::fs::File::open(input)?; let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?; let (header, data) = wav::read(&mut input)?;
println!("loaded wav data: {header:?}"); println!("loaded wav data: {header:?}");
if header.sampling_rate != SAMPLE_RATE as u32 { if header.sampling_rate != m::SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE) anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
} }
let data = data.as_sixteen().expect("expected 16 bit wav file"); let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
@ -502,14 +476,14 @@ fn main() -> Result<()> {
.map(|v| *v as f32 / 32768.) .map(|v| *v as f32 / 32768.)
.collect(); .collect();
println!("pcm data loaded {}", pcm_data.len()); println!("pcm data loaded {}", pcm_data.len());
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?; let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
let mel_len = mel.len(); let mel_len = mel.len();
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?; let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims()); println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?; let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = Whisper::load(&vb, config)?; let mut model = Whisper::load(&vb, config)?;

View File

@ -113,7 +113,7 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
.iter() .iter()
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>"))) .map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?; let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
let audio_features = model.encoder.forward(&mel, true)?; let audio_features = model.encoder.forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?; let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;

View File

@ -0,0 +1,27 @@
# candle-wuerstchen: Efficient Pretraining of Text-to-Image Models
![anthropomorphic cat dressed as a fire fighter](./assets/cat.jpg)
The `wuerstchen` example is a port of the [diffusers
implementation](https://github.com/huggingface/diffusers/tree/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen) for Würstchen v2.
The candle implementation reproduces the same structure/files for models and
pipelines. Useful resources:
- [Official implementation](https://github.com/dome272/Wuerstchen).
- [Arxiv paper](https://arxiv.org/abs/2306.00637).
- Blog post: [Introducing Würstchen: Fast Diffusion for Image Generation](https://huggingface.co/blog/wuerstchen).
## Getting the weights
The weights are automatically downloaded for you from the [HuggingFace
Hub](https://huggingface.co/) on the first run. There are various command line
flags to use local files instead, run with `--help` to learn about them.
## Running some example.
```bash
cargo run --example wuerstchen --release --features cuda,cudnn -- \
--prompt "Anthropomorphic cat dressed as a fire fighter"
```
The final image is named `sd_final.png` by default.

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

View File

@ -0,0 +1,396 @@
#![allow(unused)]
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use candle_transformers::models::stable_diffusion;
use candle_transformers::models::wuerstchen;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
const PRIOR_GUIDANCE_SCALE: f64 = 4.0;
const RESOLUTION_MULTIPLE: f64 = 42.67;
const LATENT_DIM_SCALE: f64 = 10.67;
const PRIOR_CIN: usize = 16;
const DECODER_CIN: usize = 4;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The prompt to be used for image generation.
#[arg(
long,
default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
)]
prompt: String,
#[arg(long, default_value = "")]
uncond_prompt: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
/// The height in pixels of the generated image.
#[arg(long)]
height: Option<usize>,
/// The width in pixels of the generated image.
#[arg(long)]
width: Option<usize>,
/// The decoder weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
decoder_weights: Option<String>,
/// The CLIP weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
clip_weights: Option<String>,
/// The CLIP weight file used by the prior model, in .safetensors format.
#[arg(long, value_name = "FILE")]
prior_clip_weights: Option<String>,
/// The prior weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
prior_weights: Option<String>,
/// The VQGAN weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
vqgan_weights: Option<String>,
#[arg(long, value_name = "FILE")]
/// The file specifying the tokenizer to used for tokenization.
tokenizer: Option<String>,
#[arg(long, value_name = "FILE")]
/// The file specifying the tokenizer to used for prior tokenization.
prior_tokenizer: Option<String>,
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
#[arg(long)]
sliced_attention_size: Option<usize>,
/// The number of steps to run the diffusion for.
#[arg(long, default_value_t = 30)]
n_steps: usize,
/// The number of samples to generate.
#[arg(long, default_value_t = 1)]
num_samples: i64,
/// The name of the final image to generate.
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
final_image: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
PriorTokenizer,
Clip,
PriorClip,
Decoder,
VqGan,
Prior,
}
impl ModelFile {
fn get(&self, filename: Option<String>) -> Result<std::path::PathBuf> {
use hf_hub::api::sync::Api;
match filename {
Some(filename) => Ok(std::path::PathBuf::from(filename)),
None => {
let repo_main = "warp-ai/wuerstchen";
let repo_prior = "warp-ai/wuerstchen-prior";
let (repo, path) = match self {
Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"),
Self::PriorTokenizer => (repo_prior, "tokenizer/tokenizer.json"),
Self::Clip => (repo_main, "text_encoder/model.safetensors"),
Self::PriorClip => (repo_prior, "text_encoder/model.safetensors"),
Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"),
Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"),
Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"),
};
let filename = Api::new()?.model(repo.to_string()).get(path)?;
Ok(filename)
}
}
}
}
fn output_filename(
basename: &str,
sample_idx: i64,
num_samples: i64,
timestep_idx: Option<usize>,
) -> String {
let filename = if num_samples > 1 {
match basename.rsplit_once('.') {
None => format!("{basename}.{sample_idx}.png"),
Some((filename_no_extension, extension)) => {
format!("{filename_no_extension}.{sample_idx}.{extension}")
}
}
} else {
basename.to_string()
};
match timestep_idx {
None => filename,
Some(timestep_idx) => match filename.rsplit_once('.') {
None => format!("{filename}-{timestep_idx}.png"),
Some((filename_no_extension, extension)) => {
format!("{filename_no_extension}-{timestep_idx}.{extension}")
}
},
}
}
fn encode_prompt(
prompt: &str,
uncond_prompt: Option<&str>,
tokenizer: std::path::PathBuf,
clip_weights: std::path::PathBuf,
clip_config: stable_diffusion::clip::Config,
device: &Device,
) -> Result<Tensor> {
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let pad_id = match &clip_config.pad_with {
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
};
println!("Running with prompt \"{prompt}\".");
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let tokens_len = tokens.len();
while tokens.len() < clip_config.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
println!("Building the clip transformer.");
let text_model =
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
match uncond_prompt {
None => Ok(text_embeddings),
Some(uncond_prompt) => {
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let uncond_tokens_len = uncond_tokens.len();
while uncond_tokens.len() < clip_config.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
let uncond_embeddings =
text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
Ok(text_embeddings)
}
}
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args {
prompt,
uncond_prompt,
cpu,
height,
width,
n_steps,
tokenizer,
final_image,
sliced_attention_size,
num_samples,
clip_weights,
prior_weights,
vqgan_weights,
decoder_weights,
tracing,
..
} = args;
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(cpu)?;
let height = height.unwrap_or(1024);
let width = width.unwrap_or(1024);
let prior_text_embeddings = {
let tokenizer = ModelFile::PriorTokenizer.get(args.prior_tokenizer)?;
let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
encode_prompt(
&prompt,
Some(&uncond_prompt),
tokenizer.clone(),
weights,
stable_diffusion::clip::Config::wuerstchen_prior(),
&device,
)?
};
println!("generated prior text embeddings {prior_text_embeddings:?}");
let text_embeddings = {
let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
let weights = ModelFile::Clip.get(clip_weights)?;
encode_prompt(
&prompt,
None,
tokenizer.clone(),
weights,
stable_diffusion::clip::Config::wuerstchen(),
&device,
)?
};
println!("generated text embeddings {text_embeddings:?}");
println!("Building the prior.");
let b_size = 1;
let image_embeddings = {
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let mut latents = Tensor::randn(
0f32,
1f32,
(b_size, PRIOR_CIN, latent_height, latent_width),
&device,
)?;
let prior = {
let prior_weights = ModelFile::Prior.get(prior_weights)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::prior::WPrior::new(
/* c_in */ PRIOR_CIN,
/* c */ 1536,
/* c_cond */ 1280,
/* c_r */ 64,
/* depth */ 32,
/* nhead */ 24,
args.use_flash_attn,
vb,
)?
};
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
let timesteps = prior_scheduler.timesteps();
let timesteps = &timesteps[..timesteps.len() - 1];
println!("prior denoising");
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
let noise_pred = (noise_pred_uncond
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
}
((latents * 42.)? - 1.)?
};
println!("Building the vqgan.");
let vqgan = {
let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? };
let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::paella_vq::PaellaVQ::new(vb)?
};
println!("Building the decoder.");
// https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
let decoder = {
let decoder_weights = ModelFile::Decoder.get(decoder_weights)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? };
let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::diffnext::WDiffNeXt::new(
/* c_in */ DECODER_CIN,
/* c_out */ DECODER_CIN,
/* c_r */ 64,
/* c_cond */ 1024,
/* clip_embd */ 1024,
/* patch_size */ 2,
args.use_flash_attn,
vb,
)?
};
for idx in 0..num_samples {
// https://huggingface.co/warp-ai/wuerstchen/blob/main/model_index.json
let latent_height = (image_embeddings.dim(2)? as f64 * LATENT_DIM_SCALE) as usize;
let latent_width = (image_embeddings.dim(3)? as f64 * LATENT_DIM_SCALE) as usize;
let mut latents = Tensor::randn(
0f32,
1f32,
(b_size, DECODER_CIN, latent_height, latent_width),
&device,
)?;
println!("diffusion process with prior {image_embeddings:?}");
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(12, Default::default())?;
let timesteps = scheduler.timesteps();
let timesteps = &timesteps[..timesteps.len() - 1];
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
let noise_pred =
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
latents = scheduler.step(&noise_pred, t, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
}
println!(
"Generating the final image for sample {}/{}.",
idx + 1,
num_samples
);
let image = vqgan.decode(&(&latents * 0.3764)?)?;
// TODO: Add the clamping between 0 and 1.
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
candle_examples::save_image(&image, image_filename)?
}
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
run(args)
}

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
use candle_examples::object_detection::{non_maximum_suppression, Bbox}; use candle_transformers::object_detection::{non_maximum_suppression, Bbox};
mod darknet; mod darknet;
use anyhow::Result; use anyhow::Result;
@ -46,7 +46,7 @@ pub fn report(
let (npreds, pred_size) = pred.dims2()?; let (npreds, pred_size) = pred.dims2()?;
let nclasses = pred_size - 5; let nclasses = pred_size - 5;
// The bounding boxes grouped by (maximum) class index. // The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect(); let mut bboxes: Vec<Vec<Bbox<()>>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold. // Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds { for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.get(index)?)?; let pred = Vec::<f32>::try_from(pred.get(index)?)?;
@ -65,7 +65,7 @@ pub fn report(
xmax: pred[0] + pred[2] / 2., xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2., ymax: pred[1] + pred[3] / 2.,
confidence, confidence,
keypoints: vec![], data: (),
}; };
bboxes[class_index].push(bbox) bboxes[class_index].push(bbox)
} }

View File

@ -0,0 +1,47 @@
# candle-yolo-v8: Object Detection and Pose Estimation
This is a port of [Ultralytics
YOLOv8](https://github.com/ultralytics/ultralytics). The implementation is based
on the [tinygrad
version](https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py)
and on the model architecture described in this
[issue](https://github.com/ultralytics/ultralytics/issues/189). The supported
tasks are object detection and pose estimation.
You can try this model online on the [Candle YOLOv8
Space](https://huggingface.co/spaces/lmz/candle-yolo). The model then fully runs
in your browser using WebAssembly - if you use a custom image it will never
leave your phone/computer!
## Running some example
### Object Detection
```bash
cargo run --example yolo-v8 --release -- candle-examples/examples/yolo-v8/assets/bike.jpg
```
This prints details about the detected objects and generates a `bike.pp.jpg` file.
![Leading group, Giro d'Italia 2021](./assets/bike.jpg)
Image source:
[wikimedia](https://commons.wikimedia.org/wiki/File:Leading_group,_Giro_d%27Italia_2021,_Stage_15.jpg).
![Leading group, Giro d'Italia 2021](./assets/bike.od.jpg)
### Pose Estimation
```bash
cargo run --example yolo-v8 --release -- \
candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose
```
![Leading group, Giro d'Italia 2021](./assets/bike.pose.jpg)
### Command-line flags
- `--which`: select the model variant to be used, `n`, `s` , `m`, `l`, or `x` by
increasing size and quality.
- `--task`: `detect` for object detection and `pose` for pose estimation.
- `--legend-size`: the size of the characters to print.
- `--model`: use a local model file rather than downloading it from the hub.

Binary file not shown.

After

Width:  |  Height:  |  Size: 179 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 175 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 189 KiB

View File

@ -8,8 +8,8 @@ mod model;
use model::{Multiples, YoloV8, YoloV8Pose}; use model::{Multiples, YoloV8, YoloV8Pose};
use candle::{DType, Device, IndexOp, Result, Tensor}; use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_examples::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use candle_nn::{Module, VarBuilder}; use candle_nn::{Module, VarBuilder};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use image::DynamicImage; use image::DynamicImage;
@ -64,7 +64,7 @@ pub fn report_detect(
let (pred_size, npreds) = pred.dims2()?; let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4; let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index. // The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect(); let mut bboxes: Vec<Vec<Bbox<Vec<KeyPoint>>>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold. // Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds { for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?; let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
@ -83,7 +83,7 @@ pub fn report_detect(
xmax: pred[0] + pred[2] / 2., xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2., ymax: pred[1] + pred[3] / 2.,
confidence, confidence,
keypoints: vec![], data: vec![],
}; };
bboxes[class_index].push(bbox) bboxes[class_index].push(bbox)
} }
@ -176,7 +176,7 @@ pub fn report_pose(
xmax: pred[0] + pred[2] / 2., xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2., ymax: pred[1] + pred[3] / 2.,
confidence, confidence,
keypoints, data: keypoints,
}; };
bboxes.push(bbox) bboxes.push(bbox)
} }
@ -204,7 +204,7 @@ pub fn report_pose(
image::Rgb([255, 0, 0]), image::Rgb([255, 0, 0]),
); );
} }
for kp in b.keypoints.iter() { for kp in b.data.iter() {
if kp.mask < 0.6 { if kp.mask < 0.6 {
continue; continue;
} }
@ -219,8 +219,8 @@ pub fn report_pose(
} }
for &(idx1, idx2) in KP_CONNECTIONS.iter() { for &(idx1, idx2) in KP_CONNECTIONS.iter() {
let kp1 = &b.keypoints[idx1]; let kp1 = &b.data[idx1];
let kp2 = &b.keypoints[idx2]; let kp2 = &b.data[idx2];
if kp1.mask < 0.6 || kp2.mask < 0.6 { if kp1.mask < 0.6 || kp2.mask < 0.6 {
continue; continue;
} }

View File

@ -1,6 +1,5 @@
pub mod coco_classes; pub mod coco_classes;
pub mod imagenet; pub mod imagenet;
pub mod object_detection;
use candle::{Device, Result, Tensor}; use candle::{Device, Result, Tensor};
@ -16,6 +15,36 @@ pub fn device(cpu: bool) -> Result<Device> {
} }
} }
pub fn load_image<P: AsRef<std::path::Path>>(
p: P,
resize_longest: Option<usize>,
) -> Result<(Tensor, usize, usize)> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?;
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
let img = match resize_longest {
None => img,
Some(resize_longest) => {
let (height, width) = (img.height(), img.width());
let resize_longest = resize_longest as u32;
let (height, width) = if height < width {
let h = (resize_longest * height) / width;
(h, resize_longest)
} else {
let w = (resize_longest * width) / height;
(resize_longest, w)
};
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
}
};
let (height, width) = (img.height() as usize, img.width() as usize);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
Ok((data, initial_h, initial_w))
}
pub fn load_image_and_resize<P: AsRef<std::path::Path>>( pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
p: P, p: P,
width: usize, width: usize,
@ -35,14 +64,14 @@ pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
} }
/// Saves an image to disk using the image crate, this expects an input with shape /// Saves an image to disk using the image crate, this expects an input with shape
/// (c, width, height). /// (c, height, width).
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> { pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
let p = p.as_ref(); let p = p.as_ref();
let (channel, width, height) = img.dims3()?; let (channel, height, width) = img.dims3()?;
if channel != 3 { if channel != 3 {
candle::bail!("save_image expects an input of shape (3, width, height)") candle::bail!("save_image expects an input of shape (3, height, width)")
} }
let img = img.transpose(0, 1)?.t()?.flatten_all()?; let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?; let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> = let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
@ -52,3 +81,27 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
image.save(p).map_err(candle::Error::wrap)?; image.save(p).map_err(candle::Error::wrap)?;
Ok(()) Ok(())
} }
pub fn save_image_resize<P: AsRef<std::path::Path>>(
img: &Tensor,
p: P,
h: usize,
w: usize,
) -> Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle::bail!("error saving image {p:?}"),
};
let image = image::DynamicImage::from(image);
let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-flash-attn" name = "candle-flash-attn"
version = "0.2.1" version = "0.2.3"
edition = "2021" edition = "2021"
description = "Flash attention layer for the candle ML framework." description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core", features = ["cuda"], version = "0.2.1", package = "candle-core" } candle = { path = "../candle-core", features = ["cuda"], version = "0.2.3", package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] } half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies] [build-dependencies]
@ -21,4 +21,4 @@ rayon = "1.7.0"
[dev-dependencies] [dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", version = "0.2.1", features = ["cuda"] } candle-nn = { path = "../candle-nn", version = "0.2.3", features = ["cuda"] }

View File

@ -6,7 +6,7 @@ use rayon::prelude::*;
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr; use std::str::FromStr;
const KERNEL_FILES: [&str; 9] = [ const KERNEL_FILES: [&str; 17] = [
"flash_api.cu", "flash_api.cu",
"flash_fwd_hdim128_fp16_sm80.cu", "flash_fwd_hdim128_fp16_sm80.cu",
"flash_fwd_hdim160_fp16_sm80.cu", "flash_fwd_hdim160_fp16_sm80.cu",
@ -16,14 +16,14 @@ const KERNEL_FILES: [&str; 9] = [
"flash_fwd_hdim32_fp16_sm80.cu", "flash_fwd_hdim32_fp16_sm80.cu",
"flash_fwd_hdim64_fp16_sm80.cu", "flash_fwd_hdim64_fp16_sm80.cu",
"flash_fwd_hdim96_fp16_sm80.cu", "flash_fwd_hdim96_fp16_sm80.cu",
// "flash_fwd_hdim128_bf16_sm80.cu", "flash_fwd_hdim128_bf16_sm80.cu",
// "flash_fwd_hdim160_bf16_sm80.cu", "flash_fwd_hdim160_bf16_sm80.cu",
// "flash_fwd_hdim192_bf16_sm80.cu", "flash_fwd_hdim192_bf16_sm80.cu",
// "flash_fwd_hdim224_bf16_sm80.cu", "flash_fwd_hdim224_bf16_sm80.cu",
// "flash_fwd_hdim256_bf16_sm80.cu", "flash_fwd_hdim256_bf16_sm80.cu",
// "flash_fwd_hdim32_bf16_sm80.cu", "flash_fwd_hdim32_bf16_sm80.cu",
// "flash_fwd_hdim64_bf16_sm80.cu", "flash_fwd_hdim64_bf16_sm80.cu",
// "flash_fwd_hdim96_bf16_sm80.cu", "flash_fwd_hdim96_bf16_sm80.cu",
]; ];
fn main() -> Result<()> { fn main() -> Result<()> {
@ -57,9 +57,20 @@ fn main() -> Result<()> {
#[allow(clippy::redundant_clone)] #[allow(clippy::redundant_clone)]
out_dir.clone() out_dir.clone()
} }
Ok(build_dir) => PathBuf::from(build_dir), Ok(build_dir) => {
let path = PathBuf::from(build_dir);
path.canonicalize().expect(&format!(
"Directory doesn't exists: {} (the current directory is {})",
&path.display(),
std::env::current_dir()?.display()
))
}
}; };
set_cuda_include_dir()?; set_cuda_include_dir()?;
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let compute_cap = compute_cap()?; let compute_cap = compute_cap()?;
let out_file = build_dir.join("libflashattention.a"); let out_file = build_dir.join("libflashattention.a");
@ -95,14 +106,21 @@ fn main() -> Result<()> {
.args(["--default-stream", "per-thread"]) .args(["--default-stream", "per-thread"])
.arg("-Icutlass/include") .arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr") .arg("--expt-relaxed-constexpr")
.arg(cu_file); .arg("--verbose");
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(cu_file);
let output = command let output = command
.spawn() .spawn()
.context("failed spawning nvcc")? .context("failed spawning nvcc")?
.wait_with_output()?; .wait_with_output()?;
if !output.status.success() { if !output.status.success() {
anyhow::bail!( anyhow::bail!(
"nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
&command,
String::from_utf8_lossy(&output.stdout), String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr) String::from_utf8_lossy(&output.stderr)
) )
@ -122,7 +140,8 @@ fn main() -> Result<()> {
.wait_with_output()?; .wait_with_output()?;
if !output.status.success() { if !output.status.success() {
anyhow::bail!( anyhow::bail!(
"nvcc error while linking:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
&command,
String::from_utf8_lossy(&output.stdout), String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr) String::from_utf8_lossy(&output.stderr)
) )

View File

@ -1,20 +1,19 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// TODO: Switch back to handling bf16.
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
FWD_HEADDIM_SWITCH(params.d, [&] {
run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
});
}
// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) { // void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// FP16_SWITCH(!params.is_bf16, [&] {
// FWD_HEADDIM_SWITCH(params.d, [&] { // FWD_HEADDIM_SWITCH(params.d, [&] {
// run_mha_fwd_<elem_type, kHeadDim>(params, stream); // run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
// });
// }); // });
// } // }
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
FP16_SWITCH(!params.is_bf16, [&] {
FWD_HEADDIM_SWITCH(params.d, [&] {
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
});
});
}
extern "C" void run_mha( extern "C" void run_mha(
void *q_ptr, void *q_ptr,
void *k_ptr, void *k_ptr,
@ -52,7 +51,8 @@ extern "C" void run_mha(
uint32_t seqlen_q_rounded, uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded, uint32_t seqlen_k_rounded,
int is_causal int is_causal,
int is_bf16
) { ) {
Flash_fwd_params params; Flash_fwd_params params;
// Reset the parameters // Reset the parameters
@ -102,7 +102,7 @@ extern "C" void run_mha(
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout; params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
params.is_bf16 = 0; params.is_bf16 = is_bf16;
params.cu_seqlens_q = cu_seqlens_q_ptr; params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr; params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`. params.p_ptr = nullptr; // used for `return_softmax`.

View File

@ -38,6 +38,7 @@ extern "C" {
seqlen_k_rounded: u32, seqlen_k_rounded: u32,
is_causal: c_int, is_causal: c_int,
is_bf16: c_int,
); );
} }

View File

@ -4,7 +4,7 @@ use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr; use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Layout, Result, Shape, Tensor}; use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use half::f16; use half::{bf16, f16};
pub struct FlashAttn { pub struct FlashAttn {
pub softmax_scale: f32, pub softmax_scale: f32,
@ -15,24 +15,10 @@ fn round_multiple(x: usize, m: usize) -> usize {
(x + m - 1) / m * m (x + m - 1) / m * m
} }
impl candle::CustomOp3 for FlashAttn { impl FlashAttn {
fn name(&self) -> &'static str { fn cuda_fwd_t<
"flash-attn" T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
} >(
fn cpu_fwd(
&self,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for flash-attn")
}
fn cuda_fwd(
&self, &self,
q: &candle::CudaStorage, q: &candle::CudaStorage,
q_l: &Layout, q_l: &Layout,
@ -40,15 +26,16 @@ impl candle::CustomOp3 for FlashAttn {
k_l: &Layout, k_l: &Layout,
v: &candle::CudaStorage, v: &candle::CudaStorage,
v_l: &Layout, v_l: &Layout,
is_bf16: bool,
) -> Result<(candle::CudaStorage, Shape)> { ) -> Result<(candle::CudaStorage, Shape)> {
// https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187 // https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187
let dev = q.device(); let dev = q.device();
let out_shape = q_l.shape().clone(); let out_shape = q_l.shape().clone();
let out_l = Layout::contiguous(&out_shape); let out_l = Layout::contiguous(&out_shape);
let q = q.as_cuda_slice::<f16>()?; let q = q.as_cuda_slice::<T>()?;
let k = k.as_cuda_slice::<f16>()?; let k = k.as_cuda_slice::<T>()?;
let v = v.as_cuda_slice::<f16>()?; let v = v.as_cuda_slice::<T>()?;
let q = q.slice(q_l.start_offset()..); let q = q.slice(q_l.start_offset()..);
let k = k.slice(k_l.start_offset()..); let k = k.slice(k_l.start_offset()..);
let v = v.slice(v_l.start_offset()..); let v = v.slice(v_l.start_offset()..);
@ -104,10 +91,11 @@ impl candle::CustomOp3 for FlashAttn {
let seqlen_k_rounded = round_multiple(seqlen_k, 128); let seqlen_k_rounded = round_multiple(seqlen_k, 128);
let elem_count = out_shape.elem_count(); let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?; let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?; let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let causal = if self.causal { 1 } else { 0 }; let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 };
unsafe { unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
@ -146,6 +134,7 @@ impl candle::CustomOp3 for FlashAttn {
/* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal, /* is_causal */ causal,
/* is_bf16 */ is_bf16,
) )
} }
@ -154,6 +143,40 @@ impl candle::CustomOp3 for FlashAttn {
} }
} }
impl candle::CustomOp3 for FlashAttn {
fn name(&self) -> &'static str {
"flash-attn"
}
fn cpu_fwd(
&self,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for flash-attn")
}
fn cuda_fwd(
&self,
q: &candle::CudaStorage,
q_l: &Layout,
k: &candle::CudaStorage,
k_l: &Layout,
v: &candle::CudaStorage,
v_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match q.dtype() {
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false),
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true),
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
}
}
}
/// Flash-attention v2 layer. /// Flash-attention v2 layer.
/// ///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. /// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
@ -190,24 +213,10 @@ struct FlashAttnVarLen {
seqlens_k: Tensor, seqlens_k: Tensor,
} }
impl candle::CustomOp3 for FlashAttnVarLen { impl FlashAttnVarLen {
fn name(&self) -> &'static str { fn cuda_fwd_t<
"flash-attn-varlen" T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
} >(
fn cpu_fwd(
&self,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for flash-attn")
}
fn cuda_fwd(
&self, &self,
q: &candle::CudaStorage, q: &candle::CudaStorage,
q_l: &Layout, q_l: &Layout,
@ -215,6 +224,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
k_l: &Layout, k_l: &Layout,
v: &candle::CudaStorage, v: &candle::CudaStorage,
v_l: &Layout, v_l: &Layout,
is_bf16: bool,
) -> Result<(candle::CudaStorage, Shape)> { ) -> Result<(candle::CudaStorage, Shape)> {
// https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327 // https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327
let dev = q.device(); let dev = q.device();
@ -314,6 +324,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
.w()?; .w()?;
let causal = if self.causal { 1 } else { 0 }; let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 };
unsafe { unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
@ -354,6 +365,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
/* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal, /* is_causal */ causal,
/* is_bf16 */ is_bf16,
) )
} }
@ -362,6 +374,40 @@ impl candle::CustomOp3 for FlashAttnVarLen {
} }
} }
impl candle::CustomOp3 for FlashAttnVarLen {
fn name(&self) -> &'static str {
"flash-attn-varlen"
}
fn cpu_fwd(
&self,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for flash-attn")
}
fn cuda_fwd(
&self,
q: &candle::CudaStorage,
q_l: &Layout,
k: &candle::CudaStorage,
k_l: &Layout,
v: &candle::CudaStorage,
v_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match q.dtype() {
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false),
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true),
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
}
}
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching. /// Flash-attention v2 layer with variable-length batching.
/// ///

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-kernels" name = "candle-kernels"
version = "0.2.1" version = "0.2.3"
edition = "2021" edition = "2021"
description = "CUDA kernels for Candle" description = "CUDA kernels for Candle"

View File

@ -164,6 +164,8 @@ mod cuda {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}"); println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let children = kernel_paths let children = kernel_paths
.par_iter() .par_iter()
.flat_map(|p| { .flat_map(|p| {
@ -188,8 +190,13 @@ mod cuda {
.args(["--output-directory", &out_dir]) .args(["--output-directory", &out_dir])
// Flash attention only // Flash attention only
// .arg("--expt-relaxed-constexpr") // .arg("--expt-relaxed-constexpr")
.args(&include_options) .args(&include_options);
.arg(p); if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(p);
Some((p, command.spawn() Some((p, command.spawn()
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output())) .expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
}}) }})

View File

@ -77,20 +77,30 @@ CAST_OP(double, __half, cast_f64_f16)
CAST_OP(uint32_t, uint32_t, cast_u32_u32) CAST_OP(uint32_t, uint32_t, cast_u32_u32)
CAST_OP(uint32_t, uint8_t, cast_u32_u8 ) CAST_OP(uint32_t, uint8_t, cast_u32_u8 )
CAST_OP(uint32_t, int64_t, cast_u32_i64 )
CAST_OP(uint32_t, float, cast_u32_f32) CAST_OP(uint32_t, float, cast_u32_f32)
CAST_OP(uint32_t, double, cast_u32_f64) CAST_OP(uint32_t, double, cast_u32_f64)
CAST_OP(uint8_t, uint32_t, cast_u8_u32) CAST_OP(uint8_t, uint32_t, cast_u8_u32)
CAST_OP(uint8_t, uint8_t, cast_u8_u8 ) CAST_OP(uint8_t, uint8_t, cast_u8_u8 )
CAST_OP(uint8_t, int64_t, cast_u8_i64 )
CAST_OP(uint8_t, float, cast_u8_f32) CAST_OP(uint8_t, float, cast_u8_f32)
CAST_OP(uint8_t, double, cast_u8_f64) CAST_OP(uint8_t, double, cast_u8_f64)
CAST_OP(int64_t, uint32_t, cast_i64_u32)
CAST_OP(int64_t, uint8_t, cast_i64_u8 )
CAST_OP(int64_t, int64_t, cast_i64_i64 )
CAST_OP(int64_t, float, cast_i64_f32)
CAST_OP(int64_t, double, cast_i64_f64)
CAST_OP(float, uint8_t, cast_f32_u8 ) CAST_OP(float, uint8_t, cast_f32_u8 )
CAST_OP(float, uint32_t, cast_f32_u32) CAST_OP(float, uint32_t, cast_f32_u32)
CAST_OP(float, int64_t, cast_f32_i64 )
CAST_OP(float, float, cast_f32_f32) CAST_OP(float, float, cast_f32_f32)
CAST_OP(float, double, cast_f32_f64) CAST_OP(float, double, cast_f32_f64)
CAST_OP(double, uint8_t, cast_f64_u8 ) CAST_OP(double, uint8_t, cast_f64_u8 )
CAST_OP(double, uint32_t, cast_f64_u32) CAST_OP(double, uint32_t, cast_f64_u32)
CAST_OP(double, int64_t, cast_f64_i64 )
CAST_OP(double, float, cast_f64_f32) CAST_OP(double, float, cast_f64_f32)
CAST_OP(double, double, cast_f64_f64) CAST_OP(double, double, cast_f64_f64)

View File

@ -51,6 +51,118 @@ __device__ void conv1d(
dst[dst_i] = static_cast<T>(d); dst[dst_i] = static_cast<T>(d);
} }
template <typename T>
__device__ void im2col1d(
const size_t dst_numel,
const size_t l_out,
const size_t l_k,
const size_t stride,
const size_t padding,
const size_t dilation,
const size_t *info,
const T *src,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// dst: (b_size, l_out, c_in, l_k)
// src: (b_size, c_in, l_in)
if (dst_i >= dst_numel) {
return;
}
const size_t *src_dims = info;
const size_t *src_s = info + 3;
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
const size_t dst_s2 = l_k;
const size_t dst_s1 = c_in * dst_s2;
const size_t dst_s0 = l_out * dst_s1;
size_t tmp_dst_i = dst_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t l_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= l_idx * dst_s1;
const size_t c_idx = tmp_dst_i / dst_s2;
tmp_dst_i -= c_idx * dst_s2;
const size_t l_k_idx = tmp_dst_i;
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
dst[dst_i] = src[src_i];
}
}
template <typename T>
__device__ void im2col(
const size_t dst_numel,
const size_t h_out,
const size_t w_out,
const size_t h_k,
const size_t w_k,
const size_t stride,
const size_t padding,
const size_t dilation,
const size_t *info,
const T *src,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// dst: (b_size, h_out, w_out, c_in, h_k, w_k)
// src: (b_size, c_in, h_in, w_in)
if (dst_i >= dst_numel) {
return;
}
const size_t *src_dims = info;
const size_t *src_s = info + 4;
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t h_in = src_dims[2];
const size_t w_in = src_dims[3];
const size_t dst_s4 = w_k;
const size_t dst_s3 = h_k * dst_s4;
const size_t dst_s2 = c_in * dst_s3;
const size_t dst_s1 = w_out * dst_s2;
const size_t dst_s0 = h_out * dst_s1;
size_t tmp_dst_i = dst_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t h_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= h_idx * dst_s1;
const size_t w_idx = tmp_dst_i / dst_s2;
tmp_dst_i -= w_idx * dst_s2;
const size_t c_idx = tmp_dst_i / dst_s3;
tmp_dst_i -= c_idx * dst_s3;
const size_t h_k_idx = tmp_dst_i / dst_s4;
tmp_dst_i -= h_k_idx * dst_s4;
const size_t w_k_idx = tmp_dst_i;
size_t src_h_idx = h_idx * stride + h_k_idx * dilation;
size_t src_w_idx = w_idx * stride + w_k_idx * dilation;
if (src_h_idx < padding || src_h_idx >= h_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else if (src_w_idx < padding || src_w_idx >= w_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else {
src_h_idx -= padding;
src_w_idx -= padding;
const size_t src_i =
b_idx * src_s[0]
+ c_idx * src_s[1]
+ src_h_idx * src_s[2]
+ src_w_idx * src_s[3];
dst[dst_i] = src[src_i];
}
}
// Naive implementation of conv2d. // Naive implementation of conv2d.
template <typename T, typename A> template <typename T, typename A>
__device__ void conv2d( __device__ void conv2d(
@ -363,6 +475,38 @@ extern "C" __global__ void FN_NAME( \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \ conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
} \ } \
#define IM2COL1D_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t dst_numel, \
const size_t l_out, \
const size_t l_k, \
const size_t stride, \
const size_t padding, \
const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
TYPENAME *dst \
) { \
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
} \
#define IM2COL_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t dst_numel, \
const size_t h_out, \
const size_t w_out, \
const size_t h_k, \
const size_t w_k, \
const size_t stride, \
const size_t padding, \
const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
TYPENAME *dst \
) { \
im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \
} \
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \ #define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \ extern "C" __global__ void FN_NAME( \
const size_t src_numel, \ const size_t src_numel, \
@ -428,6 +572,8 @@ CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16) AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16) MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
IM2COL_OP(__nv_bfloat16, im2col_bf16)
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
#endif #endif
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
@ -437,6 +583,8 @@ CONVT2D_OP(__half, float, conv_transpose2d_f16)
AVG_POOL2D_OP(__half, float, avg_pool2d_f16) AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
MAX_POOL2D_OP(__half, max_pool2d_f16) MAX_POOL2D_OP(__half, max_pool2d_f16)
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16) UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
IM2COL_OP(__half, im2col_f16)
IM2COL1D_OP(__half, im2col1d_f16)
#endif #endif
CONV1D_OP(float, float, conv1d_f32) CONV1D_OP(float, float, conv1d_f32)
@ -468,3 +616,13 @@ UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64) UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64)
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8) UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32) UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
IM2COL_OP(float, im2col_f32)
IM2COL_OP(double, im2col_f64)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(double, im2col1d_f64)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)

View File

@ -129,6 +129,10 @@ __device__ __forceinline__ float powg(float a, float b) { return powf(a, b); }
__device__ __forceinline__ double powg(double a, double b) { return pow(a, b); } __device__ __forceinline__ double powg(double a, double b) { return pow(a, b); }
__device__ __forceinline__ float tanhg(float a) { return tanhf(a); } __device__ __forceinline__ float tanhg(float a) { return tanhf(a); }
__device__ __forceinline__ double tanhg(double a) { return tanh(a); } __device__ __forceinline__ double tanhg(double a) { return tanh(a); }
__device__ __forceinline__ float erfg(float a) { return erff(a); }
__device__ __forceinline__ double erfg(double a) { return erf(a); }
__device__ __forceinline__ float normcdfg(float a) { return normcdff(a); }
__device__ __forceinline__ double normcdfg(double a) { return normcdf(a); }
__device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); } __device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); }
__device__ __forceinline__ double maxg(double a, double b) { return fmax(a, b); } __device__ __forceinline__ double maxg(double a, double b) { return fmax(a, b); }
__device__ __forceinline__ float ming(float a, float b) { return fminf(a, b); } __device__ __forceinline__ float ming(float a, float b) { return fminf(a, b); }
@ -157,6 +161,8 @@ __device__ __forceinline__ __half sing(__half a) { return hsin(a); }
__device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; } __device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; }
__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); } __device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); } __device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); }
__device__ __forceinline__ __half erfg(__half a) { return __float2half(erff(__half2float(a))); }
__device__ __forceinline__ __half normcdfg(__half a) { return __float2half(normcdff(__half2float(a))); }
__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); } __device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); }
__device__ __forceinline__ __half logg(__half a) { return hlog(a); } __device__ __forceinline__ __half logg(__half a) { return hlog(a); }
__device__ __forceinline__ __half expg(__half a) { return hexp(a); } __device__ __forceinline__ __half expg(__half a) { return hexp(a); }
@ -173,6 +179,8 @@ __device__ __forceinline__ __nv_bfloat16 sing(__nv_bfloat16 a) { return hsin(a);
__device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; } __device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; }
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); } __device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); } __device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 erfg(__nv_bfloat16 a) { return __float2bfloat16(erff(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 normcdfg(__nv_bfloat16 a) { return __float2bfloat16(normcdff(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); } __device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); }
__device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); } __device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); }
__device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); } __device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); }

View File

@ -49,6 +49,50 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
dst[dst_id] = shr[0]; dst[dst_id] = shr[0];
} }
// Softmax implementation adapted from ggml.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159
template <typename T, typename ACC>
__device__ void softmax(const T * x, T * dst, const int ncols) {
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int block_size = blockDim.y;
const int tid = threadIdx.y;
T max_val = -INFINITY;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
max_val = maxg(max_val, x[i]);
}
// find the max value in the block
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
max_val = maxg(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
}
ACC tmp = 0.;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const T val = expg(x[i] - max_val);
tmp += static_cast<ACC>(val);
dst[i] = val;
}
// sum up partial sums
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
const ACC inv_tmp = 1. / tmp;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
dst[i] *= inv_tmp;
}
}
template <typename T> template <typename T>
__device__ void __device__ void
fast_max(const size_t src_numel, const size_t el_to_sum_per_block, fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
@ -290,12 +334,21 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
} \ } \
} }
#define SOFTMAX_OP(TYPENAME, ACC_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, \
const int n_cols) { \
softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \
} \
#if __CUDA_ARCH__ >= 800 #if __CUDA_ARCH__ >= 800
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
SUM_OP(__nv_bfloat16, sum_bf16) SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
#endif #endif
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
SUM_OP(__half, sum_f16) SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
#endif #endif
@ -303,6 +356,8 @@ FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fa
SUM_OP(float, sum_f32) SUM_OP(float, sum_f32)
SUM_OP(double, sum_f64) SUM_OP(double, sum_f64)
SUM_OP(uint32_t, sum_u32) SUM_OP(uint32_t, sum_u32)
SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)

View File

@ -28,6 +28,11 @@ extern "C" __global__ void FN_NAME( \
} \ } \
} \ } \
template<typename T>
__device__ __forceinline__ T gelu_erf_fwd(T x) {
return x * normcdfg(x);
}
template<typename T> template<typename T>
__device__ __forceinline__ T gelu_fwd(T x) { __device__ __forceinline__ T gelu_fwd(T x) {
T x_sq = x * x; T x_sq = x * x;
@ -86,10 +91,13 @@ UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x)) UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x)) UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x))
UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x)) UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x))
UNARY_OP(__nv_bfloat16, uerf_bf16, erfg(x))
UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x))
UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x)) UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x))
UNARY_OP(__nv_bfloat16, usqr_bf16, x*x) UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x)) UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x)) UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
@ -104,10 +112,13 @@ UNARY_OP(__half, ulog_f16, logg(x))
UNARY_OP(__half, usin_f16, sing(x)) UNARY_OP(__half, usin_f16, sing(x))
UNARY_OP(__half, ucos_f16, cosg(x)) UNARY_OP(__half, ucos_f16, cosg(x))
UNARY_OP(__half, utanh_f16, tanhg(x)) UNARY_OP(__half, utanh_f16, tanhg(x))
UNARY_OP(__half, uerf_f16, erfg(x))
UNARY_OP(__half, unormcdf_f16, normcdfg(x))
UNARY_OP(__half, uabs_f16, absg(x)) UNARY_OP(__half, uabs_f16, absg(x))
UNARY_OP(__half, usqr_f16, x*x) UNARY_OP(__half, usqr_f16, x*x)
UNARY_OP(__half, usqrt_f16, sqrtg(x)) UNARY_OP(__half, usqrt_f16, sqrtg(x))
UNARY_OP(__half, ugelu_f16, gelu_fwd(x)) UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
UNARY_OP(__half, urelu_f16, relu_fwd(x)) UNARY_OP(__half, urelu_f16, relu_fwd(x))
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param)) UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
UNARY_OP1(__half, upowf_f16, powg(x, param)) UNARY_OP1(__half, upowf_f16, powg(x, param))
@ -131,6 +142,10 @@ UNARY_OP(float, ucos_f32, cosg(x))
UNARY_OP(double, ucos_f64, cosg(x)) UNARY_OP(double, ucos_f64, cosg(x))
UNARY_OP(float, utanh_f32, tanhg(x)) UNARY_OP(float, utanh_f32, tanhg(x))
UNARY_OP(double, utanh_f64, tanhg(x)) UNARY_OP(double, utanh_f64, tanhg(x))
UNARY_OP(float, uerf_f32, erfg(x))
UNARY_OP(double, uerf_f64, erfg(x))
UNARY_OP(float, unormcdf_f32, normcdfg(x))
UNARY_OP(double, unormcdf_f64, normcdfg(x))
UNARY_OP(float, uabs_f32, absg(x)) UNARY_OP(float, uabs_f32, absg(x))
UNARY_OP(double, uabs_f64, absg(x)) UNARY_OP(double, uabs_f64, absg(x))
UNARY_OP(float, usqr_f32, x*x) UNARY_OP(float, usqr_f32, x*x)
@ -139,6 +154,8 @@ UNARY_OP(float, usqrt_f32, sqrtg(x))
UNARY_OP(double, usqrt_f64, sqrtg(x)) UNARY_OP(double, usqrt_f64, sqrtg(x))
UNARY_OP(float, ugelu_f32, gelu_fwd(x)) UNARY_OP(float, ugelu_f32, gelu_fwd(x))
UNARY_OP(double, ugelu_f64, gelu_fwd(x)) UNARY_OP(double, ugelu_f64, gelu_fwd(x))
UNARY_OP(float, ugelu_erf_f32, gelu_erf_fwd(x))
UNARY_OP(double, ugelu_erf_f64, gelu_erf_fwd(x))
UNARY_OP(float, urelu_f32, relu_fwd(x)) UNARY_OP(float, urelu_f32, relu_fwd(x))
UNARY_OP(double, urelu_f64, relu_fwd(x)) UNARY_OP(double, urelu_f64, relu_fwd(x))
UNARY_OP1(float, uelu_f32, elu_fwd(x, param)) UNARY_OP1(float, uelu_f32, elu_fwd(x, param))

View File

@ -11,13 +11,18 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
half = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
clap = { workspace = true }
[features] [features]
default = [] default = []

View File

@ -0,0 +1,302 @@
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::quantized::GgmlType;
use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
use clap::{Parser, Subcommand};
const CHECK_CONV2D: bool = false;
trait Benchmark {
type PreProcessData;
type RunResult;
fn preprocess() -> Result<Self::PreProcessData>;
fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>;
const ITERS: usize;
}
struct Im2Col {
h_k: usize,
w_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col {
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
(h_out, w_out)
}
}
impl candle::CustomOp1 for Im2Col {
fn name(&self) -> &'static str {
"im2col"
}
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
let &Self {
h_k,
w_k,
stride,
dilation,
padding,
} = self;
let (b, c, h, w) = layout.shape().dims4()?;
let (h_out, w_out) = self.hw_out(h, w);
let slice = storage.as_slice::<f32>()?;
let src = &slice[layout.start_offset()..];
let mut dst = vec![0f32; b * h_out * w_out * c * h_k * w_k];
let (src_s0, src_s1, src_s2, src_s3) = {
let s = layout.stride();
(s[0], s[1], s[2], s[3])
};
// TODO: provide specialized kernels for the common use cases.
// - h_k = w_k = 1
// - padding = 0
// - stride = 1
// - dilation = 1
for b_idx in 0..b {
let src_idx = b_idx * src_s0;
let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
for h_idx in 0..h_out {
let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
for w_idx in 0..w_out {
let dst_idx = dst_idx + w_idx * c * h_k * w_k;
for c_idx in 0..c {
let dst_idx = dst_idx + c_idx * h_k * w_k;
let src_idx = c_idx * src_s1 + src_idx;
for h_k_idx in 0..h_k {
let src_h = h_idx * stride + h_k_idx * dilation;
if padding != 0 && (src_h < padding || src_h >= h + padding) {
continue;
}
let src_h = src_h - padding;
let src_idx = src_idx + src_h * src_s2;
let dst_idx = dst_idx + h_k_idx * w_k;
for w_k_idx in 0..w_k {
let src_w = w_idx * stride + w_k_idx * dilation;
if padding != 0 && (src_w < padding || src_w >= w + padding) {
continue;
}
let src_w = src_w - padding;
let src_idx = src_idx + src_w * src_s3;
let dst_idx = dst_idx + w_k_idx;
dst[dst_idx] = src[src_idx]
}
}
}
}
}
}
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (b * h_out * w_out, c * h_k * w_k).into()))
}
}
// Conv1d example as used in whisper.
struct Conv1d;
impl Benchmark for Conv1d {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
Ok((inp, w))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.conv1d(&d.1, 0, 1, 1, 1)
}
const ITERS: usize = 5;
}
// Conv2d example as used in stable-diffusion.
struct Conv2d;
impl Benchmark for Conv2d {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
Ok((inp, w))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.conv2d(&d.1, 0, 1, 1, 1)
}
const ITERS: usize = 5;
}
// Conv2d example as used in stable-diffusion, im2col implementation.
struct Conv2dIm2Col;
impl Benchmark for Conv2dIm2Col {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
Ok((inp, w))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
// d.0.conv2d(&d.1, 0, 1, 1, 1)
let (b, _, h, w) = d.0.dims4()?;
let (_, _, h_k, w_k) = d.1.dims4()?;
let op = Im2Col {
h_k,
w_k,
stride: 1,
dilation: 1,
padding: 0,
};
let (h_out, w_out) = op.hw_out(h, w);
let col = d.0.apply_op1_no_bwd(&op)?;
let res = col.matmul(&d.1.flatten_from(1)?.t()?)?;
let res = res
.reshape((b, h_out, w_out, ()))?
.permute((0, 3, 1, 2))?
.contiguous()?;
if CHECK_CONV2D {
let res2 = d.0.conv2d(&d.1, op.padding, op.stride, op.dilation, 1);
let diff = (&res - res2)?.sqr()?.mean_all()?;
println!("{diff}");
}
Ok(res)
}
const ITERS: usize = 5;
}
struct Matmul;
impl Benchmark for Matmul {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
Ok((lhs, rhs))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.matmul(&d.1)
}
const ITERS: usize = 100;
}
// This benchmark is similar to:
// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp
struct QMatMul;
impl Benchmark for QMatMul {
type PreProcessData = (candle::quantized::QMatMul, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
let mm = candle::quantized::QMatMul::from_qtensor(mm);
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
Ok((mm, arg))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.forward(&d.1)
}
const ITERS: usize = 100;
}
struct Softmax;
impl Benchmark for Softmax {
type PreProcessData = Tensor;
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
// Typical whisper tiny size.
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
Ok(x)
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
candle_nn::ops::softmax(d, D::Minus1)
}
const ITERS: usize = 100;
}
struct SoftmaxLastDim;
impl Benchmark for SoftmaxLastDim {
type PreProcessData = Tensor;
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
// Typical whisper tiny size.
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
Ok(x)
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
candle_nn::ops::softmax_last_dim(d)
}
const ITERS: usize = 100;
}
fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
use std::hint::black_box;
let iters = iters.unwrap_or(B::ITERS);
let d = B::preprocess()?;
let start = std::time::Instant::now();
for _iter in 0..iters {
let _res = black_box(B::run_one(black_box(&d))?);
}
println!("{:?}", start.elapsed() / iters as u32);
Ok(())
}
#[derive(Subcommand, Debug, Clone)]
enum Task {
Conv1d,
Conv2d,
Conv2dIm2Col,
Matmul,
Qmatmul,
Softmax,
SoftmaxLastDim,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// The benchmark to be run.
#[command(subcommand)]
task: Task,
#[arg(long)]
iters: Option<usize>,
}
fn main() -> Result<()> {
let args = Args::parse();
match args.task {
Task::Conv1d => run::<Conv1d>(args.iters)?,
Task::Conv2d => run::<Conv2d>(args.iters)?,
Task::Conv2dIm2Col => run::<Conv2dIm2Col>(args.iters)?,
Task::Matmul => run::<Matmul>(args.iters)?,
Task::Softmax => run::<Softmax>(args.iters)?,
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
Task::Qmatmul => run::<QMatMul>(args.iters)?,
}
Ok(())
}

View File

@ -1,18 +1,29 @@
use candle::Tensor; use candle::Tensor;
use serde::Deserialize;
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum Activation { pub enum Activation {
#[default]
Gelu, Gelu,
#[serde(rename = "gated-gelu")]
NewGelu,
Relu, Relu,
Elu(f64), Elu(f64),
LeakyRelu(f64),
} }
impl super::Module for Activation { impl super::Module for Activation {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self { match self {
Self::Gelu => xs.gelu(), Self::Gelu => xs.gelu(),
// TODO: This is "gelu_new", not the original "gelu".
// There's some small numerical difference:
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
Self::NewGelu => xs.gelu(),
Self::Relu => xs.relu(), Self::Relu => xs.relu(),
&Self::Elu(alpha) => xs.elu(alpha), &Self::Elu(alpha) => xs.elu(alpha),
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
} }
} }
} }

View File

@ -38,7 +38,7 @@ impl From<f64> for BatchNormConfig {
} }
} }
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct BatchNorm { pub struct BatchNorm {
running_mean: Tensor, running_mean: Tensor,
running_var: Tensor, running_var: Tensor,

View File

@ -20,7 +20,7 @@ impl Default for Conv1dConfig {
} }
} }
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct Conv1d { pub struct Conv1d {
weight: Tensor, weight: Tensor,
bias: Option<Tensor>, bias: Option<Tensor>,
@ -39,6 +39,14 @@ impl Conv1d {
pub fn config(&self) -> &Conv1dConfig { pub fn config(&self) -> &Conv1dConfig {
&self.config &self.config
} }
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
} }
impl crate::Module for Conv1d { impl crate::Module for Conv1d {
@ -80,8 +88,7 @@ impl Default for Conv2dConfig {
} }
} }
#[allow(dead_code)] #[derive(Clone, Debug)]
#[derive(Debug)]
pub struct Conv2d { pub struct Conv2d {
weight: Tensor, weight: Tensor,
bias: Option<Tensor>, bias: Option<Tensor>,
@ -100,6 +107,14 @@ impl Conv2d {
pub fn config(&self) -> &Conv2dConfig { pub fn config(&self) -> &Conv2dConfig {
&self.config &self.config
} }
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
} }
impl crate::Module for Conv2d { impl crate::Module for Conv2d {
@ -122,15 +137,76 @@ impl crate::Module for Conv2d {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConvTranspose2dConfig {
pub padding: usize,
pub output_padding: usize,
pub stride: usize,
pub dilation: usize,
// TODO: support groups.
}
impl Default for ConvTranspose2dConfig {
fn default() -> Self {
Self {
padding: 0,
output_padding: 0,
stride: 1,
dilation: 1,
}
}
}
#[derive(Clone, Debug)]
pub struct ConvTranspose2d {
weight: Tensor,
bias: Option<Tensor>,
config: ConvTranspose2dConfig,
}
impl ConvTranspose2d {
pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose2dConfig) -> Self {
Self {
weight,
bias,
config,
}
}
pub fn config(&self) -> &ConvTranspose2dConfig {
&self.config
}
}
impl crate::Module for ConvTranspose2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv_transpose2d(
&self.weight,
self.config.padding,
self.config.output_padding,
self.config.stride,
self.config.dilation,
)?;
match &self.bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
}
}
pub fn conv1d( pub fn conv1d(
in_channels: usize, in_channels: usize,
out_channels: usize, out_channels: usize,
kernel_size: usize, kernel_size: usize,
cfg: Conv1dConfig, cfg: Conv1dConfig,
vs: crate::VarBuilder, vb: crate::VarBuilder,
) -> Result<Conv1d> { ) -> Result<Conv1d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get_with_hints( let ws = vb.get_with_hints(
(out_channels, in_channels / cfg.groups, kernel_size), (out_channels, in_channels / cfg.groups, kernel_size),
"weight", "weight",
init_ws, init_ws,
@ -140,7 +216,7 @@ pub fn conv1d(
lo: -bound, lo: -bound,
up: bound, up: bound,
}; };
let bs = vs.get_with_hints(out_channels, "bias", init_bs)?; let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
Ok(Conv1d::new(ws, Some(bs), cfg)) Ok(Conv1d::new(ws, Some(bs), cfg))
} }
@ -149,10 +225,10 @@ pub fn conv2d(
out_channels: usize, out_channels: usize,
kernel_size: usize, kernel_size: usize,
cfg: Conv2dConfig, cfg: Conv2dConfig,
vs: crate::VarBuilder, vb: crate::VarBuilder,
) -> Result<Conv2d> { ) -> Result<Conv2d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get_with_hints( let ws = vb.get_with_hints(
( (
out_channels, out_channels,
in_channels / cfg.groups, in_channels / cfg.groups,
@ -167,7 +243,7 @@ pub fn conv2d(
lo: -bound, lo: -bound,
up: bound, up: bound,
}; };
let bs = vs.get_with_hints(out_channels, "bias", init_bs)?; let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
Ok(Conv2d::new(ws, Some(bs), cfg)) Ok(Conv2d::new(ws, Some(bs), cfg))
} }
@ -176,10 +252,10 @@ pub fn conv2d_no_bias(
out_channels: usize, out_channels: usize,
kernel_size: usize, kernel_size: usize,
cfg: Conv2dConfig, cfg: Conv2dConfig,
vs: crate::VarBuilder, vb: crate::VarBuilder,
) -> Result<Conv2d> { ) -> Result<Conv2d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get_with_hints( let ws = vb.get_with_hints(
( (
out_channels, out_channels,
in_channels / cfg.groups, in_channels / cfg.groups,
@ -191,3 +267,44 @@ pub fn conv2d_no_bias(
)?; )?;
Ok(Conv2d::new(ws, None, cfg)) Ok(Conv2d::new(ws, None, cfg))
} }
pub fn conv_transpose2d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: ConvTranspose2dConfig,
vb: crate::VarBuilder,
) -> Result<ConvTranspose2d> {
let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
let init = crate::Init::Uniform {
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints(
(in_channels, out_channels, kernel_size, kernel_size),
"weight",
init,
)?;
let bs = vb.get_with_hints(out_channels, "bias", init)?;
Ok(ConvTranspose2d::new(ws, Some(bs), cfg))
}
pub fn conv_transpose2d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: ConvTranspose2dConfig,
vb: crate::VarBuilder,
) -> Result<ConvTranspose2d> {
let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
let init = crate::Init::Uniform {
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints(
(in_channels, out_channels, kernel_size, kernel_size),
"weight",
init,
)?;
Ok(ConvTranspose2d::new(ws, None, cfg))
}

View File

@ -1,7 +1,7 @@
//! Embedding Layer. //! Embedding Layer.
use candle::{Result, Tensor}; use candle::{Result, Tensor};
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct Embedding { pub struct Embedding {
embeddings: Tensor, embeddings: Tensor,
hidden_size: usize, hidden_size: usize,
@ -18,6 +18,11 @@ impl Embedding {
pub fn embeddings(&self) -> &Tensor { pub fn embeddings(&self) -> &Tensor {
&self.embeddings &self.embeddings
} }
/// Get the hidden size of the embedding matrix
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
} }
impl crate::Module for Embedding { impl crate::Module for Embedding {

View File

@ -4,7 +4,7 @@
use candle::{DType, Result, Tensor}; use candle::{DType, Result, Tensor};
// This group norm version handles both weight and bias so removes the mean. // This group norm version handles both weight and bias so removes the mean.
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct GroupNorm { pub struct GroupNorm {
weight: Tensor, weight: Tensor,
bias: Tensor, bias: Tensor,

View File

@ -28,7 +28,7 @@
//! ``` //! ```
//! //!
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
use candle::{DType, Result, Tensor}; use candle::{DType, Result, Tensor, D};
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub struct LayerNormConfig { pub struct LayerNormConfig {
@ -60,7 +60,7 @@ impl From<f64> for LayerNormConfig {
} }
// This layer norm version handles both weight and bias so removes the mean. // This layer norm version handles both weight and bias so removes the mean.
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct LayerNorm { pub struct LayerNorm {
weight: Tensor, weight: Tensor,
bias: Option<Tensor>, bias: Option<Tensor>,
@ -104,15 +104,15 @@ impl crate::Module for LayerNorm {
DType::F16 | DType::BF16 => DType::F32, DType::F16 | DType::BF16 => DType::F32,
d => d, d => d,
}; };
let (_bsize, _seq_len, hidden_size) = x.dims3()?; let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?; let x = x.to_dtype(internal_dtype)?;
let x = if self.remove_mean { let x = if self.remove_mean {
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
x.broadcast_sub(&mean_x)? x.broadcast_sub(&mean_x)?
} else { } else {
x x
}; };
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?; let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
match &self.bias { match &self.bias {
@ -143,7 +143,7 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
} }
/// RmsNorm is a specialized version of the LayerNorm module. /// RmsNorm is a specialized version of the LayerNorm module.
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct RmsNorm(LayerNorm); pub struct RmsNorm(LayerNorm);
impl RmsNorm { impl RmsNorm {

Some files were not shown because too many files have changed in this diff Show More