Merge branch 'main' into book-trainin-simplified
1
.gitignore
vendored
@ -23,6 +23,7 @@ flamegraph.svg
|
||||
*.dylib
|
||||
*.so
|
||||
*.swp
|
||||
*.swo
|
||||
trace-*.json
|
||||
|
||||
candle-wasm-examples/*/build
|
||||
|
47
CHANGELOG.md
@ -1,13 +1,58 @@
|
||||
# Changelog
|
||||
This documents the main changes to the `candle` crate.
|
||||
|
||||
## v0.2.1 - Unreleased
|
||||
## v0.2.3 - Unreleased
|
||||
|
||||
### Added
|
||||
|
||||
### Modified
|
||||
|
||||
## v0.2.2 - 2023-09-18
|
||||
|
||||
### Added
|
||||
- Support for `top_p` sampling
|
||||
[819](https://github.com/huggingface/candle/pull/819).
|
||||
- T5 model including decoding
|
||||
[864](https://github.com/huggingface/candle/pull/864).
|
||||
- 1-d upsampling
|
||||
[839](https://github.com/huggingface/candle/pull/839).
|
||||
|
||||
### Modified
|
||||
- Bugfix for conv2d
|
||||
[820](https://github.com/huggingface/candle/pull/820).
|
||||
- Support tensor based indexing using `.i`
|
||||
[842](https://github.com/huggingface/candle/pull/842).
|
||||
|
||||
## v0.2.1 - 2023-09-11
|
||||
|
||||
### Added
|
||||
- Add some RNNs (GRU and LSTM) in `candle-nn`
|
||||
[674](https://github.com/huggingface/candle/pull/674),
|
||||
[688](https://github.com/huggingface/candle/pull/688).
|
||||
- gguf v2 support
|
||||
[725](https://github.com/huggingface/candle/pull/725).
|
||||
- Quantized llama example in Python using the pyo3 api
|
||||
[716](https://github.com/huggingface/candle/pull/716).
|
||||
- `candle-nn` layer for conv2d-transposed
|
||||
[760](https://github.com/huggingface/candle/pull/760).
|
||||
- Add the Segment-Anything Model (SAM) as an example
|
||||
[773](https://github.com/huggingface/candle/pull/773).
|
||||
- TinyViT backbone for the segemnt anything example
|
||||
[787](https://github.com/huggingface/candle/pull/787).
|
||||
- Shape with holes support
|
||||
[770](https://github.com/huggingface/candle/pull/770).
|
||||
|
||||
### Modified
|
||||
- Dilations are now supported in conv-transpose2d.
|
||||
[671](https://github.com/huggingface/candle/pull/671).
|
||||
- Interactive mode for the quantized model
|
||||
[690](https://github.com/huggingface/candle/pull/690).
|
||||
- Faster softmax operation
|
||||
[747](https://github.com/huggingface/candle/pull/747).
|
||||
- Faster convolution operations on CPU and CUDA via im2col
|
||||
[802](https://github.com/huggingface/candle/pull/802).
|
||||
- Moving some models to a more central location
|
||||
[796](https://github.com/huggingface/candle/pull/796).
|
||||
|
||||
## v0.2.0 - 2023-08-30
|
||||
|
||||
|
11
Cargo.toml
@ -8,17 +8,16 @@ members = [
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/llama2-c",
|
||||
"candle-wasm-examples/segment-anything",
|
||||
"candle-wasm-examples/whisper",
|
||||
"candle-wasm-examples/yolo",
|
||||
"candle-wasm-examples/bert",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
]
|
||||
exclude = ["candle-flash-attn", "candle-kernels"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.2.1"
|
||||
version = "0.2.3"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,7 +32,7 @@ byteorder = "1.4.3"
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
# TODO: Switch back to the official gemm implementation once it has caught up.
|
||||
gemm = { version = "0.15.6", package = "candle-gemm" }
|
||||
gemm = { version = "0.16.0", package = "candle-gemm" }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
|
115
README.md
@ -8,7 +8,9 @@ Candle is a minimalist ML framework for Rust with a focus on performance (includ
|
||||
and ease of use. Try our online demos:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[yolo](https://huggingface.co/spaces/lmz/candle-yolo).
|
||||
[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
|
||||
[Segment
|
||||
Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
|
||||
## Get started
|
||||
|
||||
@ -45,37 +47,54 @@ For more advanced examples, please have a look at the following section.
|
||||
|
||||
## Check out our examples
|
||||
|
||||
Check out our [examples](./candle-examples/examples/):
|
||||
These online demos run entirely in your browser:
|
||||
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
||||
object recognition.
|
||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
|
||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||
|
||||
We also provide a some command line based examples using state of the art models:
|
||||
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
||||
generation.
|
||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||
the LLaMA model using the same quantization techniques as
|
||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
|
||||
|
||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||
|
||||
- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
|
||||
image generative model.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
|
||||
|
||||
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
|
||||
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
|
||||
estimation models.
|
||||
Run them using the following commands:
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.od.jpg" width="200"><img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.pose.jpg" width="200">
|
||||
- [segment-anything](./candle-examples/examples/segment-anything/): image
|
||||
segmentation model with prompt.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
|
||||
Run them using commands like:
|
||||
```
|
||||
cargo run --example whisper --release
|
||||
cargo run --example llama --release
|
||||
cargo run --example falcon --release
|
||||
cargo run --example bert --release
|
||||
cargo run --example bigcode --release
|
||||
cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch"
|
||||
cargo run --example dinov2 --release -- --image path/to/myinput.jpg
|
||||
cargo run --example quantized --release
|
||||
cargo run --example yolo-v3 --release -- myimage.jpg
|
||||
cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose
|
||||
```
|
||||
|
||||
In order to use **CUDA** add `--features cuda` to the example command line. If
|
||||
@ -85,7 +104,8 @@ There are also some wasm examples for whisper and
|
||||
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
|
||||
`trunk` or try them online:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
|
||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||
test server:
|
||||
@ -98,6 +118,15 @@ trunk serve --release --port 8081
|
||||
And then head over to
|
||||
[http://localhost:8081/](http://localhost:8081/).
|
||||
|
||||
<!--- ANCHOR: useful_libraries --->
|
||||
|
||||
## Useful Libraries
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
|
||||
<!--- ANCHOR_END: useful_libraries --->
|
||||
|
||||
<!--- ANCHOR: features --->
|
||||
|
||||
## Features
|
||||
@ -110,10 +139,21 @@ And then head over to
|
||||
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
|
||||
- WASM support, run your models in a browser.
|
||||
- Included models.
|
||||
- LLMs: LLaMA v1 and v2, Falcon, StarCoder.
|
||||
- Language Models.
|
||||
- LLaMA v1 and v2.
|
||||
- Falcon.
|
||||
- StarCoder.
|
||||
- T5.
|
||||
- Bert.
|
||||
- Whisper (multi-lingual support).
|
||||
- Stable Diffusion.
|
||||
- Computer Vision: DINOv2, EfficientNet, yolo-v3, yolo-v8.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Computer Vision Models.
|
||||
- DINOv2.
|
||||
- EfficientNet.
|
||||
- yolo-v3.
|
||||
- yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
- Quantization support using the llama.cpp quantized types.
|
||||
@ -243,6 +283,35 @@ authentication token. See issue
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
#### Compiling with flash-attention fails
|
||||
|
||||
```
|
||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||
```
|
||||
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
||||
```
|
||||
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
```
|
||||
|
||||
#### Linking error on windows when running rustdoc or mdbook tests
|
||||
|
||||
```
|
||||
Couldn't compile the test.
|
||||
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
|
||||
error: linking with `link.exe` failed: exit code: 1181
|
||||
//very long chain of linking
|
||||
= note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'
|
||||
```
|
||||
|
||||
Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:
|
||||
|
||||
```
|
||||
mdbook test candle-book -L .\target\debug\deps\ `
|
||||
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
|
||||
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
|
||||
```
|
||||
|
||||
#### Tracking down errors
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
@ -10,10 +10,10 @@
|
||||
|
||||
# Reference Guide
|
||||
|
||||
- [Running a model](inference/README.md)
|
||||
- [Running a model](inference/inference.md)
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Training](training/README.md)
|
||||
- [Training](training/training.md)
|
||||
- [Simplified](training/simplified.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
- [Fine-tuning]()
|
||||
|
@ -29,7 +29,7 @@ After adding `RUST_BACKTRACE=1`:
|
||||
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
|
||||
```
|
||||
|
||||
Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
|
||||
Not super pretty at the moment, but we can see error occurred on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
|
||||
|
||||
|
||||
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
|
||||
|
@ -6,7 +6,7 @@ Open `src/main.rs` and fill in this content:
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
use candle_core::{DType, Device, Result, Tensor};
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
|
||||
struct Model {
|
||||
first: Tensor,
|
||||
@ -25,11 +25,11 @@ fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; to use the GPU.
|
||||
let device = Device::Cpu;
|
||||
|
||||
let first = Tensor::zeros((784, 100), DType::F32, &device)?;
|
||||
let second = Tensor::zeros((100, 10), DType::F32, &device)?;
|
||||
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
@ -50,7 +50,7 @@ the classical `Linear` layer. We can do as such
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# use candle_core::{DType, Device, Result, Tensor};
|
||||
# use candle_core::{Device, Result, Tensor};
|
||||
struct Linear{
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
@ -80,7 +80,7 @@ This will change the model running code into a new function
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# use candle_core::{DType, Device, Result, Tensor};
|
||||
# use candle_core::{Device, Result, Tensor};
|
||||
# struct Linear{
|
||||
# weight: Tensor,
|
||||
# bias: Tensor,
|
||||
@ -110,15 +110,15 @@ fn main() -> Result<()> {
|
||||
let device = Device::cuda_if_available(0)?;
|
||||
|
||||
// Creating a dummy model
|
||||
let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
|
||||
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
|
||||
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||
let first = Linear{weight, bias};
|
||||
let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
|
||||
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
|
||||
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||
let second = Linear{weight, bias};
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
// Inference on the model
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
@ -146,7 +146,7 @@ And rewrite our examples using it
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# extern crate candle_nn;
|
||||
use candle_core::{DType, Device, Result, Tensor};
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_nn::{Linear, Module};
|
||||
|
||||
struct Model {
|
||||
@ -167,15 +167,15 @@ fn main() -> Result<()> {
|
||||
let device = Device::Cpu;
|
||||
|
||||
// This has changed (784, 100) -> (100, 784) !
|
||||
let weight = Tensor::zeros((100, 784), DType::F32, &device)?;
|
||||
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
|
||||
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||
let first = Linear::new(weight, Some(bias));
|
||||
let weight = Tensor::zeros((10, 100), DType::F32, &device)?;
|
||||
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
|
||||
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||
let second = Linear::new(weight, Some(bias));
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
@ -188,8 +188,8 @@ Feel free to modify this example to use `Conv2d` to create a classical convnet i
|
||||
|
||||
Now that we have the running dummy code we can get to more advanced topics:
|
||||
|
||||
- [For PyTorch users](./guide/cheatsheet.md)
|
||||
- [Running existing models](./inference/README.md)
|
||||
- [Training models](./training/README.md)
|
||||
- [For PyTorch users](../guide/cheatsheet.md)
|
||||
- [Running existing models](../inference/inference.md)
|
||||
- [Training models](../training/training.md)
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.2.1", optional = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.2.3", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
|
@ -1,166 +0,0 @@
|
||||
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_core::quantized::GgmlType;
|
||||
use candle_core::{Device, Result, Tensor, D};
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
fn softmax<D: candle_core::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
|
||||
let dim = dim.to_index(xs.shape(), "softmax")?;
|
||||
let max = xs.max_keepdim(dim)?;
|
||||
let diff = xs.broadcast_sub(&max)?;
|
||||
let num = diff.exp()?;
|
||||
let den = num.sum_keepdim(dim)?;
|
||||
num.broadcast_div(&den)
|
||||
}
|
||||
|
||||
trait Benchmark {
|
||||
type PreProcessData;
|
||||
type RunResult;
|
||||
|
||||
fn preprocess() -> Result<Self::PreProcessData>;
|
||||
fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>;
|
||||
|
||||
const ITERS: usize;
|
||||
}
|
||||
|
||||
// Conv1d example as used in whisper.
|
||||
struct Conv1d;
|
||||
impl Benchmark for Conv1d {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
|
||||
Ok((inp, w))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv1d(&d.1, 0, 1, 1, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 5;
|
||||
}
|
||||
|
||||
// Conv2d example as used in stable-diffusion.
|
||||
struct Conv2d;
|
||||
impl Benchmark for Conv2d {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
Ok((inp, w))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv2d(&d.1, 0, 1, 1, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 1;
|
||||
}
|
||||
|
||||
struct Matmul;
|
||||
impl Benchmark for Matmul {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.matmul(&d.1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
// This benchmark is similar to:
|
||||
// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp
|
||||
struct QMatMul;
|
||||
impl Benchmark for QMatMul {
|
||||
type PreProcessData = (candle_core::quantized::QMatMul, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let zeros = vec![candle_core::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
||||
let mm = candle_core::quantized::QTensor::new(zeros, (4096, 11008))?;
|
||||
let mm = candle_core::quantized::QMatMul::from_qtensor(mm);
|
||||
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
||||
Ok((mm, arg))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.forward(&d.1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
struct Softmax;
|
||||
impl Benchmark for Softmax {
|
||||
type PreProcessData = Tensor;
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
// Typical whisper tiny size.
|
||||
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
softmax(d, D::Minus1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
|
||||
use std::hint::black_box;
|
||||
|
||||
let iters = iters.unwrap_or(B::ITERS);
|
||||
let d = B::preprocess()?;
|
||||
let start = std::time::Instant::now();
|
||||
for _iter in 0..iters {
|
||||
let _res = black_box(B::run_one(black_box(&d))?);
|
||||
}
|
||||
println!("{:?}", start.elapsed() / iters as u32);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Task {
|
||||
Conv1d,
|
||||
Conv2d,
|
||||
Matmul,
|
||||
Qmatmul,
|
||||
Softmax,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// The benchmark to be run.
|
||||
#[command(subcommand)]
|
||||
task: Task,
|
||||
|
||||
#[arg(long)]
|
||||
iters: Option<usize>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
match args.task {
|
||||
Task::Conv1d => run::<Conv1d>(args.iters)?,
|
||||
Task::Conv2d => run::<Conv2d>(args.iters)?,
|
||||
Task::Matmul => run::<Matmul>(args.iters)?,
|
||||
Task::Softmax => run::<Softmax>(args.iters)?,
|
||||
Task::Qmatmul => run::<QMatMul>(args.iters)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -218,12 +218,65 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize_safetensors(
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
) -> Result<()> {
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
|
||||
println!("tensors: {}", tensors.len());
|
||||
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
|
||||
let qtensors = tensors
|
||||
.into_par_iter()
|
||||
.map(|(name, tensor)| {
|
||||
println!(" quantizing {name} {tensor:?}");
|
||||
let should_quantize = tensor.rank() == 2 && tensor.dim(0)? % 256 == 0;
|
||||
let tensor = if should_quantize {
|
||||
quantize_fn(&tensor)?
|
||||
} else {
|
||||
QTensor::quantize::<f32>(&tensor)?
|
||||
};
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let qtensors = qtensors
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
gguf_file::write(&mut out_file, &[], &qtensors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
qmode: QuantizationMode,
|
||||
) -> Result<()> {
|
||||
if let Some(extension) = in_file.extension() {
|
||||
if extension == "safetensors" {
|
||||
return run_quantize_safetensors(in_file, out_file, q);
|
||||
}
|
||||
}
|
||||
|
||||
// Open the out file early so as to fail directly on missing directories etc.
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut in_ = std::fs::File::open(&in_file)?;
|
||||
|
@ -370,6 +370,38 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
|
||||
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_tanh_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
|
||||
}
|
||||
vs_tanh_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = 0.5 * v * (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
|
||||
}
|
||||
vd_tanh_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = 0.5 * v * (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -57,6 +57,7 @@ pub trait BackendStorage: Sized {
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||
|
@ -91,13 +91,14 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D(node)
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
| Op::Copy(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Cmp(node, _)
|
||||
| Op::Reduce(node, _, _)
|
||||
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
@ -111,6 +112,7 @@ impl Tensor {
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
|
||||
}
|
||||
} else {
|
||||
nodes
|
||||
@ -262,6 +264,9 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
}
|
||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest1d",
|
||||
})?,
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
@ -437,6 +442,10 @@ impl Tensor {
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Relu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
@ -517,6 +526,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GradStore(HashMap<TensorId, Tensor>);
|
||||
|
||||
impl GradStore {
|
||||
|
763
candle-core/src/cpu/erf.rs
Normal file
@ -0,0 +1,763 @@
|
||||
#![allow(clippy::excessive_precision)]
|
||||
// Code taken from https://github.com/statrs-dev/statrs
|
||||
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
|
||||
//! related functions
|
||||
|
||||
mod evaluate {
|
||||
//! Provides functions that don't have a numerical solution and must
|
||||
//! be solved computationally (e.g. evaluation of a polynomial)
|
||||
|
||||
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
|
||||
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
|
||||
/// coeffecient
|
||||
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
|
||||
/// `2z^2 - z + 3`
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// Returns 0 for a 0 length coefficient slice
|
||||
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
|
||||
let n = coeff.len();
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut sum = *coeff.last().unwrap();
|
||||
for c in coeff[0..n - 1].iter().rev() {
|
||||
sum = *c + z * sum;
|
||||
}
|
||||
sum
|
||||
}
|
||||
}
|
||||
use std::f64;
|
||||
|
||||
/// `erf` calculates the error function at `x`.
|
||||
pub fn erf(x: f64) -> f64 {
|
||||
if x.is_nan() {
|
||||
f64::NAN
|
||||
} else if x >= 0.0 && x.is_infinite() {
|
||||
1.0
|
||||
} else if x <= 0.0 && x.is_infinite() {
|
||||
-1.0
|
||||
} else if x == 0. {
|
||||
0.0
|
||||
} else {
|
||||
erf_impl(x, false)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erf_inv` calculates the inverse error function
|
||||
/// at `x`.
|
||||
pub fn erf_inv(x: f64) -> f64 {
|
||||
if x == 0.0 {
|
||||
0.0
|
||||
} else if x >= 1.0 {
|
||||
f64::INFINITY
|
||||
} else if x <= -1.0 {
|
||||
f64::NEG_INFINITY
|
||||
} else if x < 0.0 {
|
||||
erf_inv_impl(-x, 1.0 + x, -1.0)
|
||||
} else {
|
||||
erf_inv_impl(x, 1.0 - x, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erfc` calculates the complementary error function
|
||||
/// at `x`.
|
||||
pub fn erfc(x: f64) -> f64 {
|
||||
if x.is_nan() {
|
||||
f64::NAN
|
||||
} else if x == f64::INFINITY {
|
||||
0.0
|
||||
} else if x == f64::NEG_INFINITY {
|
||||
2.0
|
||||
} else {
|
||||
erf_impl(x, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erfc_inv` calculates the complementary inverse
|
||||
/// error function at `x`.
|
||||
pub fn erfc_inv(x: f64) -> f64 {
|
||||
if x <= 0.0 {
|
||||
f64::INFINITY
|
||||
} else if x >= 2.0 {
|
||||
f64::NEG_INFINITY
|
||||
} else if x > 1.0 {
|
||||
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
|
||||
} else {
|
||||
erf_inv_impl(1.0 - x, x, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
// **********************************************************
|
||||
// ********** Coefficients for erf_impl polynomial **********
|
||||
// **********************************************************
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_impl`
|
||||
/// in the interval [1e-10, 0.5].
|
||||
const ERF_IMPL_AN: &[f64] = &[
|
||||
0.00337916709551257388990745,
|
||||
-0.00073695653048167948530905,
|
||||
-0.374732337392919607868241,
|
||||
0.0817442448733587196071743,
|
||||
-0.0421089319936548595203468,
|
||||
0.0070165709512095756344528,
|
||||
-0.00495091255982435110337458,
|
||||
0.000871646599037922480317225,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_impl`
|
||||
/// in the interval [1e-10, 0.5]
|
||||
const ERF_IMPL_AD: &[f64] = &[
|
||||
1.0,
|
||||
-0.218088218087924645390535,
|
||||
0.412542972725442099083918,
|
||||
-0.0841891147873106755410271,
|
||||
0.0655338856400241519690695,
|
||||
-0.0120019604454941768171266,
|
||||
0.00408165558926174048329689,
|
||||
-0.000615900721557769691924509,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_IMPL_BN: &[f64] = &[
|
||||
-0.0361790390718262471360258,
|
||||
0.292251883444882683221149,
|
||||
0.281447041797604512774415,
|
||||
0.125610208862766947294894,
|
||||
0.0274135028268930549240776,
|
||||
0.00250839672168065762786937,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_IMPL_BD: &[f64] = &[
|
||||
1.0,
|
||||
1.8545005897903486499845,
|
||||
1.43575803037831418074962,
|
||||
0.582827658753036572454135,
|
||||
0.124810476932949746447682,
|
||||
0.0113724176546353285778481,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [0.75, 1.25].
|
||||
const ERF_IMPL_CN: &[f64] = &[
|
||||
-0.0397876892611136856954425,
|
||||
0.153165212467878293257683,
|
||||
0.191260295600936245503129,
|
||||
0.10276327061989304213645,
|
||||
0.029637090615738836726027,
|
||||
0.0046093486780275489468812,
|
||||
0.000307607820348680180548455,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [0.75, 1.25].
|
||||
const ERF_IMPL_CD: &[f64] = &[
|
||||
1.0,
|
||||
1.95520072987627704987886,
|
||||
1.64762317199384860109595,
|
||||
0.768238607022126250082483,
|
||||
0.209793185936509782784315,
|
||||
0.0319569316899913392596356,
|
||||
0.00213363160895785378615014,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [1.25, 2.25].
|
||||
const ERF_IMPL_DN: &[f64] = &[
|
||||
-0.0300838560557949717328341,
|
||||
0.0538578829844454508530552,
|
||||
0.0726211541651914182692959,
|
||||
0.0367628469888049348429018,
|
||||
0.00964629015572527529605267,
|
||||
0.00133453480075291076745275,
|
||||
0.778087599782504251917881e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [1.25, 2.25].
|
||||
const ERF_IMPL_DD: &[f64] = &[
|
||||
1.0,
|
||||
1.75967098147167528287343,
|
||||
1.32883571437961120556307,
|
||||
0.552528596508757581287907,
|
||||
0.133793056941332861912279,
|
||||
0.0179509645176280768640766,
|
||||
0.00104712440019937356634038,
|
||||
-0.106640381820357337177643e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [2.25, 3.5].
|
||||
const ERF_IMPL_EN: &[f64] = &[
|
||||
-0.0117907570137227847827732,
|
||||
0.014262132090538809896674,
|
||||
0.0202234435902960820020765,
|
||||
0.00930668299990432009042239,
|
||||
0.00213357802422065994322516,
|
||||
0.00025022987386460102395382,
|
||||
0.120534912219588189822126e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [2.25, 3.5].
|
||||
const ERF_IMPL_ED: &[f64] = &[
|
||||
1.0,
|
||||
1.50376225203620482047419,
|
||||
0.965397786204462896346934,
|
||||
0.339265230476796681555511,
|
||||
0.0689740649541569716897427,
|
||||
0.00771060262491768307365526,
|
||||
0.000371421101531069302990367,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [3.5, 5.25].
|
||||
const ERF_IMPL_FN: &[f64] = &[
|
||||
-0.00546954795538729307482955,
|
||||
0.00404190278731707110245394,
|
||||
0.0054963369553161170521356,
|
||||
0.00212616472603945399437862,
|
||||
0.000394984014495083900689956,
|
||||
0.365565477064442377259271e-4,
|
||||
0.135485897109932323253786e-5,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [3.5, 5.25].
|
||||
const ERF_IMPL_FD: &[f64] = &[
|
||||
1.0,
|
||||
1.21019697773630784832251,
|
||||
0.620914668221143886601045,
|
||||
0.173038430661142762569515,
|
||||
0.0276550813773432047594539,
|
||||
0.00240625974424309709745382,
|
||||
0.891811817251336577241006e-4,
|
||||
-0.465528836283382684461025e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [5.25, 8].
|
||||
const ERF_IMPL_GN: &[f64] = &[
|
||||
-0.00270722535905778347999196,
|
||||
0.0013187563425029400461378,
|
||||
0.00119925933261002333923989,
|
||||
0.00027849619811344664248235,
|
||||
0.267822988218331849989363e-4,
|
||||
0.923043672315028197865066e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [5.25, 8].
|
||||
const ERF_IMPL_GD: &[f64] = &[
|
||||
1.0,
|
||||
0.814632808543141591118279,
|
||||
0.268901665856299542168425,
|
||||
0.0449877216103041118694989,
|
||||
0.00381759663320248459168994,
|
||||
0.000131571897888596914350697,
|
||||
0.404815359675764138445257e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [8, 11.5].
|
||||
const ERF_IMPL_HN: &[f64] = &[
|
||||
-0.00109946720691742196814323,
|
||||
0.000406425442750422675169153,
|
||||
0.000274499489416900707787024,
|
||||
0.465293770646659383436343e-4,
|
||||
0.320955425395767463401993e-5,
|
||||
0.778286018145020892261936e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [8, 11.5].
|
||||
const ERF_IMPL_HD: &[f64] = &[
|
||||
1.0,
|
||||
0.588173710611846046373373,
|
||||
0.139363331289409746077541,
|
||||
0.0166329340417083678763028,
|
||||
0.00100023921310234908642639,
|
||||
0.24254837521587225125068e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [11.5, 17].
|
||||
const ERF_IMPL_IN: &[f64] = &[
|
||||
-0.00056907993601094962855594,
|
||||
0.000169498540373762264416984,
|
||||
0.518472354581100890120501e-4,
|
||||
0.382819312231928859704678e-5,
|
||||
0.824989931281894431781794e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [11.5, 17].
|
||||
const ERF_IMPL_ID: &[f64] = &[
|
||||
1.0,
|
||||
0.339637250051139347430323,
|
||||
0.043472647870310663055044,
|
||||
0.00248549335224637114641629,
|
||||
0.535633305337152900549536e-4,
|
||||
-0.117490944405459578783846e-12,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [17, 24].
|
||||
const ERF_IMPL_JN: &[f64] = &[
|
||||
-0.000241313599483991337479091,
|
||||
0.574224975202501512365975e-4,
|
||||
0.115998962927383778460557e-4,
|
||||
0.581762134402593739370875e-6,
|
||||
0.853971555085673614607418e-8,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [17, 24].
|
||||
const ERF_IMPL_JD: &[f64] = &[
|
||||
1.0,
|
||||
0.233044138299687841018015,
|
||||
0.0204186940546440312625597,
|
||||
0.000797185647564398289151125,
|
||||
0.117019281670172327758019e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [24, 38].
|
||||
const ERF_IMPL_KN: &[f64] = &[
|
||||
-0.000146674699277760365803642,
|
||||
0.162666552112280519955647e-4,
|
||||
0.269116248509165239294897e-5,
|
||||
0.979584479468091935086972e-7,
|
||||
0.101994647625723465722285e-8,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [24, 38].
|
||||
const ERF_IMPL_KD: &[f64] = &[
|
||||
1.0,
|
||||
0.165907812944847226546036,
|
||||
0.0103361716191505884359634,
|
||||
0.000286593026373868366935721,
|
||||
0.298401570840900340874568e-5,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [38, 60].
|
||||
const ERF_IMPL_LN: &[f64] = &[
|
||||
-0.583905797629771786720406e-4,
|
||||
0.412510325105496173512992e-5,
|
||||
0.431790922420250949096906e-6,
|
||||
0.993365155590013193345569e-8,
|
||||
0.653480510020104699270084e-10,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [38, 60].
|
||||
const ERF_IMPL_LD: &[f64] = &[
|
||||
1.0,
|
||||
0.105077086072039915406159,
|
||||
0.00414278428675475620830226,
|
||||
0.726338754644523769144108e-4,
|
||||
0.477818471047398785369849e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [60, 85].
|
||||
const ERF_IMPL_MN: &[f64] = &[
|
||||
-0.196457797609229579459841e-4,
|
||||
0.157243887666800692441195e-5,
|
||||
0.543902511192700878690335e-7,
|
||||
0.317472492369117710852685e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [60, 85].
|
||||
const ERF_IMPL_MD: &[f64] = &[
|
||||
1.0,
|
||||
0.052803989240957632204885,
|
||||
0.000926876069151753290378112,
|
||||
0.541011723226630257077328e-5,
|
||||
0.535093845803642394908747e-15,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [85, 110].
|
||||
const ERF_IMPL_NN: &[f64] = &[
|
||||
-0.789224703978722689089794e-5,
|
||||
0.622088451660986955124162e-6,
|
||||
0.145728445676882396797184e-7,
|
||||
0.603715505542715364529243e-10,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [85, 110].
|
||||
const ERF_IMPL_ND: &[f64] = &[
|
||||
1.0,
|
||||
0.0375328846356293715248719,
|
||||
0.000467919535974625308126054,
|
||||
0.193847039275845656900547e-5,
|
||||
];
|
||||
|
||||
// **********************************************************
|
||||
// ********** Coefficients for erf_inv_impl polynomial ******
|
||||
// **********************************************************
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0, 0.5].
|
||||
const ERF_INV_IMPL_AN: &[f64] = &[
|
||||
-0.000508781949658280665617,
|
||||
-0.00836874819741736770379,
|
||||
0.0334806625409744615033,
|
||||
-0.0126926147662974029034,
|
||||
-0.0365637971411762664006,
|
||||
0.0219878681111168899165,
|
||||
0.00822687874676915743155,
|
||||
-0.00538772965071242932965,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0, 0.5].
|
||||
const ERF_INV_IMPL_AD: &[f64] = &[
|
||||
1.0,
|
||||
-0.970005043303290640362,
|
||||
-1.56574558234175846809,
|
||||
1.56221558398423026363,
|
||||
0.662328840472002992063,
|
||||
-0.71228902341542847553,
|
||||
-0.0527396382340099713954,
|
||||
0.0795283687341571680018,
|
||||
-0.00233393759374190016776,
|
||||
0.000886216390456424707504,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_INV_IMPL_BN: &[f64] = &[
|
||||
-0.202433508355938759655,
|
||||
0.105264680699391713268,
|
||||
8.37050328343119927838,
|
||||
17.6447298408374015486,
|
||||
-18.8510648058714251895,
|
||||
-44.6382324441786960818,
|
||||
17.445385985570866523,
|
||||
21.1294655448340526258,
|
||||
-3.67192254707729348546,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_INV_IMPL_BD: &[f64] = &[
|
||||
1.0,
|
||||
6.24264124854247537712,
|
||||
3.9713437953343869095,
|
||||
-28.6608180499800029974,
|
||||
-20.1432634680485188801,
|
||||
48.5609213108739935468,
|
||||
10.8268667355460159008,
|
||||
-22.6436933413139721736,
|
||||
1.72114765761200282724,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x less than 3.
|
||||
const ERF_INV_IMPL_CN: &[f64] = &[
|
||||
-0.131102781679951906451,
|
||||
-0.163794047193317060787,
|
||||
0.117030156341995252019,
|
||||
0.387079738972604337464,
|
||||
0.337785538912035898924,
|
||||
0.142869534408157156766,
|
||||
0.0290157910005329060432,
|
||||
0.00214558995388805277169,
|
||||
-0.679465575181126350155e-6,
|
||||
0.285225331782217055858e-7,
|
||||
-0.681149956853776992068e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x less than 3.
|
||||
const ERF_INV_IMPL_CD: &[f64] = &[
|
||||
1.0,
|
||||
3.46625407242567245975,
|
||||
5.38168345707006855425,
|
||||
4.77846592945843778382,
|
||||
2.59301921623620271374,
|
||||
0.848854343457902036425,
|
||||
0.152264338295331783612,
|
||||
0.01105924229346489121,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||
const ERF_INV_IMPL_DN: &[f64] = &[
|
||||
-0.0350353787183177984712,
|
||||
-0.00222426529213447927281,
|
||||
0.0185573306514231072324,
|
||||
0.00950804701325919603619,
|
||||
0.00187123492819559223345,
|
||||
0.000157544617424960554631,
|
||||
0.460469890584317994083e-5,
|
||||
-0.230404776911882601748e-9,
|
||||
0.266339227425782031962e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||
const ERF_INV_IMPL_DD: &[f64] = &[
|
||||
1.0,
|
||||
1.3653349817554063097,
|
||||
0.762059164553623404043,
|
||||
0.220091105764131249824,
|
||||
0.0341589143670947727934,
|
||||
0.00263861676657015992959,
|
||||
0.764675292302794483503e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||
const ERF_INV_IMPL_EN: &[f64] = &[
|
||||
-0.0167431005076633737133,
|
||||
-0.00112951438745580278863,
|
||||
0.00105628862152492910091,
|
||||
0.000209386317487588078668,
|
||||
0.149624783758342370182e-4,
|
||||
0.449696789927706453732e-6,
|
||||
0.462596163522878599135e-8,
|
||||
-0.281128735628831791805e-13,
|
||||
0.99055709973310326855e-16,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||
const ERF_INV_IMPL_ED: &[f64] = &[
|
||||
1.0,
|
||||
0.591429344886417493481,
|
||||
0.138151865749083321638,
|
||||
0.0160746087093676504695,
|
||||
0.000964011807005165528527,
|
||||
0.275335474764726041141e-4,
|
||||
0.282243172016108031869e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||
const ERF_INV_IMPL_FN: &[f64] = &[
|
||||
-0.0024978212791898131227,
|
||||
-0.779190719229053954292e-5,
|
||||
0.254723037413027451751e-4,
|
||||
0.162397777342510920873e-5,
|
||||
0.396341011304801168516e-7,
|
||||
0.411632831190944208473e-9,
|
||||
0.145596286718675035587e-11,
|
||||
-0.116765012397184275695e-17,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||
const ERF_INV_IMPL_FD: &[f64] = &[
|
||||
1.0,
|
||||
0.207123112214422517181,
|
||||
0.0169410838120975906478,
|
||||
0.000690538265622684595676,
|
||||
0.145007359818232637924e-4,
|
||||
0.144437756628144157666e-6,
|
||||
0.509761276599778486139e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x greater than 44.
|
||||
const ERF_INV_IMPL_GN: &[f64] = &[
|
||||
-0.000539042911019078575891,
|
||||
-0.28398759004727721098e-6,
|
||||
0.899465114892291446442e-6,
|
||||
0.229345859265920864296e-7,
|
||||
0.225561444863500149219e-9,
|
||||
0.947846627503022684216e-12,
|
||||
0.135880130108924861008e-14,
|
||||
-0.348890393399948882918e-21,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x greater than 44.
|
||||
const ERF_INV_IMPL_GD: &[f64] = &[
|
||||
1.0,
|
||||
0.0845746234001899436914,
|
||||
0.00282092984726264681981,
|
||||
0.468292921940894236786e-4,
|
||||
0.399968812193862100054e-6,
|
||||
0.161809290887904476097e-8,
|
||||
0.231558608310259605225e-11,
|
||||
];
|
||||
|
||||
/// `erf_impl` computes the error function at `z`.
|
||||
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
|
||||
fn erf_impl(z: f64, inv: bool) -> f64 {
|
||||
if z < 0.0 {
|
||||
if !inv {
|
||||
return -erf_impl(-z, false);
|
||||
}
|
||||
if z < -0.5 {
|
||||
return 2.0 - erf_impl(-z, true);
|
||||
}
|
||||
return 1.0 + erf_impl(-z, false);
|
||||
}
|
||||
|
||||
let result = if z < 0.5 {
|
||||
if z < 1e-10 {
|
||||
z * 1.125 + z * 0.003379167095512573896158903121545171688
|
||||
} else {
|
||||
z * 1.125
|
||||
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
|
||||
}
|
||||
} else if z < 110.0 {
|
||||
let (r, b) = if z < 0.75 {
|
||||
(
|
||||
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
|
||||
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
|
||||
0.3440242112,
|
||||
)
|
||||
} else if z < 1.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
|
||||
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
|
||||
0.419990927,
|
||||
)
|
||||
} else if z < 2.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
|
||||
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
|
||||
0.4898625016,
|
||||
)
|
||||
} else if z < 3.5 {
|
||||
(
|
||||
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
|
||||
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
|
||||
0.5317370892,
|
||||
)
|
||||
} else if z < 5.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
|
||||
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
|
||||
0.5489973426,
|
||||
)
|
||||
} else if z < 8.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
|
||||
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
|
||||
0.5571740866,
|
||||
)
|
||||
} else if z < 11.5 {
|
||||
(
|
||||
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
|
||||
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
|
||||
0.5609807968,
|
||||
)
|
||||
} else if z < 17.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
|
||||
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
|
||||
0.5626493692,
|
||||
)
|
||||
} else if z < 24.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
|
||||
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
|
||||
0.5634598136,
|
||||
)
|
||||
} else if z < 38.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
|
||||
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
|
||||
0.5638477802,
|
||||
)
|
||||
} else if z < 60.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
|
||||
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
|
||||
0.5640528202,
|
||||
)
|
||||
} else if z < 85.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
|
||||
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
|
||||
0.5641309023,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
|
||||
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
|
||||
0.5641584396,
|
||||
)
|
||||
};
|
||||
let g = (-z * z).exp() / z;
|
||||
g * b + g * r
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if inv && z >= 0.5 {
|
||||
result
|
||||
} else if z >= 0.5 || inv {
|
||||
1.0 - result
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
// `erf_inv_impl` computes the inverse error function where
|
||||
// `p`,`q`, and `s` are the first, second, and third intermediate
|
||||
// parameters respectively
|
||||
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
|
||||
let result = if p <= 0.5 {
|
||||
let y = 0.0891314744949340820313;
|
||||
let g = p * (p + 10.0);
|
||||
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
|
||||
g * y + g * r
|
||||
} else if q >= 0.25 {
|
||||
let y = 2.249481201171875;
|
||||
let g = (-2.0 * q.ln()).sqrt();
|
||||
let xs = q - 0.25;
|
||||
let r =
|
||||
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
|
||||
g / (y + r)
|
||||
} else {
|
||||
let x = (-q.ln()).sqrt();
|
||||
if x < 3.0 {
|
||||
let y = 0.807220458984375;
|
||||
let xs = x - 1.125;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
|
||||
y * x + r * x
|
||||
} else if x < 6.0 {
|
||||
let y = 0.93995571136474609375;
|
||||
let xs = x - 3.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
|
||||
y * x + r * x
|
||||
} else if x < 18.0 {
|
||||
let y = 0.98362827301025390625;
|
||||
let xs = x - 6.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
|
||||
y * x + r * x
|
||||
} else if x < 44.0 {
|
||||
let y = 0.99714565277099609375;
|
||||
let xs = x - 18.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
|
||||
y * x + r * x
|
||||
} else {
|
||||
let y = 0.99941349029541015625;
|
||||
let xs = x - 44.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
|
||||
y * x + r * x
|
||||
}
|
||||
};
|
||||
s * result
|
||||
}
|
@ -1,4 +1,7 @@
|
||||
pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
|
||||
pub trait VecOps: num_traits::NumAssign + Copy {
|
||||
fn min(self, rhs: Self) -> Self;
|
||||
fn max(self, rhs: Self) -> Self;
|
||||
|
||||
/// Dot-product of two vectors.
|
||||
///
|
||||
/// # Safety
|
||||
@ -37,10 +40,7 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
|
||||
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = *xs;
|
||||
for i in 1..len {
|
||||
let x = *xs.add(i);
|
||||
if x > *res {
|
||||
*res = x
|
||||
}
|
||||
*res = (*res).max(*xs.add(i))
|
||||
}
|
||||
}
|
||||
|
||||
@ -54,15 +54,22 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
|
||||
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = *xs;
|
||||
for i in 1..len {
|
||||
let x = *xs.add(i);
|
||||
if x < *res {
|
||||
*res = x
|
||||
}
|
||||
*res = (*res).min(*xs.add(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for f32 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
super::vec_dot_f32(lhs, rhs, res, len)
|
||||
@ -75,6 +82,16 @@ impl VecOps for f32 {
|
||||
}
|
||||
|
||||
impl VecOps for half::f16 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
let mut res_f32 = 0f32;
|
||||
@ -83,11 +100,61 @@ impl VecOps for half::f16 {
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for f64 {}
|
||||
impl VecOps for half::bf16 {}
|
||||
impl VecOps for u8 {}
|
||||
impl VecOps for u32 {}
|
||||
impl VecOps for i64 {}
|
||||
impl VecOps for f64 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for half::bf16 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for u8 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
<Self as Ord>::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
<Self as Ord>::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for u32 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
<Self as Ord>::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
<Self as Ord>::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for i64 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
<Self as Ord>::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
<Self as Ord>::max(self, other)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
||||
|
@ -1,3 +1,4 @@
|
||||
pub mod erf;
|
||||
pub mod kernels;
|
||||
|
||||
trait Cpu<const ARR: usize> {
|
||||
|
@ -2,6 +2,10 @@ use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
use rayon::prelude::*;
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
// intercept the oom errors to avoid panicking and provide a proper error.
|
||||
@ -445,7 +449,7 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
|
||||
}
|
||||
|
||||
// This function maps over two strided index sequences.
|
||||
fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
@ -525,7 +529,7 @@ fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||
}
|
||||
|
||||
// Similar to binary_map but with vectorized variants.
|
||||
fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
@ -723,6 +727,36 @@ impl Map1 for MaxPool2D {
|
||||
}
|
||||
}
|
||||
|
||||
struct UpsampleNearest1D(usize);
|
||||
|
||||
impl Map1 for UpsampleNearest1D {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
// TODO: Specialized implementation for the case 2*sz?
|
||||
let dst_sz = self.0;
|
||||
let (b_sz, c, src_sz) = layout.shape().dims3()?;
|
||||
let stride = layout.stride();
|
||||
let stride_sz = stride[2];
|
||||
let src_index = layout.start_offset();
|
||||
let scale_sz = src_sz as f64 / dst_sz as f64;
|
||||
let mut dst = vec![T::zero(); b_sz * c * dst_sz];
|
||||
let src_idxs = (0..dst_sz)
|
||||
.map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
|
||||
.collect::<Vec<_>>();
|
||||
for b_idx in 0..b_sz {
|
||||
let dst = &mut dst[b_idx * c * dst_sz..];
|
||||
let src_index = src_index + b_idx * stride[0];
|
||||
for c_idx in 0..c {
|
||||
let dst = &mut dst[c_idx * dst_sz..];
|
||||
let src_index = src_index + c_idx * stride[1];
|
||||
for (idx, src_idx) in src_idxs.iter().enumerate() {
|
||||
dst[idx] = src[src_index + src_idx * stride_sz]
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct UpsampleNearest2D(usize, usize);
|
||||
|
||||
impl Map1 for UpsampleNearest2D {
|
||||
@ -1052,10 +1086,8 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
let num_threads = crate::utils::get_num_threads();
|
||||
|
||||
for offset in 0..p.k_size {
|
||||
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
|
||||
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||
let dst_idx = dst_c_idx * l_out;
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
|
||||
@ -1090,6 +1122,140 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col1D {
|
||||
l_k: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
impl Im2Col1D {
|
||||
fn l_out(&self, l: usize) -> usize {
|
||||
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||
}
|
||||
}
|
||||
|
||||
impl Map1 for Im2Col1D {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let &Self {
|
||||
l_k,
|
||||
stride,
|
||||
dilation,
|
||||
padding,
|
||||
} = self;
|
||||
let (b, c, l) = layout.shape().dims3()?;
|
||||
let l_out = self.l_out(l);
|
||||
let src = &vs[layout.start_offset()..];
|
||||
let mut dst = vec![T::zero(); b * l_out * c * l_k];
|
||||
let (src_s0, src_s1, src_s2) = {
|
||||
let s = layout.stride();
|
||||
(s[0], s[1], s[2])
|
||||
};
|
||||
// TODO: provide specialized kernels for the common use cases.
|
||||
// - l_k = 1
|
||||
// - padding = 0
|
||||
// - stride = 1
|
||||
// - dilation = 1
|
||||
for b_idx in 0..b {
|
||||
let src_idx = b_idx * src_s0;
|
||||
let dst_idx = b_idx * l_out * c * l_k;
|
||||
for l_idx in 0..l_out {
|
||||
let dst_idx = dst_idx + l_idx * c * l_k;
|
||||
for c_idx in 0..c {
|
||||
let dst_idx = dst_idx + c_idx * l_k;
|
||||
let src_idx = c_idx * src_s1 + src_idx;
|
||||
for l_k_idx in 0..l_k {
|
||||
let src_l = l_idx * stride + l_k_idx * dilation;
|
||||
if padding != 0 && (src_l < padding || src_l >= l + padding) {
|
||||
continue;
|
||||
}
|
||||
let src_l = src_l - padding;
|
||||
let src_idx = src_idx + src_l * src_s2;
|
||||
let dst_idx = dst_idx + l_k_idx;
|
||||
dst[dst_idx] = src[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col {
|
||||
h_k: usize,
|
||||
w_k: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
impl Im2Col {
|
||||
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
|
||||
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
|
||||
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
|
||||
(h_out, w_out)
|
||||
}
|
||||
}
|
||||
|
||||
impl Map1 for Im2Col {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let &Self {
|
||||
h_k,
|
||||
w_k,
|
||||
stride,
|
||||
dilation,
|
||||
padding,
|
||||
} = self;
|
||||
let (b, c, h, w) = layout.shape().dims4()?;
|
||||
let (h_out, w_out) = self.hw_out(h, w);
|
||||
let src = &vs[layout.start_offset()..];
|
||||
let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
|
||||
let (src_s0, src_s1, src_s2, src_s3) = {
|
||||
let s = layout.stride();
|
||||
(s[0], s[1], s[2], s[3])
|
||||
};
|
||||
// TODO: provide specialized kernels for the common use cases.
|
||||
// - h_k = w_k = 1
|
||||
// - padding = 0
|
||||
// - stride = 1
|
||||
// - dilation = 1
|
||||
for b_idx in 0..b {
|
||||
let src_idx = b_idx * src_s0;
|
||||
let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
|
||||
for h_idx in 0..h_out {
|
||||
let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
|
||||
for w_idx in 0..w_out {
|
||||
let dst_idx = dst_idx + w_idx * c * h_k * w_k;
|
||||
for c_idx in 0..c {
|
||||
let dst_idx = dst_idx + c_idx * h_k * w_k;
|
||||
let src_idx = c_idx * src_s1 + src_idx;
|
||||
for h_k_idx in 0..h_k {
|
||||
let src_h = h_idx * stride + h_k_idx * dilation;
|
||||
if padding != 0 && (src_h < padding || src_h >= h + padding) {
|
||||
continue;
|
||||
}
|
||||
let src_h = src_h - padding;
|
||||
let src_idx = src_idx + src_h * src_s2;
|
||||
let dst_idx = dst_idx + h_k_idx * w_k;
|
||||
for w_k_idx in 0..w_k {
|
||||
let src_w = w_idx * stride + w_k_idx * dilation;
|
||||
if padding != 0 && (src_w < padding || src_w >= w + padding) {
|
||||
continue;
|
||||
}
|
||||
let src_w = src_w - padding;
|
||||
let src_idx = src_idx + src_w * src_s3;
|
||||
let dst_idx = dst_idx + w_k_idx;
|
||||
dst[dst_idx] = src[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||
|
||||
impl<'a> Map2 for Conv2D<'a> {
|
||||
@ -1123,11 +1289,9 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
let num_threads = crate::utils::get_num_threads();
|
||||
|
||||
for offset_h in 0..p.k_h {
|
||||
for offset_w in 0..p.k_w {
|
||||
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
|
||||
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||
let dst_idx = dst_c_idx * out_w * out_h;
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| {
|
||||
@ -1216,11 +1380,10 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
}
|
||||
}
|
||||
}
|
||||
let num_threads = crate::utils::get_num_threads();
|
||||
|
||||
for k_y in 0..p.k_h {
|
||||
for k_x in 0..p.k_w {
|
||||
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
|
||||
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| {
|
||||
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
|
||||
@ -1298,8 +1461,9 @@ impl Map2 for MatMul {
|
||||
) -> Result<Vec<T>> {
|
||||
use gemm::{gemm, Parallelism};
|
||||
|
||||
if T::DTYPE == DType::BF16 {
|
||||
return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?;
|
||||
match T::DTYPE {
|
||||
DType::F16 | DType::F32 | DType::F64 => {}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
|
||||
}
|
||||
|
||||
let (b, m, n, k) = self.0;
|
||||
@ -2003,6 +2167,10 @@ impl BackendStorage for CpuStorage {
|
||||
MaxPool2D(kernel_size, stride).map(self, layout)
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
|
||||
UpsampleNearest1D(sz).map(self, layout)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
||||
UpsampleNearest2D(h, w).map(self, layout)
|
||||
}
|
||||
@ -2231,7 +2399,40 @@ impl BackendStorage for CpuStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
Conv1D(params).map(self, l, kernel, kernel_l)
|
||||
if !USE_IM2COL_CONV1D {
|
||||
return Conv1D(params).map(self, l, kernel, kernel_l);
|
||||
}
|
||||
let op = Im2Col1D {
|
||||
l_k: params.k_size,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
};
|
||||
let col = op.map(self, l)?;
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let l_out = params.l_out();
|
||||
let k = op.l_k * params.c_in;
|
||||
let m = l_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
@ -2241,7 +2442,43 @@ impl BackendStorage for CpuStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
Conv2D(params).map(self, l, kernel, kernel_l)
|
||||
if !USE_IM2COL_CONV2D {
|
||||
return Conv2D(params).map(self, l, kernel, kernel_l);
|
||||
}
|
||||
let op = Im2Col {
|
||||
h_k: params.k_h,
|
||||
w_k: params.k_w,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
};
|
||||
let col = op.map(self, l)?;
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let (h_out, w_out) = (params.out_h(), params.out_w());
|
||||
let k = op.h_k * op.w_k * params.c_in;
|
||||
let m = h_out * w_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
|
||||
.transpose(1, 2)?
|
||||
.transpose(1, 3)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use candle_kernels as kernels;
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{
|
||||
@ -312,6 +312,13 @@ impl BackendDevice for CudaDevice {
|
||||
// cudarc changes.
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
// curand can only generate an odd number of values.
|
||||
// https://github.com/huggingface/candle/issues/734
|
||||
let elem_count_round = if elem_count % 2 == 1 {
|
||||
elem_count + 1
|
||||
} else {
|
||||
elem_count
|
||||
};
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
@ -321,7 +328,7 @@ impl BackendDevice for CudaDevice {
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||
curand
|
||||
.0
|
||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||
@ -329,7 +336,7 @@ impl BackendDevice for CudaDevice {
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
@ -383,7 +390,7 @@ impl BackendDevice for CudaDevice {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum CudaStorageSlice {
|
||||
pub enum CudaStorageSlice {
|
||||
U8(CudaSlice<u8>),
|
||||
U32(CudaSlice<u32>),
|
||||
I64(CudaSlice<i64>),
|
||||
@ -394,7 +401,7 @@ enum CudaStorageSlice {
|
||||
}
|
||||
type S = CudaStorageSlice;
|
||||
|
||||
trait Map1 {
|
||||
pub trait Map1 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -416,7 +423,7 @@ trait Map1 {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2 {
|
||||
pub trait Map2 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
@ -441,7 +448,7 @@ trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2InPlace {
|
||||
pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
@ -472,7 +479,7 @@ trait Map2InPlace {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map1Any {
|
||||
pub trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -495,7 +502,7 @@ trait Map1Any {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2Any {
|
||||
pub trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
@ -532,7 +539,7 @@ impl Map1 for Clone {
|
||||
}
|
||||
}
|
||||
|
||||
fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
pub fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
let dtype = T::DTYPE.as_str();
|
||||
format!("{root}_{dtype}")
|
||||
}
|
||||
@ -593,6 +600,105 @@ impl Map1 for Elu {
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col1D {
|
||||
l_k: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
impl Im2Col1D {
|
||||
fn l_out(&self, l: usize) -> usize {
|
||||
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||
}
|
||||
}
|
||||
|
||||
impl Map1 for Im2Col1D {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let l_out = self.l_out(dims[2]);
|
||||
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (
|
||||
dst_el,
|
||||
l_out,
|
||||
self.l_k,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
&ds,
|
||||
src,
|
||||
&dst,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col {
|
||||
h_k: usize,
|
||||
w_k: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
impl Im2Col {
|
||||
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
|
||||
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
|
||||
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
|
||||
(h_out, w_out)
|
||||
}
|
||||
}
|
||||
|
||||
impl Map1 for Im2Col {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
|
||||
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (
|
||||
dst_el,
|
||||
h_out,
|
||||
w_out,
|
||||
self.h_k,
|
||||
self.w_k,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
&ds,
|
||||
src,
|
||||
&dst,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Powf(f64);
|
||||
impl Map1 for Powf {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
@ -1310,8 +1416,8 @@ fn slice_src_and_dst<'a, T>(
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStorage {
|
||||
slice: CudaStorageSlice,
|
||||
device: CudaDevice,
|
||||
pub slice: CudaStorageSlice,
|
||||
pub device: CudaDevice,
|
||||
}
|
||||
|
||||
pub trait CudaDType: Sized {
|
||||
@ -1650,9 +1756,46 @@ impl BackendStorage for CudaStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
|
||||
let device = self.device().clone();
|
||||
if !USE_IM2COL_CONV1D {
|
||||
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
|
||||
let col = Im2Col1D {
|
||||
l_k: params.k_size,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
padding: params.padding,
|
||||
}
|
||||
.map(&self.slice, &device, l)?;
|
||||
let col = Self { slice: col, device };
|
||||
let l_out = params.l_out();
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let k = params.k_size * params.c_in;
|
||||
let m = l_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cudnn"))]
|
||||
@ -1663,9 +1806,50 @@ impl BackendStorage for CudaStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
let device = self.device().clone();
|
||||
if !USE_IM2COL_CONV2D {
|
||||
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
|
||||
let col = Im2Col {
|
||||
h_k: params.k_h,
|
||||
w_k: params.k_w,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
padding: params.padding,
|
||||
}
|
||||
.map(&self.slice, &device, l)?;
|
||||
let col = Self { slice: col, device };
|
||||
let h_out = params.out_h();
|
||||
let w_out = params.out_w();
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let k = params.k_h * params.k_w * params.c_in;
|
||||
let m = h_out * w_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||
.transpose(1, 2)?
|
||||
.transpose(1, 3)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
#[cfg(feature = "cudnn")]
|
||||
@ -1770,6 +1954,10 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> {
|
||||
crate::bail!("upsample-nearest1d is not supported on cuda")
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
|
||||
@ -1889,6 +2077,9 @@ impl BackendStorage for CudaStorage {
|
||||
let src_shape = src_l.shape();
|
||||
let dims = src_shape.dims();
|
||||
let el_count = src_shape.elem_count();
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
|
||||
|
@ -54,8 +54,8 @@ pub(crate) fn launch_conv2d<
|
||||
let x_shape = [
|
||||
params.b_size as i32,
|
||||
params.c_in as i32,
|
||||
params.i_w as i32,
|
||||
params.i_h as i32,
|
||||
params.i_w as i32,
|
||||
];
|
||||
// Note that `src` already starts at the proper offset.
|
||||
let x = if src_l.is_contiguous() {
|
||||
@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d<
|
||||
[
|
||||
params.c_out as i32,
|
||||
params.c_in as i32,
|
||||
params.k_w as i32,
|
||||
params.k_h as i32,
|
||||
params.k_w as i32,
|
||||
],
|
||||
)?;
|
||||
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
||||
let y = cudnn.create_4d_tensor(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[params.b_size as i32, params.c_out as i32, w_out, h_out],
|
||||
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
||||
)?;
|
||||
let conv2d = Conv2dForward {
|
||||
conv: &conv,
|
||||
|
@ -1,15 +1,24 @@
|
||||
//! Types for elements that can be stored and manipulated using tensors.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{CpuStorage, Error, Result};
|
||||
|
||||
/// The different types of elements allowed in tensors.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
// Unsigned 8 bits integer.
|
||||
U8,
|
||||
// Unsigned 32 bits integer.
|
||||
U32,
|
||||
// Signed 64 bits integer.
|
||||
I64,
|
||||
// Brain floating-point using half precision (16 bits).
|
||||
BF16,
|
||||
// Floating-point using half precision (16 bits).
|
||||
F16,
|
||||
// Floating-point using single precision (32 bits).
|
||||
F32,
|
||||
// Floating-point using double precision (64 bits).
|
||||
F64,
|
||||
}
|
||||
|
||||
@ -33,6 +42,7 @@ impl std::str::FromStr for DType {
|
||||
}
|
||||
|
||||
impl DType {
|
||||
/// String representation for dtypes.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::U8 => "u8",
|
||||
@ -45,6 +55,7 @@ impl DType {
|
||||
}
|
||||
}
|
||||
|
||||
/// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
|
||||
pub fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
Self::U8 => 1,
|
||||
|
@ -152,6 +152,10 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ pub enum Error {
|
||||
UnsupportedDTypeForOp(DType, &'static str),
|
||||
|
||||
// === Dimension Index Errors ===
|
||||
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
|
||||
#[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
|
||||
DimOutOfRange {
|
||||
shape: Shape,
|
||||
dim: i32,
|
||||
@ -207,11 +207,11 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
impl Error {
|
||||
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
Self::Wrapped(Box::new(err))
|
||||
Self::Wrapped(Box::new(err)).bt()
|
||||
}
|
||||
|
||||
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
Self::Msg(err.to_string())
|
||||
Self::Msg(err.to_string()).bt()
|
||||
}
|
||||
|
||||
pub fn bt(self) -> Self {
|
||||
|
@ -46,19 +46,31 @@ impl Tensor {
|
||||
current_dim += 1;
|
||||
out
|
||||
}
|
||||
TensorIndexer::IndexSelect(indexes) => {
|
||||
if indexes.rank() != 1 {
|
||||
crate::bail!("multi-dimensional tensor indexing is not supported")
|
||||
}
|
||||
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
|
||||
current_dim += 1;
|
||||
out
|
||||
}
|
||||
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
|
||||
};
|
||||
}
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
/// Generic structure used to index a slice of the tensor
|
||||
pub enum TensorIndexer {
|
||||
/// This selects the elemnts for which an index has some specific value.
|
||||
Select(usize),
|
||||
/// This is a regular slice, purely indexing a chunk of the tensor
|
||||
Narrow(Bound<usize>, Bound<usize>),
|
||||
/// Indexing via a 1d tensor
|
||||
IndexSelect(Tensor),
|
||||
Err(Error),
|
||||
}
|
||||
|
||||
impl From<usize> for TensorIndexer {
|
||||
@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u32]> for TensorIndexer {
|
||||
fn from(index: &[u32]) -> Self {
|
||||
match Tensor::new(index, &crate::Device::Cpu) {
|
||||
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
|
||||
Err(e) => TensorIndexer::Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<u32>> for TensorIndexer {
|
||||
fn from(index: Vec<u32>) -> Self {
|
||||
let len = index.len();
|
||||
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
|
||||
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
|
||||
Err(e) => TensorIndexer::Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Tensor> for TensorIndexer {
|
||||
fn from(tensor: &Tensor) -> Self {
|
||||
TensorIndexer::IndexSelect(tensor.clone())
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_from_range {
|
||||
($range_type:ty) => {
|
||||
impl From<$range_type> for TensorIndexer {
|
||||
|
@ -59,6 +59,7 @@ mod op;
|
||||
pub mod pickle;
|
||||
pub mod quantized;
|
||||
pub mod safetensors;
|
||||
pub mod scalar;
|
||||
pub mod shape;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
@ -109,14 +110,8 @@ impl ToUsize2 for (usize, usize) {
|
||||
}
|
||||
|
||||
// A simple trait defining a module with forward method using a single argument.
|
||||
pub trait Module: std::fmt::Debug {
|
||||
pub trait Module {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
|
||||
/// Change the module to use training mode vs eval mode.
|
||||
///
|
||||
/// The default implementation does nothing as this is only used for a couple modules such as
|
||||
/// dropout or batch-normalization.
|
||||
fn set_training(&mut self, _training: bool) {}
|
||||
}
|
||||
|
||||
impl Module for quantized::QMatMul {
|
||||
@ -124,3 +119,9 @@ impl Module for quantized::QMatMul {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self(xs)
|
||||
}
|
||||
}
|
||||
|
@ -58,6 +58,8 @@ pub enum UnaryOp {
|
||||
Sqr,
|
||||
Sqrt,
|
||||
Gelu,
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Tanh,
|
||||
}
|
||||
@ -116,6 +118,7 @@ pub enum Op {
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
UpsampleNearest1D(Tensor),
|
||||
UpsampleNearest2D(Tensor),
|
||||
|
||||
Cat(Vec<Tensor>, usize),
|
||||
@ -324,6 +327,8 @@ pub(crate) struct Recip;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Tanh;
|
||||
|
||||
@ -600,6 +605,92 @@ impl UnaryOpT for Gelu {
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::vd_gelu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::accelerate::vs_gelu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::accelerate::vd_gelu(xs, ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Erf {
|
||||
const NAME: &'static str = "erf";
|
||||
const KERNEL: &'static str = "uerf";
|
||||
const V: Self = Erf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
crate::cpu::erf::erf(v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for GeluErf {
|
||||
const NAME: &'static str = "gelu_erf";
|
||||
const KERNEL: &'static str = "ugelu_erf";
|
||||
const V: Self = GeluErf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Relu {
|
||||
|
@ -85,7 +85,7 @@ const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
|
||||
pub struct BlockQ8_1 {
|
||||
pub(crate) d: f16,
|
||||
pub(crate) s: f16,
|
||||
pub(crate) qs: [u8; QK8_1],
|
||||
pub(crate) qs: [i8; QK8_1],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
|
||||
|
||||
@ -278,6 +278,7 @@ impl GgmlType for BlockQ4_1 {
|
||||
}
|
||||
|
||||
sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
|
||||
+ f16::to_f32(xs.m) * f16::to_f32(ys.s)
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
@ -471,6 +472,7 @@ impl GgmlType for BlockQ5_1 {
|
||||
}
|
||||
|
||||
sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
|
||||
+ f16::to_f32(xs.m) * f16::to_f32(ys.s)
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
@ -652,8 +654,8 @@ impl GgmlType for BlockQ8_1 {
|
||||
for j in 0..Self::BLCK_SIZE / 2 {
|
||||
let v0 = xs[j] * id;
|
||||
let v1 = xs[j + Self::BLCK_SIZE / 2] * id;
|
||||
ys.qs[j] = f32::round(v0) as u8;
|
||||
ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as u8;
|
||||
ys.qs[j] = f32::round(v0) as i8;
|
||||
ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as i8;
|
||||
sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32;
|
||||
}
|
||||
ys.s = f16::from_f32(sum as f32) * ys.d;
|
||||
|
@ -229,7 +229,7 @@ impl QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||
|
||||
impl QMatMul {
|
||||
|
@ -78,11 +78,7 @@ impl st::View for &Tensor {
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
pub fn save_safetensors<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
name: &str,
|
||||
filename: P,
|
||||
) -> Result<()> {
|
||||
pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
|
||||
let data = [(name, self.clone())];
|
||||
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
|
||||
}
|
||||
@ -267,7 +263,7 @@ impl MmapedFile {
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let inner = memmap2::MmapOptions::new()
|
||||
|
23
candle-core/src/scalar.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use crate::{Result, Tensor, WithDType};
|
||||
|
||||
pub enum TensorScalar {
|
||||
Tensor(Tensor),
|
||||
Scalar(Tensor),
|
||||
}
|
||||
|
||||
pub trait TensorOrScalar {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar>;
|
||||
}
|
||||
|
||||
impl TensorOrScalar for &Tensor {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
||||
Ok(TensorScalar::Tensor(self.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: WithDType> TensorOrScalar for T {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
||||
let scalar = Tensor::new(self, &crate::Device::Cpu)?;
|
||||
Ok(TensorScalar::Scalar(scalar))
|
||||
}
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
//! The shape of a tensor is a tuple with the size of each of its dimensions.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{Error, Result};
|
||||
|
||||
@ -72,6 +73,14 @@ impl From<(usize, usize, usize, usize, usize)> for Shape {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
|
||||
fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![
|
||||
d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<usize>> for Shape {
|
||||
fn from(dims: Vec<usize>) -> Self {
|
||||
Self(dims)
|
||||
@ -119,6 +128,7 @@ impl Shape {
|
||||
Self(dims.to_vec())
|
||||
}
|
||||
|
||||
/// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
|
||||
pub fn rank(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
@ -127,10 +137,12 @@ impl Shape {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// The dimensions as a slice of `usize`.
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// The total number of elements, this is the product of all dimension sizes.
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
}
|
||||
@ -182,6 +194,8 @@ impl Shape {
|
||||
true
|
||||
}
|
||||
|
||||
/// Modifies the shape by adding a list of additional dimensions at the end of the existing
|
||||
/// dimensions.
|
||||
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
|
||||
self.0.extend(additional_dims);
|
||||
self
|
||||
@ -419,6 +433,29 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
let d3 = self.3.to_index(shape, op)?;
|
||||
let d4 = self.4.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2, d3, d4])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
let d3 = self.3.to_index(shape, op)?;
|
||||
let d4 = self.4.to_index(shape, op)?;
|
||||
let d5 = self.5.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2, d3, d4, d5])
|
||||
}
|
||||
}
|
||||
|
||||
extract_dims!(dims0, 0, |_: &[usize]| (), ());
|
||||
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
@ -457,3 +494,171 @@ mod tests {
|
||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ShapeWithOneHole {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape>;
|
||||
}
|
||||
|
||||
impl<S: Into<Shape>> ShapeWithOneHole for S {
|
||||
fn into_shape(self, _el_count: usize) -> Result<Shape> {
|
||||
Ok(self.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((),) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
Ok(el_count.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1) = self;
|
||||
if el_count % d1 != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||
}
|
||||
Ok((el_count / d1, d1).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, ()) = self;
|
||||
if el_count % d1 != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||
}
|
||||
Ok((d1, el_count / d1).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2) = self;
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2) = self;
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, ()) = self;
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, ()) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3, d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3, d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3, d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, (), d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, el_count / d, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, d4, ()) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, d4, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
@ -369,6 +369,19 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
|
@ -1,8 +1,10 @@
|
||||
//! Tensors are N-dimenional matrixes of elements using a single data type.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{
|
||||
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
||||
};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
@ -103,6 +105,28 @@ macro_rules! binary_op {
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! binary_op_scalar {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
let rhs = match rhs.to_tensor_scalar()? {
|
||||
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
|
||||
crate::scalar::TensorScalar::Scalar(rhs) => rhs
|
||||
.to_dtype(self.dtype())?
|
||||
.to_device(self.device())?
|
||||
.broadcast_as(self.shape())?,
|
||||
};
|
||||
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage(),
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! broadcast_binary_op {
|
||||
($fn_name:ident, $inner_fn_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
@ -445,8 +469,8 @@ impl Tensor {
|
||||
binary_op!(mul, Mul);
|
||||
binary_op!(sub, Sub);
|
||||
binary_op!(div, Div);
|
||||
binary_op!(maximum, Maximum);
|
||||
binary_op!(minimum, Minimum);
|
||||
binary_op_scalar!(maximum, Maximum);
|
||||
binary_op_scalar!(minimum, Minimum);
|
||||
broadcast_binary_op!(broadcast_add, add);
|
||||
broadcast_binary_op!(broadcast_mul, mul);
|
||||
broadcast_binary_op!(broadcast_sub, sub);
|
||||
@ -465,6 +489,8 @@ impl Tensor {
|
||||
unary_op!(sqr, Sqr);
|
||||
unary_op!(sqrt, Sqrt);
|
||||
unary_op!(gelu, Gelu);
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
|
||||
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
|
||||
@ -642,7 +668,12 @@ impl Tensor {
|
||||
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = 1;
|
||||
let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
|
||||
let op = match op {
|
||||
ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
|
||||
BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
|
||||
}
|
||||
ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
|
||||
};
|
||||
let res = from_storage(storage, dims, op, false);
|
||||
if keepdim {
|
||||
Ok(res)
|
||||
@ -775,8 +806,15 @@ impl Tensor {
|
||||
/// comparison operation is specified by the `op` argument.
|
||||
///
|
||||
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
|
||||
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, "cmp")?;
|
||||
pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
|
||||
let rhs = match rhs.to_tensor_scalar()? {
|
||||
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
|
||||
crate::scalar::TensorScalar::Scalar(rhs) => rhs
|
||||
.to_dtype(self.dtype())?
|
||||
.to_device(self.device())?
|
||||
.broadcast_as(self.shape())?,
|
||||
};
|
||||
let shape = self.same_shape_binary_op(&rhs, "cmp")?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
|
||||
@ -785,45 +823,68 @@ impl Tensor {
|
||||
}
|
||||
|
||||
/// Element-wise equality.
|
||||
pub fn eq(&self, rhs: &Self) -> Result<Self> {
|
||||
pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Eq)
|
||||
}
|
||||
|
||||
/// Element-wise non-equality.
|
||||
pub fn ne(&self, rhs: &Self) -> Result<Self> {
|
||||
pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Ne)
|
||||
}
|
||||
|
||||
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn lt(&self, rhs: &Self) -> Result<Self> {
|
||||
pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Lt)
|
||||
}
|
||||
|
||||
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn gt(&self, rhs: &Self) -> Result<Self> {
|
||||
pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Gt)
|
||||
}
|
||||
|
||||
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn ge(&self, rhs: &Self) -> Result<Self> {
|
||||
pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Ge)
|
||||
}
|
||||
|
||||
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn le(&self, rhs: &Self) -> Result<Self> {
|
||||
pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Le)
|
||||
}
|
||||
|
||||
/// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
|
||||
/// Clamp the tensor values to be between `min` and `max`.
|
||||
pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
|
||||
self.maximum(min)?.minimum(max)
|
||||
}
|
||||
|
||||
/// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
|
||||
///
|
||||
/// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
|
||||
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
||||
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
||||
let (n, c, _l) = self.dims3()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest1d(self.layout(), target_size)?;
|
||||
Ok(from_storage(storage, (n, c, target_size), op, false))
|
||||
}
|
||||
|
||||
/// Alias for `interpolate1d`.
|
||||
pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
|
||||
self.interpolate1d(target_size)
|
||||
}
|
||||
|
||||
/// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
|
||||
/// nearest element.
|
||||
///
|
||||
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
||||
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
||||
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
let (n, c, _h, _w) = self.dims4()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||
let storage = self
|
||||
@ -832,6 +893,11 @@ impl Tensor {
|
||||
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
|
||||
}
|
||||
|
||||
/// Alias for `interpolate2d`.
|
||||
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
self.interpolate2d(target_h, target_w)
|
||||
}
|
||||
|
||||
/// 2D average pooling over an input tensor with multiple channels.
|
||||
///
|
||||
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
||||
@ -1684,12 +1750,15 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), true))
|
||||
}
|
||||
|
||||
// TODO: Do we want to allow target shape using -1 on some dimensions?
|
||||
/// Reshape returns a tensor with the target shape provided that the number of elements of the
|
||||
/// original tensor is the same.
|
||||
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
|
||||
/// a new storage and copies the data over, the returned tensor is always contiguous.
|
||||
///
|
||||
/// The shape can be specified using a tuple of `usize` and at most one `()` in which case
|
||||
/// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
|
||||
/// as to match the number of elements in the tensor.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, DType, Device, D};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
@ -1699,10 +1768,14 @@ impl Tensor {
|
||||
///
|
||||
/// let c = a.reshape((3, 2))?;
|
||||
/// assert_eq!(c.shape().dims(), &[3, 2]);
|
||||
///
|
||||
/// let c = a.reshape((2, (), 1))?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 3, 1]);
|
||||
///
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
||||
let shape = shape.into();
|
||||
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
let shape = s.into_shape(self.elem_count())?;
|
||||
if shape.elem_count() != self.elem_count() {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
@ -1836,6 +1909,34 @@ impl Tensor {
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
@ -33,6 +33,44 @@ fn tensor_2d(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clamp(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let tensor = tensor.clamp(1.5, 6.2)?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<f32>()?,
|
||||
[[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
|
||||
[
|
||||
[-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
|
||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||
[
|
||||
[-0.004, 0.8413, 3.9999, -0.046, 0.3457],
|
||||
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
|
||||
[
|
||||
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn binary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor1 = Tensor::new(data, device)?;
|
||||
@ -877,6 +915,14 @@ fn broadcasting(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn randn(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||
assert_eq!(tensor.dims(), [5, 3]);
|
||||
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||
assert_eq!(tensor.dims(), [5, 3]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
@ -889,6 +935,7 @@ test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
@ -899,6 +946,8 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
test_device!(randn, randn_cpu, randn_gpu);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
|
@ -11,8 +11,8 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
@ -8,13 +8,9 @@ use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufReader, Read};
|
||||
|
||||
fn read_u32<T: Read>(reader: &mut T) -> Result<u32> {
|
||||
let mut b = vec![0u8; 4];
|
||||
reader.read_exact(&mut b)?;
|
||||
let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| {
|
||||
(s + basis * u64::from(x), basis * 256)
|
||||
});
|
||||
Ok(result as u32)
|
||||
fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> {
|
||||
use byteorder::ReadBytesExt;
|
||||
reader.read_u32::<byteorder::BigEndian>()
|
||||
}
|
||||
|
||||
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
|
||||
|
@ -11,19 +11,19 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -50,7 +50,7 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
|
||||
|
44
candle-examples/examples/bert/README.md
Normal file
@ -0,0 +1,44 @@
|
||||
# candle-bert
|
||||
|
||||
Bert is a general large language model. In this example it can be used for two
|
||||
different tasks:
|
||||
- Compute sentence embeddings for a prompt.
|
||||
- Compute similarities between a set of sentences.
|
||||
|
||||
|
||||
## Sentence embeddings
|
||||
|
||||
Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||
are downloaded from the hub on the first run.
|
||||
|
||||
```bash
|
||||
cargo run --example bert --release -- --prompt "Here is a test sentence"
|
||||
|
||||
> [[[ 0.0798, -0.0665, -0.0247, ..., -0.1082, -0.1000, -0.2751],
|
||||
> [ 0.4218, 0.2690, 0.2740, ..., 0.3889, 1.3503, 0.9908],
|
||||
> [ 0.0466, 0.3041, -0.1143, ..., 0.4427, 0.6926, -0.1515],
|
||||
> ...
|
||||
> [ 0.3396, 0.4320, -0.4408, ..., 0.9212, 0.2331, -0.6777],
|
||||
> [ 0.2789, 0.7539, 0.4306, ..., -0.0095, 0.3375, -1.7529],
|
||||
> [ 0.6737, 0.7882, 0.0548, ..., 0.1836, 0.7299, -0.6617]]]
|
||||
> Tensor[[1, 7, 384], f32]
|
||||
```
|
||||
|
||||
## Similarities
|
||||
|
||||
In this example, Bert is used to compute the sentence embeddings for a set of
|
||||
sentences (hardcoded in the examples). Then cosine similarities are computed for
|
||||
each sentence pair and they are reported by decreasing values, hence the first
|
||||
reported pair contains the two sentences that have the highest similarity score.
|
||||
The sentence embeddings are computed using average pooling through all the
|
||||
sentence tokens, including some potential padding.
|
||||
|
||||
```bash
|
||||
cargo run --example bert --release
|
||||
|
||||
> score: 0.85 'The new movie is awesome' 'The new movie is so great'
|
||||
> score: 0.61 'The cat sits outside' 'The cat plays in the garden'
|
||||
> score: 0.52 'I love pasta' 'Do you like pizza?'
|
||||
> score: 0.23 'The new movie is awesome' 'Do you like pizza?'
|
||||
> score: 0.22 'I love pasta' 'The new movie is awesome'
|
||||
```
|
@ -3,14 +3,13 @@ extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
mod model;
|
||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::Tensor;
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use model::{BertModel, Config, DTYPE};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
|
19
candle-examples/examples/bigcode/README.md
Normal file
@ -0,0 +1,19 @@
|
||||
# candle-starcoder: code generation model
|
||||
|
||||
[StarCoder/BigCode](https://huggingface.co/bigcode/starcoderbase-1b) is a LLM
|
||||
model specialized to code generation. The initial model was trained on 80
|
||||
programming languages.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
cargo run --example bigcode --release -- --prompt "fn fact(n: u64) -> u64 "
|
||||
|
||||
> fn fact(n: u64) -> u64 {
|
||||
> if n == 0 {
|
||||
> 1
|
||||
> } else {
|
||||
> n * fact(n - 1)
|
||||
> }
|
||||
> }
|
||||
```
|
@ -7,8 +7,7 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
mod model;
|
||||
use model::{Config, GPTBigCode};
|
||||
use candle_transformers::models::bigcode::{Config, GPTBigCode};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
@ -29,9 +28,10 @@ impl TextGeneration {
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp);
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
@ -95,6 +95,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -150,7 +154,14 @@ fn main() -> Result<()> {
|
||||
let model = GPTBigCode::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
||||
|
19
candle-examples/examples/dinov2/README.md
Normal file
@ -0,0 +1,19 @@
|
||||
# candle-dinov2
|
||||
|
||||
[DINOv2](https://github.com/facebookresearch/dinov2) is a computer vision model.
|
||||
In this example, it is used as an ImageNet classifier: the model returns the
|
||||
probability for the image to belong to each of the 1000 ImageNet categories.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
cargo run --example dinov2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
> mountain bike, all-terrain bike, off-roader: 43.67%
|
||||
> bicycle-built-for-two, tandem bicycle, tandem: 33.20%
|
||||
> crash helmet : 13.23%
|
||||
> unicycle, monocycle : 2.44%
|
||||
> maillot : 2.42%
|
||||
```
|
||||
|
||||

|
@ -9,285 +9,10 @@ extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::dinov2;
|
||||
|
||||
const IMG_SIZE: usize = 518;
|
||||
const PATCH_SIZE: usize = 14;
|
||||
const NUM_CLASSES: usize = 1000;
|
||||
|
||||
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
num_heads: usize,
|
||||
scale: f64,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
proj_bias: bool,
|
||||
) -> Result<Self> {
|
||||
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
|
||||
let scale = 1. / ((dim / num_heads) as f64).sqrt();
|
||||
Ok(Self {
|
||||
qkv,
|
||||
proj,
|
||||
num_heads,
|
||||
scale,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, c) = xs.dims3()?;
|
||||
let qkv = self
|
||||
.qkv
|
||||
.forward(xs)?
|
||||
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
|
||||
.transpose(1, 2)? // 02134
|
||||
.transpose(0, 1)? // 20134
|
||||
.transpose(2, 3)?; // 20314
|
||||
let q = (qkv.i(0)? * self.scale)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
|
||||
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
|
||||
self.proj.forward(&attn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LayerScale {
|
||||
gamma: Tensor,
|
||||
}
|
||||
|
||||
impl LayerScale {
|
||||
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
|
||||
let gamma = vb.get(dim, "gamma")?;
|
||||
Ok(Self { gamma })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerScale {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.broadcast_mul(&self.gamma)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Mlp {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
|
||||
let out_features = in_features;
|
||||
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
|
||||
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
|
||||
Ok(Self { fc1, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?.gelu()?;
|
||||
self.fc2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Block {
|
||||
norm1: LayerNorm,
|
||||
attn: Attention,
|
||||
ls1: LayerScale,
|
||||
norm2: LayerNorm,
|
||||
mlp: Mlp,
|
||||
ls2: LayerScale,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
|
||||
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
|
||||
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
|
||||
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
|
||||
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
|
||||
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
ls1,
|
||||
norm2,
|
||||
mlp,
|
||||
ls2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self
|
||||
.ls1
|
||||
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = self
|
||||
.ls2
|
||||
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
proj: candle_nn::Conv2d,
|
||||
patch_size: (usize, usize),
|
||||
num_patches: usize,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
img_size: usize,
|
||||
patch_size: usize,
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
) -> Result<Self> {
|
||||
let config = candle_nn::Conv2dConfig {
|
||||
stride: patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
|
||||
let num_patches = (img_size / patch_size) * (img_size / patch_size);
|
||||
Ok(Self {
|
||||
proj,
|
||||
patch_size: (patch_size, patch_size),
|
||||
num_patches,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _c, h, w) = xs.dims4()?;
|
||||
let (patch_h, patch_w) = self.patch_size;
|
||||
if (h % patch_h) != 0 {
|
||||
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
|
||||
}
|
||||
if (w % patch_w) != 0 {
|
||||
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
|
||||
}
|
||||
let xs = self.proj.forward(xs)?;
|
||||
let (b, c, h, w) = xs.dims4()?;
|
||||
// flatten embeddings.
|
||||
xs.reshape((b, c, h * w))?.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DinoVisionTransformer {
|
||||
patch_embed: PatchEmbed,
|
||||
cls_token: Tensor,
|
||||
pos_embed: Tensor,
|
||||
blocks: Vec<Block>,
|
||||
norm: LayerNorm,
|
||||
head: Linear,
|
||||
}
|
||||
|
||||
impl DinoVisionTransformer {
|
||||
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let patch_embed =
|
||||
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
|
||||
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
|
||||
let num_tokens = 1;
|
||||
let pos_embed = vb.get(
|
||||
(1, patch_embed.num_patches + num_tokens, embed_dim),
|
||||
"pos_embed",
|
||||
)?;
|
||||
let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
|
||||
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
|
||||
let vb_b = vb.pp("blocks");
|
||||
let blocks = (0..depth)
|
||||
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
cls_token,
|
||||
pos_embed,
|
||||
blocks,
|
||||
norm,
|
||||
head,
|
||||
})
|
||||
}
|
||||
|
||||
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
|
||||
let npatch = xs.dim(1)? - 1;
|
||||
let n = self.pos_embed.dim(1)? - 1;
|
||||
let sqrt_n = (n as f64).sqrt();
|
||||
if npatch == n && w == h {
|
||||
return Ok(xs.clone());
|
||||
}
|
||||
let class_pos_embed = self.pos_embed.i((.., ..1))?;
|
||||
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
|
||||
let dim = xs.dim(D::Minus1)?;
|
||||
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
|
||||
let patch_pos_embed = patch_pos_embed
|
||||
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
|
||||
.transpose(2, 3)?
|
||||
.transpose(1, 2)?;
|
||||
// This uses bicubic interpolation in the original implementation.
|
||||
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
|
||||
let el_count = patch_pos_embed.shape().elem_count();
|
||||
let patch_pos_embed =
|
||||
patch_pos_embed
|
||||
.transpose(1, 2)?
|
||||
.transpose(2, 3)?
|
||||
.reshape((1, el_count / dim, dim))?;
|
||||
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
|
||||
}
|
||||
|
||||
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _nc, w, h) = xs.dims4()?;
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
|
||||
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DinoVisionTransformer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.prepare_tokens_with_mask(xs)?;
|
||||
for blk in self.blocks.iter() {
|
||||
xs = blk.forward(&xs)?
|
||||
}
|
||||
let xs = self.norm.forward(&xs)?;
|
||||
let xs_norm_clstoken = xs.i((.., 0))?;
|
||||
let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
|
||||
let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
|
||||
self.head.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
|
||||
DinoVisionTransformer::new(vb, 12, 384, 6)
|
||||
}
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
@ -320,7 +45,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let model = vit_small(vb)?;
|
||||
let model = dinov2::vit_small(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
|
@ -8,340 +8,11 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use nn::{Module, VarBuilder};
|
||||
|
||||
// Based on the Python version from torchvision.
|
||||
// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct MBConvConfig {
|
||||
expand_ratio: f64,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
input_channels: usize,
|
||||
out_channels: usize,
|
||||
num_layers: usize,
|
||||
}
|
||||
|
||||
fn make_divisible(v: f64, divisor: usize) -> usize {
|
||||
let min_value = divisor;
|
||||
let new_v = usize::max(
|
||||
min_value,
|
||||
(v + divisor as f64 * 0.5) as usize / divisor * divisor,
|
||||
);
|
||||
if (new_v as f64) < 0.9 * v {
|
||||
new_v + divisor
|
||||
} else {
|
||||
new_v
|
||||
}
|
||||
}
|
||||
|
||||
fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
|
||||
let bneck_conf = |e, k, s, i, o, n| {
|
||||
let input_channels = make_divisible(i as f64 * width_mult, 8);
|
||||
let out_channels = make_divisible(o as f64 * width_mult, 8);
|
||||
let num_layers = (n as f64 * depth_mult).ceil() as usize;
|
||||
MBConvConfig {
|
||||
expand_ratio: e,
|
||||
kernel: k,
|
||||
stride: s,
|
||||
input_channels,
|
||||
out_channels,
|
||||
num_layers,
|
||||
}
|
||||
};
|
||||
vec![
|
||||
bneck_conf(1., 3, 1, 32, 16, 1),
|
||||
bneck_conf(6., 3, 2, 16, 24, 2),
|
||||
bneck_conf(6., 5, 2, 24, 40, 2),
|
||||
bneck_conf(6., 3, 2, 40, 80, 3),
|
||||
bneck_conf(6., 5, 1, 80, 112, 3),
|
||||
bneck_conf(6., 5, 2, 112, 192, 4),
|
||||
bneck_conf(6., 3, 1, 192, 320, 1),
|
||||
]
|
||||
}
|
||||
|
||||
impl MBConvConfig {
|
||||
fn b0() -> Vec<Self> {
|
||||
bneck_confs(1.0, 1.0)
|
||||
}
|
||||
fn b1() -> Vec<Self> {
|
||||
bneck_confs(1.0, 1.1)
|
||||
}
|
||||
fn b2() -> Vec<Self> {
|
||||
bneck_confs(1.1, 1.2)
|
||||
}
|
||||
fn b3() -> Vec<Self> {
|
||||
bneck_confs(1.2, 1.4)
|
||||
}
|
||||
fn b4() -> Vec<Self> {
|
||||
bneck_confs(1.4, 1.8)
|
||||
}
|
||||
fn b5() -> Vec<Self> {
|
||||
bneck_confs(1.6, 2.2)
|
||||
}
|
||||
fn b6() -> Vec<Self> {
|
||||
bneck_confs(1.8, 2.6)
|
||||
}
|
||||
fn b7() -> Vec<Self> {
|
||||
bneck_confs(2.0, 3.1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Conv2D with same padding.
|
||||
#[derive(Debug)]
|
||||
struct Conv2DSame {
|
||||
conv2d: nn::Conv2d,
|
||||
s: usize,
|
||||
k: usize,
|
||||
}
|
||||
|
||||
impl Conv2DSame {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
i: usize,
|
||||
o: usize,
|
||||
k: usize,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
bias: bool,
|
||||
) -> Result<Self> {
|
||||
let conv_config = nn::Conv2dConfig {
|
||||
stride,
|
||||
groups,
|
||||
..Default::default()
|
||||
};
|
||||
let conv2d = if bias {
|
||||
nn::conv2d(i, o, k, conv_config, vb)?
|
||||
} else {
|
||||
nn::conv2d_no_bias(i, o, k, conv_config, vb)?
|
||||
};
|
||||
Ok(Self {
|
||||
conv2d,
|
||||
s: stride,
|
||||
k,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Conv2DSame {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let s = self.s;
|
||||
let k = self.k;
|
||||
let (_, _, ih, iw) = xs.dims4()?;
|
||||
let oh = (ih + s - 1) / s;
|
||||
let ow = (iw + s - 1) / s;
|
||||
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
|
||||
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
|
||||
if pad_h > 0 || pad_w > 0 {
|
||||
let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
|
||||
let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
|
||||
self.conv2d.forward(&xs)
|
||||
} else {
|
||||
self.conv2d.forward(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConvNormActivation {
|
||||
conv2d: Conv2DSame,
|
||||
bn2d: nn::BatchNorm,
|
||||
activation: bool,
|
||||
}
|
||||
|
||||
impl ConvNormActivation {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
i: usize,
|
||||
o: usize,
|
||||
k: usize,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
|
||||
let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
|
||||
Ok(Self {
|
||||
conv2d,
|
||||
bn2d,
|
||||
activation: true,
|
||||
})
|
||||
}
|
||||
|
||||
fn no_activation(self) -> Self {
|
||||
Self {
|
||||
activation: false,
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ConvNormActivation {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.conv2d.forward(xs)?;
|
||||
let xs = self.bn2d.forward(&xs)?;
|
||||
if self.activation {
|
||||
swish(&xs)
|
||||
} else {
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SqueezeExcitation {
|
||||
fc1: Conv2DSame,
|
||||
fc2: Conv2DSame,
|
||||
}
|
||||
|
||||
impl SqueezeExcitation {
|
||||
fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
|
||||
let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
|
||||
let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
|
||||
Ok(Self { fc1, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SqueezeExcitation {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
// equivalent to adaptive_avg_pool2d([1, 1])
|
||||
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
|
||||
let xs = self.fc1.forward(&xs)?;
|
||||
let xs = swish(&xs)?;
|
||||
let xs = self.fc2.forward(&xs)?;
|
||||
let xs = nn::ops::sigmoid(&xs)?;
|
||||
residual.broadcast_mul(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MBConv {
|
||||
expand_cna: Option<ConvNormActivation>,
|
||||
depthwise_cna: ConvNormActivation,
|
||||
squeeze_excitation: SqueezeExcitation,
|
||||
project_cna: ConvNormActivation,
|
||||
config: MBConvConfig,
|
||||
}
|
||||
|
||||
impl MBConv {
|
||||
fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
|
||||
let vb = vb.pp("block");
|
||||
let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
|
||||
let expand_cna = if exp != c.input_channels {
|
||||
Some(ConvNormActivation::new(
|
||||
vb.pp("0"),
|
||||
c.input_channels,
|
||||
exp,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let start_index = if expand_cna.is_some() { 1 } else { 0 };
|
||||
let depthwise_cna =
|
||||
ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
|
||||
let squeeze_channels = usize::max(1, c.input_channels / 4);
|
||||
let squeeze_excitation =
|
||||
SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
|
||||
let project_cna =
|
||||
ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
|
||||
.no_activation();
|
||||
Ok(Self {
|
||||
expand_cna,
|
||||
depthwise_cna,
|
||||
squeeze_excitation,
|
||||
project_cna,
|
||||
config: c,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MBConv {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let use_res_connect =
|
||||
self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
|
||||
let ys = match &self.expand_cna {
|
||||
Some(expand_cna) => expand_cna.forward(xs)?,
|
||||
None => xs.clone(),
|
||||
};
|
||||
let ys = self.depthwise_cna.forward(&ys)?;
|
||||
let ys = self.squeeze_excitation.forward(&ys)?;
|
||||
let ys = self.project_cna.forward(&ys)?;
|
||||
if use_res_connect {
|
||||
ys + xs
|
||||
} else {
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn swish(s: &Tensor) -> Result<Tensor> {
|
||||
s * nn::ops::sigmoid(s)?
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EfficientNet {
|
||||
init_cna: ConvNormActivation,
|
||||
blocks: Vec<MBConv>,
|
||||
final_cna: ConvNormActivation,
|
||||
classifier: nn::Linear,
|
||||
}
|
||||
|
||||
impl EfficientNet {
|
||||
fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
|
||||
let f_p = p.pp("features");
|
||||
let first_in_c = configs[0].input_channels;
|
||||
let last_out_c = configs.last().unwrap().out_channels;
|
||||
let final_out_c = 4 * last_out_c;
|
||||
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
|
||||
let nconfigs = configs.len();
|
||||
let mut blocks = vec![];
|
||||
for (index, cnf) in configs.into_iter().enumerate() {
|
||||
let f_p = f_p.pp(index + 1);
|
||||
for r_index in 0..cnf.num_layers {
|
||||
let cnf = if r_index == 0 {
|
||||
cnf
|
||||
} else {
|
||||
MBConvConfig {
|
||||
input_channels: cnf.out_channels,
|
||||
stride: 1,
|
||||
..cnf
|
||||
}
|
||||
};
|
||||
blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
|
||||
}
|
||||
}
|
||||
let final_cna =
|
||||
ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
|
||||
let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
|
||||
Ok(Self {
|
||||
init_cna,
|
||||
blocks,
|
||||
final_cna,
|
||||
classifier,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EfficientNet {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.init_cna.forward(xs)?;
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
}
|
||||
let xs = self.final_cna.forward(&xs)?;
|
||||
// Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
|
||||
let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
|
||||
self.classifier.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
B0,
|
||||
|
3
candle-examples/examples/falcon/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# candle-falcon
|
||||
|
||||
Falcon is a general large language model.
|
@ -14,8 +14,7 @@ use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod model;
|
||||
use model::{Config, Falcon};
|
||||
use candle_transformers::models::falcon::{Config, Falcon};
|
||||
|
||||
struct TextGeneration {
|
||||
model: Falcon,
|
||||
@ -26,17 +25,25 @@ struct TextGeneration {
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
struct GenerationOptions {
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
fn new(
|
||||
model: Falcon,
|
||||
tokenizer: Tokenizer,
|
||||
generation_options: GenerationOptions,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
device: &Device,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp);
|
||||
let logits_processor =
|
||||
LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p);
|
||||
let repeat_penalty = generation_options.repeat_penalty;
|
||||
let repeat_last_n = generation_options.repeat_last_n;
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
@ -119,6 +126,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -186,15 +197,14 @@ fn main() -> Result<()> {
|
||||
let model = Falcon::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
&device,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
);
|
||||
let generation_options = GenerationOptions {
|
||||
temp: args.temperature,
|
||||
top_p: args.top_p,
|
||||
repeat_penalty: args.repeat_penalty,
|
||||
repeat_last_n: args.repeat_last_n,
|
||||
};
|
||||
let mut pipeline =
|
||||
TextGeneration::new(model, tokenizer, generation_options, args.seed, &device);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
mod model;
|
||||
use candle_transformers::models::llama as model;
|
||||
use model::{Config, Llama, LlamaConfig};
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -43,6 +42,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -194,7 +197,7 @@ fn main() -> Result<()> {
|
||||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
|
@ -27,6 +27,10 @@ struct InferenceCmd {
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
#[arg(long, default_value = "")]
|
||||
prompt: String,
|
||||
|
||||
@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> {
|
||||
None => {
|
||||
let cmd = InferenceCmd {
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
prompt: "".to_string(),
|
||||
config: None,
|
||||
model_id: "karpathy/tinyllamas".to_string(),
|
||||
@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
||||
let mut index_pos = 0;
|
||||
|
||||
print!("{}", args.prompt);
|
||||
|
@ -89,6 +89,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -222,7 +226,7 @@ fn main() -> Result<()> {
|
||||
.to_vec();
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
|
@ -13,7 +13,6 @@ extern crate accelerate_src;
|
||||
mod encodec_model;
|
||||
mod musicgen_model;
|
||||
mod nn;
|
||||
mod t5_model;
|
||||
|
||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
|
||||
@ -78,7 +77,7 @@ fn main() -> Result<()> {
|
||||
let model = model.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||
let config = GenConfig::small();
|
||||
let model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt.as_str(), true)
|
||||
|
@ -1,9 +1,10 @@
|
||||
use crate::{encodec_model, t5_model};
|
||||
use crate::encodec_model;
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
||||
VarBuilder,
|
||||
};
|
||||
use candle_transformers::models::t5;
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -370,7 +371,7 @@ impl MusicgenForCausalLM {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MusicgenForConditionalGeneration {
|
||||
pub text_encoder: crate::t5_model::T5EncoderModel,
|
||||
pub text_encoder: t5::T5EncoderModel,
|
||||
pub audio_encoder: crate::encodec_model::EncodecModel,
|
||||
pub decoder: MusicgenForCausalLM,
|
||||
cfg: GenConfig,
|
||||
@ -379,7 +380,7 @@ pub struct MusicgenForConditionalGeneration {
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct GenConfig {
|
||||
musicgen: Config,
|
||||
t5: crate::t5_model::Config,
|
||||
t5: t5::Config,
|
||||
encodec: crate::encodec_model::Config,
|
||||
}
|
||||
|
||||
@ -387,7 +388,7 @@ impl GenConfig {
|
||||
pub fn small() -> Self {
|
||||
Self {
|
||||
musicgen: Config::musicgen_small(),
|
||||
t5: t5_model::Config::musicgen_small(),
|
||||
t5: t5::Config::musicgen_small(),
|
||||
encodec: encodec_model::Config::musicgen_small(),
|
||||
}
|
||||
}
|
||||
@ -399,7 +400,7 @@ impl MusicgenForConditionalGeneration {
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||
let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||
let audio_encoder =
|
||||
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
|
||||
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
||||
|
@ -1,397 +0,0 @@
|
||||
// T5 Text Encoder
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
d_model: usize,
|
||||
d_kv: usize,
|
||||
d_ff: usize,
|
||||
num_layers: usize,
|
||||
num_decoder_layers: Option<usize>,
|
||||
num_heads: usize,
|
||||
relative_attention_num_buckets: usize,
|
||||
relative_attention_max_distance: usize,
|
||||
dropout_rate: f64,
|
||||
layer_norm_epsilon: f64,
|
||||
initializer_factor: f64,
|
||||
feed_forward_proj: Activation,
|
||||
is_decoder: bool,
|
||||
is_encoder_decoder: bool,
|
||||
use_cache: bool,
|
||||
pad_token_id: usize,
|
||||
eos_token_id: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab_size: 32128,
|
||||
d_model: 512,
|
||||
d_kv: 64,
|
||||
d_ff: 2048,
|
||||
num_layers: 6,
|
||||
num_decoder_layers: None,
|
||||
num_heads: 8,
|
||||
relative_attention_num_buckets: 32,
|
||||
relative_attention_max_distance: 128,
|
||||
dropout_rate: 0.1,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
initializer_factor: 1.0,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
use_cache: true,
|
||||
pad_token_id: 0,
|
||||
eos_token_id: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184
|
||||
pub fn musicgen_small() -> Self {
|
||||
Self {
|
||||
d_ff: 3072,
|
||||
d_kv: 64,
|
||||
d_model: 768,
|
||||
dropout_rate: 0.1,
|
||||
eos_token_id: 1,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
initializer_factor: 1.0,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
num_decoder_layers: Some(12),
|
||||
num_heads: 12,
|
||||
num_layers: 12,
|
||||
pad_token_id: 0,
|
||||
relative_attention_max_distance: 128,
|
||||
relative_attention_num_buckets: 32,
|
||||
use_cache: true,
|
||||
vocab_size: 32128,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5LayerNorm {
|
||||
weight: Tensor,
|
||||
variance_epsilon: f64,
|
||||
}
|
||||
|
||||
impl T5LayerNorm {
|
||||
fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(h, "weight")?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
variance_epsilon: eps,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let dtype = xs.dtype();
|
||||
let xs_f32 = xs.to_dtype(DType::F32)?;
|
||||
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
|
||||
let xs = xs.to_dtype(dtype)?;
|
||||
let xs = xs.broadcast_mul(&self.weight)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5DenseActDense {
|
||||
wi: Linear,
|
||||
wo: Linear,
|
||||
act: Activation,
|
||||
}
|
||||
|
||||
impl T5DenseActDense {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
|
||||
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
||||
Ok(Self {
|
||||
wi,
|
||||
wo,
|
||||
act: Activation::Relu,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.wi.forward(xs)?;
|
||||
let xs = self.act.forward(&xs)?;
|
||||
let xs = self.wo.forward(&xs)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5LayerFF {
|
||||
dense_relu_dense: T5DenseActDense,
|
||||
layer_norm: T5LayerNorm,
|
||||
}
|
||||
|
||||
impl T5LayerFF {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
// is_gated_act is not supported.
|
||||
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
dense_relu_dense,
|
||||
layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let ys = self.layer_norm.forward(xs)?;
|
||||
let ys = self.dense_relu_dense.forward(&ys)?;
|
||||
let xs = (xs + ys)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5Attention {
|
||||
q: Linear,
|
||||
k: Linear,
|
||||
v: Linear,
|
||||
o: Linear,
|
||||
n_heads: usize,
|
||||
d_kv: usize,
|
||||
relative_attention_bias: Option<Embedding>,
|
||||
relative_attention_num_buckets: usize,
|
||||
relative_attention_max_distance: usize,
|
||||
inner_dim: usize,
|
||||
}
|
||||
|
||||
impl T5Attention {
|
||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
|
||||
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
|
||||
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
|
||||
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
|
||||
let relative_attention_bias = if h {
|
||||
let emb = embedding(
|
||||
cfg.relative_attention_num_buckets,
|
||||
cfg.num_heads,
|
||||
vb.pp("relative_attention_bias"),
|
||||
)?;
|
||||
Some(emb)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
n_heads: cfg.num_heads,
|
||||
d_kv: cfg.d_kv,
|
||||
relative_attention_bias,
|
||||
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
|
||||
relative_attention_max_distance: cfg.relative_attention_max_distance,
|
||||
inner_dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Apply the mask(s)?
|
||||
// TODO: kv caching.
|
||||
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
||||
let q = self.q.forward(xs)?;
|
||||
let k = self.k.forward(xs)?;
|
||||
let v = self.v.forward(xs)?;
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let scores = q.matmul(&k.t()?)?;
|
||||
|
||||
let scores = match &self.relative_attention_bias {
|
||||
None => scores,
|
||||
Some(relative_attention_bias) => {
|
||||
let query_length = seq_len;
|
||||
let key_length = seq_len;
|
||||
// This only handles the bidirectional case.
|
||||
let num_buckets = self.relative_attention_num_buckets / 2;
|
||||
let relative_position = (0..query_length as u32)
|
||||
.map(|i| {
|
||||
(0..key_length as u32)
|
||||
.map(|j| {
|
||||
if i < j {
|
||||
j - i + num_buckets as u32
|
||||
} else {
|
||||
i - j
|
||||
}
|
||||
})
|
||||
.collect::<Vec<u32>>()
|
||||
})
|
||||
.collect::<Vec<Vec<_>>>();
|
||||
let relative_buckets = Tensor::new(relative_position, q.device())?;
|
||||
let position_bias = relative_attention_bias
|
||||
.forward(&relative_buckets)?
|
||||
.permute((2, 0, 1))?
|
||||
.unsqueeze(0)?;
|
||||
(scores + position_bias)?
|
||||
// TODO: position_bias_masked?
|
||||
}
|
||||
};
|
||||
|
||||
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let attn_output = attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, seq_len, self.inner_dim))?;
|
||||
let attn_output = self.o.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5LayerSelfAttention {
|
||||
self_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
}
|
||||
|
||||
impl T5LayerSelfAttention {
|
||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let normed_xs = self.layer_norm.forward(xs)?;
|
||||
let ys = self.self_attention.forward(&normed_xs)?;
|
||||
let ys = (xs + ys)?;
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5LayerCrossAttention {}
|
||||
|
||||
impl T5LayerCrossAttention {
|
||||
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5Block {
|
||||
self_attn: T5LayerSelfAttention,
|
||||
cross_attn: Option<T5LayerCrossAttention>,
|
||||
ff: T5LayerFF,
|
||||
}
|
||||
|
||||
impl T5Block {
|
||||
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = vb.pp("layer");
|
||||
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
|
||||
let cross_attn = if cfg.is_decoder {
|
||||
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
|
||||
let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
cross_attn,
|
||||
ff,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.self_attn.forward(xs)?;
|
||||
// TODO: clamp for f16?
|
||||
if let Some(cross_attn) = &self.cross_attn {
|
||||
xs = cross_attn.forward(&xs)?;
|
||||
// TODO: clamp for f16?
|
||||
}
|
||||
let xs = self.ff.forward(&xs)?;
|
||||
// TODO: clamp for f16?
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5Stack {
|
||||
block: Vec<T5Block>,
|
||||
shared: Arc<Embedding>,
|
||||
final_layer_norm: T5LayerNorm,
|
||||
}
|
||||
|
||||
impl T5Stack {
|
||||
fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
|
||||
let block = (0..cfg.num_layers)
|
||||
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let final_layer_norm = T5LayerNorm::load(
|
||||
cfg.d_model,
|
||||
cfg.layer_norm_epsilon,
|
||||
vb.pp("final_layer_norm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
block,
|
||||
shared: shared.clone(),
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
|
||||
|
||||
let mut hidden_states = input_embeds;
|
||||
for block in self.block.iter() {
|
||||
hidden_states = block.forward(&hidden_states)?
|
||||
}
|
||||
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct T5EncoderModel {
|
||||
shared: Arc<Embedding>,
|
||||
encoder: T5Stack,
|
||||
}
|
||||
|
||||
impl T5EncoderModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared = Arc::new(shared);
|
||||
let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
|
||||
Ok(Self { shared, encoder })
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let encoder_outputs = self.encoder.forward(input_ids)?;
|
||||
Ok(encoder_outputs)
|
||||
}
|
||||
}
|
17
candle-examples/examples/quantized-t5/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# candle-quantized-t5
|
||||
|
||||
This example uses a quantized version of the t5 model.
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle."
|
||||
...
|
||||
Eine schöne Kerze.
|
||||
```
|
||||
|
||||
The weight file is automatically retrieved from the hub. It is also possible to
|
||||
generate quantized weight files from the original safetensors file by using the
|
||||
`tensor-tools` command line utility via:
|
||||
|
||||
```bash
|
||||
cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
```
|
214
candle-examples/examples/quantized-t5/main.rs
Normal file
@ -0,0 +1,214 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use candle_transformers::models::quantized_t5 as t5;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
T5Small,
|
||||
FlanT5Small,
|
||||
FlanT5Base,
|
||||
FlanT5Large,
|
||||
FlanT5Xl,
|
||||
FlanT5Xxl,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model repository to use on the HuggingFace hub.
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
// Enable/disable decoding.
|
||||
#[arg(long, default_value = "false")]
|
||||
disable_cache: bool,
|
||||
|
||||
/// Use this prompt, otherwise compute sentence similarities.
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "t5-small")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
struct T5ModelBuilder {
|
||||
device: Device,
|
||||
config: t5::Config,
|
||||
weights_filename: PathBuf,
|
||||
}
|
||||
|
||||
impl T5ModelBuilder {
|
||||
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||
let device = Device::Cpu;
|
||||
let default_model = "lmz/candle-quantized-t5".to_string();
|
||||
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, "main".to_string()),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = match args.which {
|
||||
Which::T5Small => api.get("config.json")?,
|
||||
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
|
||||
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
|
||||
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
|
||||
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
|
||||
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
|
||||
};
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = match &args.weight_file {
|
||||
Some(filename) => std::path::PathBuf::from(filename),
|
||||
None => match args.which {
|
||||
Which::T5Small => api.get("model.gguf")?,
|
||||
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
|
||||
Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
|
||||
Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
|
||||
Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
|
||||
Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
|
||||
},
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
config.use_cache = !args.disable_cache;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
Ok((
|
||||
Self {
|
||||
device,
|
||||
config,
|
||||
weights_filename,
|
||||
},
|
||||
tokenizer,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||
let device = &builder.device;
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mut model = builder.build_model()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
|
||||
let encoder_output = model.encode(&input_token_ids)?;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for index in 0.. {
|
||||
if output_token_ids.len() > 512 {
|
||||
break;
|
||||
}
|
||||
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
|
||||
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||
} else {
|
||||
let last_token = *output_token_ids.last().unwrap();
|
||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||
};
|
||||
let logits = model
|
||||
.decode(&decoder_token_ids, &encoder_output)?
|
||||
.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&output_token_ids[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits)?;
|
||||
if next_token_id as usize == builder.config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
output_token_ids.push(next_token_id);
|
||||
if let Some(text) = tokenizer.id_to_token(next_token_id) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
output_token_ids.len(),
|
||||
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
37
candle-examples/examples/quantized/README.md
Normal file
@ -0,0 +1,37 @@
|
||||
# candle-quantized-llama: Fast Inference of quantized LLaMA models
|
||||
|
||||
This example provides a quantized LLaMA model similar to
|
||||
[llama.cpp](https://github.com/ggerganov/llama.cpp). This is based on candle
|
||||
built-in quantization methods. Supported features include:
|
||||
|
||||
- 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit integer quantization support.
|
||||
- SIMD optimizations on Apple Silicon and x86.
|
||||
- Support using the `gguf` and `ggml` file formats.
|
||||
|
||||
The weights are automatically downloaded for you from the [HuggingFace
|
||||
Hub](https://huggingface.co/) on the first run. There are various command line
|
||||
flags to use local files instead, run with `--help` to learn about them.
|
||||
|
||||

|
||||
|
||||
## Running some example.
|
||||
|
||||
```bash
|
||||
cargo run --example quantized --release -- --prompt "The best thing about coding in rust is "
|
||||
|
||||
> avx: true, neon: false, simd128: false, f16c: true
|
||||
> temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
> loaded 291 tensors (3.79GB) in 2.17s
|
||||
> params: HParams { n_vocab: 32000, n_embd: 4096, n_mult: 256, n_head: 32, n_layer: 32, n_rot: 128, ftype: 2 }
|
||||
> The best thing about coding in rust is 1.) that I don’t need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines.
|
||||
```
|
||||
|
||||
## Command-line flags
|
||||
|
||||
Run with `--help` to see all options.
|
||||
|
||||
- `--which`: specify the model to use, e.g. `7b`, `13-chat`, `7b-code`.
|
||||
- `--prompt interactive`: interactive mode where multiple prompts can be
|
||||
entered.
|
||||
- `--model mymodelfile.gguf`: use a local model file rather than getting one
|
||||
from the hub.
|
BIN
candle-examples/examples/quantized/assets/aoc.gif
Normal file
After Width: | Height: | Size: 119 KiB |
@ -12,7 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
|
||||
mod model;
|
||||
use candle_transformers::models::quantized_llama as model;
|
||||
use model::ModelWeights;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
@ -71,6 +71,10 @@ struct Args {
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -310,7 +314,7 @@ fn main() -> anyhow::Result<()> {
|
||||
prompt_tokens
|
||||
};
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = {
|
||||
|
40
candle-examples/examples/segment-anything/README.md
Normal file
@ -0,0 +1,40 @@
|
||||
# candle-segment-anything: Segment-Anything Model
|
||||
|
||||
This example is based on Meta AI [Segment-Anything
|
||||
Model](https://github.com/facebookresearch/segment-anything). This model
|
||||
provides a robust and fast image segmentation pipeline that can be tweaked via
|
||||
some prompting (requesting some points to be in the target mask, requesting some
|
||||
points to be part of the background so _not_ in the target mask, specifying some
|
||||
bounding box).
|
||||
|
||||
The default backbone can be replaced by the smaller and faster TinyViT model
|
||||
based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
|
||||
|
||||
## Running some example.
|
||||
|
||||
```bash
|
||||
cargo run --example segment-anything --release -- \
|
||||
--image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
--use-tiny
|
||||
--point-x 0.4
|
||||
--point-y 0.3
|
||||
```
|
||||
|
||||
Running this command generates a `sam_merged.jpg` file containing the original
|
||||
image with a blue overlay of the selected mask. The red dot represents the prompt
|
||||
specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part
|
||||
of the target mask.
|
||||
|
||||
The values used for `--point-x` and `--point-y` should be between 0 and 1 and
|
||||
are proportional to the image dimension, i.e. use 0.5 for the image center.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### Command-line flags
|
||||
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
|
||||
one.
|
||||
- `--point-x`, `--point-y`: specifies the location of the target point.
|
||||
- `--threshold`: sets the threshold value to be part of the mask, a negative
|
||||
value results in a larger mask and can be specified via `--threshold=-1.2`.
|
BIN
candle-examples/examples/segment-anything/assets/sam_merged.jpg
Normal file
After Width: | Height: | Size: 157 KiB |
164
candle-examples/examples/segment-anything/main.rs
Normal file
@ -0,0 +1,164 @@
|
||||
//! SAM: Segment Anything Model
|
||||
//! https://github.com/facebookresearch/segment-anything
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::DType;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::segment_anything::sam;
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
generate_masks: bool,
|
||||
|
||||
/// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_x: f64,
|
||||
|
||||
/// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_y: f64,
|
||||
|
||||
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||
/// mask, positive makes the mask more selective.
|
||||
#[arg(long, default_value_t = 0.)]
|
||||
threshold: f32,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Use the TinyViT based models from MobileSAM
|
||||
#[arg(long)]
|
||||
use_tiny: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let (image, initial_h, initial_w) =
|
||||
candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
|
||||
let image = image.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-sam".to_string());
|
||||
let filename = if args.use_tiny {
|
||||
"mobile_sam-tiny-vitt.safetensors"
|
||||
} else {
|
||||
"sam_vit_b_01ec64.safetensors"
|
||||
};
|
||||
api.get(filename)?
|
||||
}
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let sam = if args.use_tiny {
|
||||
sam::Sam::new_tiny(vb)? // tiny vit_t
|
||||
} else {
|
||||
sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
|
||||
};
|
||||
|
||||
if args.generate_masks {
|
||||
// Default options similar to the Python version.
|
||||
let bboxes = sam.generate_masks(
|
||||
&image,
|
||||
/* points_per_side */ 32,
|
||||
/* crop_n_layer */ 0,
|
||||
/* crop_overlap_ratio */ 512. / 1500.,
|
||||
/* crop_n_points_downscale_factor */ 1,
|
||||
)?;
|
||||
for (idx, bbox) in bboxes.iter().enumerate() {
|
||||
println!("{idx} {bbox:?}");
|
||||
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
|
||||
let (h, w) = mask.dims2()?;
|
||||
let mask = mask.broadcast_as((3, h, w))?;
|
||||
candle_examples::save_image_resize(
|
||||
&mask,
|
||||
format!("sam_mask{idx}.png"),
|
||||
initial_h,
|
||||
initial_w,
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
let point = Some((args.point_x, args.point_y));
|
||||
let start_time = std::time::Instant::now();
|
||||
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
||||
println!(
|
||||
"mask generated in {:.2}s",
|
||||
start_time.elapsed().as_secs_f32()
|
||||
);
|
||||
println!("mask:\n{mask}");
|
||||
println!("iou_predictions: {iou_predictions:?}");
|
||||
|
||||
let mask = (mask.ge(args.threshold)? * 255.)?;
|
||||
let (_one, h, w) = mask.dims3()?;
|
||||
let mask = mask.expand((3, h, w))?;
|
||||
|
||||
let mut img = image::io::Reader::open(&args.image)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
|
||||
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
|
||||
Some(image) => image,
|
||||
None => anyhow::bail!("error saving merged image"),
|
||||
};
|
||||
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
|
||||
img.width(),
|
||||
img.height(),
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
for x in 0..img.width() {
|
||||
for y in 0..img.height() {
|
||||
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
|
||||
if mask_p.0[0] > 100 {
|
||||
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
|
||||
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
|
||||
img_p.0[1] /= 2;
|
||||
img_p.0[0] /= 2;
|
||||
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
|
||||
}
|
||||
}
|
||||
}
|
||||
let (x, y) = (
|
||||
(args.point_x * img.width() as f64) as i32,
|
||||
(args.point_y * img.height() as f64) as i32,
|
||||
);
|
||||
imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
|
||||
.save("sam_merged.jpg")?
|
||||
}
|
||||
Ok(())
|
||||
}
|
63
candle-examples/examples/stable-diffusion/README.md
Normal file
@ -0,0 +1,63 @@
|
||||
# candle-stable-diffusion: A Diffusers API in Rust/Candle
|
||||
|
||||

|
||||
|
||||
_A rusty robot holding a fire torch in its hand_, generated by Stable Diffusion
|
||||
XL using Rust and [candle](https://github.com/huggingface/candle).
|
||||
|
||||
The `stable-diffusion` example is a conversion of
|
||||
[diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle
|
||||
rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1,
|
||||
as well as Stable Diffusion XL 1.0.
|
||||
|
||||
## Getting the weights
|
||||
|
||||
The weights are automatically downloaded for you from the [HuggingFace
|
||||
Hub](https://huggingface.co/) on the first run. There are various command line
|
||||
flags to use local files instead, run with `--help` to learn about them.
|
||||
|
||||
## Running some example.
|
||||
|
||||
```bash
|
||||
cargo run --example stable-diffusion --release --features=cuda,cudnn \
|
||||
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
|
||||
```
|
||||
|
||||
The final image is named `sd_final.png` by default.
|
||||
The default scheduler is the Denoising Diffusion Implicit Model scheduler (DDIM). The
|
||||
original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim).
|
||||
|
||||
### Command-line flags
|
||||
|
||||
- `--prompt`: the prompt to be used to generate the image.
|
||||
- `--uncond-prompt`: the optional unconditional prompt.
|
||||
- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`, or
|
||||
`xl`.
|
||||
- `--cpu`: use the cpu rather than the gpu (much slower).
|
||||
- `--height`, `--width`: set the height and width for the generated image.
|
||||
- `--n-steps`: the number of steps to be used in the diffusion process.
|
||||
- `--num-samples`: the number of samples to generate.
|
||||
- `--final-image`: the filename for the generated image(s).
|
||||
|
||||
### Using flash-attention
|
||||
|
||||
Using flash attention makes image generation a lot faster and uses less memory.
|
||||
The downside is some long compilation time. You can set the
|
||||
`CANDLE_FLASH_ATTN_BUILD_DIR` environment variable to something like
|
||||
`/home/user/.candle` to ensures that the compilation artifacts are properly
|
||||
cached.
|
||||
|
||||
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
|
||||
and using the command line flag `--use-flash-attn`.
|
||||
|
||||
## Image to Image Pipeline
|
||||
...
|
||||
|
||||
## FAQ
|
||||
|
||||
### Memory Issues
|
||||
|
||||
This requires a GPU with more than 8GB of memory, as a fallback the CPU version can be used
|
||||
with the `--cpu` flag but is much slower.
|
||||
Alternatively, reducing the height and width with the `--height` and `--width`
|
||||
flag is likely to reduce memory usage significantly.
|
After Width: | Height: | Size: 36 KiB |
@ -4,20 +4,10 @@ extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod attention;
|
||||
mod clip;
|
||||
mod ddim;
|
||||
mod embeddings;
|
||||
mod resnet;
|
||||
mod schedulers;
|
||||
mod stable_diffusion;
|
||||
mod unet_2d;
|
||||
mod unet_2d_blocks;
|
||||
mod utils;
|
||||
mod vae;
|
||||
use candle_transformers::models::stable_diffusion;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Tensor, D};
|
||||
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
||||
use clap::Parser;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
@ -96,6 +86,15 @@ struct Args {
|
||||
|
||||
#[arg(long)]
|
||||
use_f16: bool,
|
||||
|
||||
#[arg(long, value_name = "FILE")]
|
||||
img2img: Option<String>,
|
||||
|
||||
/// The strength, indicates how much to transform the initial image. The
|
||||
/// value must be between 0 and 1, a value of 1 discards the initial image
|
||||
/// information.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
img2img_strength: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
||||
@ -306,6 +305,26 @@ fn text_embeddings(
|
||||
Ok(text_embeddings)
|
||||
}
|
||||
|
||||
fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> {
|
||||
let img = image::io::Reader::open(path)?.decode()?;
|
||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||
let height = height - height % 32;
|
||||
let width = width - width % 32;
|
||||
let img = img.resize_to_fill(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
let img = img.to_rgb8();
|
||||
let img = img.into_raw();
|
||||
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
||||
.permute((2, 0, 1))?
|
||||
.to_dtype(DType::F32)?
|
||||
.affine(2. / 255., -1.)?
|
||||
.unsqueeze(0)?;
|
||||
Ok(img)
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
@ -328,9 +347,15 @@ fn run(args: Args) -> Result<()> {
|
||||
tracing,
|
||||
use_f16,
|
||||
use_flash_attn,
|
||||
img2img,
|
||||
img2img_strength,
|
||||
..
|
||||
} = args;
|
||||
|
||||
if !(0. ..=1.).contains(&img2img_strength) {
|
||||
anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}")
|
||||
}
|
||||
|
||||
let _guard = if tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
@ -382,25 +407,53 @@ fn run(args: Args) -> Result<()> {
|
||||
println!("Building the autoencoder.");
|
||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
||||
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
|
||||
let init_latent_dist = match &img2img {
|
||||
None => None,
|
||||
Some(image) => {
|
||||
let image = image_preprocess(image)?.to_device(&device)?;
|
||||
Some(vae.encode(&image)?)
|
||||
}
|
||||
};
|
||||
println!("Building the unet.");
|
||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
||||
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
|
||||
|
||||
let t_start = if img2img.is_some() {
|
||||
n_steps - (n_steps as f64 * img2img_strength) as usize
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let bsize = 1;
|
||||
for idx in 0..num_samples {
|
||||
let mut latents = Tensor::randn(
|
||||
let timesteps = scheduler.timesteps();
|
||||
let latents = match &init_latent_dist {
|
||||
Some(init_latent_dist) => {
|
||||
let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
|
||||
if t_start < timesteps.len() {
|
||||
let noise = latents.randn_like(0f64, 1f64)?;
|
||||
scheduler.add_noise(&latents, noise, timesteps[t_start])?
|
||||
} else {
|
||||
latents
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
|
||||
&device,
|
||||
)?
|
||||
.to_dtype(dtype)?;
|
||||
|
||||
)?;
|
||||
// scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = (latents * scheduler.init_noise_sigma())?;
|
||||
(latents * scheduler.init_noise_sigma())?
|
||||
}
|
||||
};
|
||||
let mut latents = latents.to_dtype(dtype)?;
|
||||
|
||||
println!("starting sampling");
|
||||
for (timestep_index, ×tep) in scheduler.timesteps().iter().enumerate() {
|
||||
for (timestep_index, ×tep) in timesteps.iter().enumerate() {
|
||||
if timestep_index < t_start {
|
||||
continue;
|
||||
}
|
||||
let start_time = std::time::Instant::now();
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
|
||||
|
25
candle-examples/examples/t5/README.md
Normal file
@ -0,0 +1,25 @@
|
||||
# candle-t5
|
||||
|
||||
## Encoder-decoder example:
|
||||
|
||||
```bash
|
||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
||||
...
|
||||
Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||
Eine schöne Kerze.
|
||||
9 tokens generated (2.42 token/s)
|
||||
```
|
||||
|
||||
## Sentence embedding example:
|
||||
|
||||
```bash
|
||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
|
||||
...
|
||||
[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
|
||||
[-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],
|
||||
[ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962],
|
||||
[-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990],
|
||||
[ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]]
|
||||
Tensor[[1, 5, 512], f32]
|
||||
Took 303.766583ms
|
||||
```
|
314
candle-examples/examples/t5/main.rs
Normal file
@ -0,0 +1,314 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use candle_transformers::models::t5;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model repository to use on the HuggingFace hub.
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// Enable decoding.
|
||||
#[arg(long)]
|
||||
decode: bool,
|
||||
|
||||
// Enable/disable decoding.
|
||||
#[arg(long, default_value = "false")]
|
||||
disable_cache: bool,
|
||||
|
||||
/// Use this prompt, otherwise compute sentence similarities.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// If set along with --decode, will use this prompt to initialize the decoder.
|
||||
#[arg(long)]
|
||||
decoder_prompt: Option<String>,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
struct T5ModelBuilder {
|
||||
device: Device,
|
||||
config: t5::Config,
|
||||
weights_filename: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl T5ModelBuilder {
|
||||
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let default_model = "t5-small".to_string();
|
||||
let default_revision = "refs/pr/15".to_string();
|
||||
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = api.get("config.json")?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = if model_id == "google/flan-t5-xxl" {
|
||||
vec![
|
||||
api.get("model-00001-of-00005.safetensors")?,
|
||||
api.get("model-00002-of-00005.safetensors")?,
|
||||
api.get("model-00003-of-00005.safetensors")?,
|
||||
api.get("model-00004-of-00005.safetensors")?,
|
||||
api.get("model-00005-of-00005.safetensors")?,
|
||||
]
|
||||
} else {
|
||||
vec![api.get("model.safetensors")?]
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
config.use_cache = !args.disable_cache;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
Ok((
|
||||
Self {
|
||||
device,
|
||||
config,
|
||||
weights_filename,
|
||||
},
|
||||
tokenizer,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
|
||||
let weights = self
|
||||
.weights_filename
|
||||
.iter()
|
||||
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|w| w.deserialize())
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
|
||||
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let weights = self
|
||||
.weights_filename
|
||||
.iter()
|
||||
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|w| w.deserialize())
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||
let device = &builder.device;
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
match args.prompt {
|
||||
Some(prompt) => {
|
||||
let tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
if !args.decode {
|
||||
let mut model = builder.build_encoder()?;
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&input_token_ids)?;
|
||||
println!("{ys}");
|
||||
println!("Took {:?}", start.elapsed());
|
||||
} else {
|
||||
let mut model = builder.build_conditional_generation()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
if let Some(decoder_prompt) = &args.decoder_prompt {
|
||||
print!("{decoder_prompt}");
|
||||
output_token_ids.extend(
|
||||
tokenizer
|
||||
.encode(decoder_prompt.to_string(), false)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec(),
|
||||
);
|
||||
}
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
|
||||
let encoder_output = model.encode(&input_token_ids)?;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for index in 0.. {
|
||||
if output_token_ids.len() > 512 {
|
||||
break;
|
||||
}
|
||||
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
|
||||
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||
} else {
|
||||
let last_token = *output_token_ids.last().unwrap();
|
||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||
};
|
||||
let logits = model
|
||||
.decode(&decoder_token_ids, &encoder_output)?
|
||||
.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&output_token_ids[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits)?;
|
||||
if next_token_id as usize == builder.config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
output_token_ids.push(next_token_id);
|
||||
if let Some(text) = tokenizer.id_to_token(next_token_id) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
output_token_ids.len(),
|
||||
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut model = builder.build_encoder()?;
|
||||
let sentences = [
|
||||
"The cat sits outside",
|
||||
"A man is playing guitar",
|
||||
"I love pasta",
|
||||
"The new movie is awesome",
|
||||
"The cat plays in the garden",
|
||||
"A woman watches TV",
|
||||
"The new movie is so great",
|
||||
"Do you like pizza?",
|
||||
];
|
||||
let n_sentences = sentences.len();
|
||||
let mut all_embeddings = Vec::with_capacity(n_sentences);
|
||||
for sentence in sentences {
|
||||
let tokens = tokenizer
|
||||
.encode(sentence, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
|
||||
let embeddings = model.forward(&token_ids)?;
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if args.normalize_embeddings {
|
||||
normalize_l2(&embeddings)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
all_embeddings.push(embeddings)
|
||||
}
|
||||
|
||||
let mut similarities = vec![];
|
||||
for (i, e_i) in all_embeddings.iter().enumerate() {
|
||||
for (j, e_j) in all_embeddings
|
||||
.iter()
|
||||
.enumerate()
|
||||
.take(n_sentences)
|
||||
.skip(i + 1)
|
||||
{
|
||||
let sum_ij = (e_i * e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_i2 = (e_i * e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_j2 = (e_j * e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||
similarities.push((cosine_similarity, i, j))
|
||||
}
|
||||
}
|
||||
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
||||
for &(score, i, j) in similarities[..5].iter() {
|
||||
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||
}
|
39
candle-examples/examples/whisper/README.md
Normal file
@ -0,0 +1,39 @@
|
||||
# candle-whisper: speech recognition
|
||||
|
||||
An implementation of [OpenAI Whisper](https://github.com/openai/whisper) using
|
||||
candle. Whisper is a general purpose speech recognition model, it can be used to
|
||||
convert audio files (in the `.wav` format) to text. Supported features include
|
||||
language detection as well as multilingual speech recognition.
|
||||
|
||||
## Running some example
|
||||
|
||||
If no audio file is passed as input, a [sample
|
||||
file](https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav) is automatically downloaded
|
||||
from the hub.
|
||||
|
||||
```bash
|
||||
cargo run --example whisper --release
|
||||
|
||||
> No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
|
||||
> loaded wav data: Header { audio_format: 1, channel_count: 1, sampling_rate: 16000, bytes_per_second: 32000, bytes_per_sample: 2, bits_per_sample: 16 }
|
||||
> pcm data loaded 176000
|
||||
> loaded mel: [1, 80, 3000]
|
||||
> 0.0s -- 30.0s: And so my fellow Americans ask not what your country can do for you ask what you can do for your country
|
||||
```
|
||||
|
||||
In order to use the multilingual mode, specify a multilingual model via the
|
||||
`--model` flag, see the details below.
|
||||
|
||||
## Command line flags
|
||||
|
||||
- `--input`: the audio file to be converted to text, in wav format.
|
||||
- `--language`: force the language to some specific value rather than being
|
||||
detected, e.g. `en`.
|
||||
- `--task`: the task to be performed, can be `transcribe` (return the text data
|
||||
in the original language) or `translate` (translate the text to English).
|
||||
- `--timestamps`: enable the timestamp mode where some timestamps are reported
|
||||
for each recognized audio extracts.
|
||||
- `--model`: the model to be used. Models that do not end with `-en` are
|
||||
multilingual models, other ones are English only models. The supported models
|
||||
are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`, `medium`,
|
||||
`medium.en`, `large`, and `large-v2`.
|
@ -10,41 +10,16 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle::{Device, IndexOp, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod audio;
|
||||
mod model;
|
||||
use model::{Config, Whisper};
|
||||
mod multilingual;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
// Audio parameters.
|
||||
const SAMPLE_RATE: usize = 16000;
|
||||
const N_FFT: usize = 400;
|
||||
const N_MELS: usize = 80;
|
||||
const HOP_LENGTH: usize = 160;
|
||||
const CHUNK_LENGTH: usize = 30;
|
||||
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
||||
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
|
||||
|
||||
const NO_SPEECH_THRESHOLD: f64 = 0.6;
|
||||
const LOGPROB_THRESHOLD: f64 = -1.0;
|
||||
const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
||||
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||
|
||||
// Tokenizer dependent bits.
|
||||
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||
const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||
const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||
use candle_transformers::models::whisper::{self as m, audio, model};
|
||||
use model::{Config, Whisper};
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
@ -94,7 +69,7 @@ impl Decoder {
|
||||
timestamps: bool,
|
||||
verbose: bool,
|
||||
) -> Result<Self> {
|
||||
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
||||
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
||||
// Suppress the notimestamps token when in timestamps mode.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||
@ -109,11 +84,11 @@ impl Decoder {
|
||||
})
|
||||
.collect();
|
||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
@ -220,17 +195,17 @@ impl Decoder {
|
||||
}
|
||||
|
||||
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
|
||||
for (i, &t) in TEMPERATURES.iter().enumerate() {
|
||||
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
|
||||
let dr: Result<DecodingResult> = self.decode(segment, t);
|
||||
if i == TEMPERATURES.len() - 1 {
|
||||
if i == m::TEMPERATURES.len() - 1 {
|
||||
return dr;
|
||||
}
|
||||
// On errors, we try again with a different temperature.
|
||||
match dr {
|
||||
Ok(dr) => {
|
||||
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
|
||||
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
|
||||
return Ok(dr);
|
||||
}
|
||||
}
|
||||
@ -248,13 +223,13 @@ impl Decoder {
|
||||
let mut segments = vec![];
|
||||
while seek < content_frames {
|
||||
let start = std::time::Instant::now();
|
||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
|
||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||
let dr = self.decode_with_fallback(&mel_segment)?;
|
||||
seek += segment_size;
|
||||
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
||||
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
|
||||
println!("no speech detected, skipping {seek} {dr:?}");
|
||||
continue;
|
||||
}
|
||||
@ -431,7 +406,6 @@ fn main() -> Result<()> {
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
@ -493,8 +467,8 @@ fn main() -> Result<()> {
|
||||
let mut input = std::fs::File::open(input)?;
|
||||
let (header, data) = wav::read(&mut input)?;
|
||||
println!("loaded wav data: {header:?}");
|
||||
if header.sampling_rate != SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE)
|
||||
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
|
||||
}
|
||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||
@ -502,14 +476,14 @@ fn main() -> Result<()> {
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
println!("pcm data loaded {}", pcm_data.len());
|
||||
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
|
||||
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
|
||||
let mel_len = mel.len();
|
||||
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
|
||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let mut model = Whisper::load(&vb, config)?;
|
||||
|
||||
|
@ -113,7 +113,7 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
|
||||
.iter()
|
||||
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
|
||||
let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
|
||||
let audio_features = model.encoder.forward(&mel, true)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
|
27
candle-examples/examples/wuerstchen/README.md
Normal file
@ -0,0 +1,27 @@
|
||||
# candle-wuerstchen: Efficient Pretraining of Text-to-Image Models
|
||||
|
||||

|
||||
|
||||
The `wuerstchen` example is a port of the [diffusers
|
||||
implementation](https://github.com/huggingface/diffusers/tree/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen) for Würstchen v2.
|
||||
The candle implementation reproduces the same structure/files for models and
|
||||
pipelines. Useful resources:
|
||||
|
||||
- [Official implementation](https://github.com/dome272/Wuerstchen).
|
||||
- [Arxiv paper](https://arxiv.org/abs/2306.00637).
|
||||
- Blog post: [Introducing Würstchen: Fast Diffusion for Image Generation](https://huggingface.co/blog/wuerstchen).
|
||||
|
||||
## Getting the weights
|
||||
|
||||
The weights are automatically downloaded for you from the [HuggingFace
|
||||
Hub](https://huggingface.co/) on the first run. There are various command line
|
||||
flags to use local files instead, run with `--help` to learn about them.
|
||||
|
||||
## Running some example.
|
||||
|
||||
```bash
|
||||
cargo run --example wuerstchen --release --features cuda,cudnn -- \
|
||||
--prompt "Anthropomorphic cat dressed as a fire fighter"
|
||||
```
|
||||
|
||||
The final image is named `sd_final.png` by default.
|
BIN
candle-examples/examples/wuerstchen/assets/cat.jpg
Normal file
After Width: | Height: | Size: 38 KiB |
396
candle-examples/examples/wuerstchen/main.rs
Normal file
@ -0,0 +1,396 @@
|
||||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use candle_transformers::models::stable_diffusion;
|
||||
use candle_transformers::models::wuerstchen;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
||||
use clap::Parser;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const PRIOR_GUIDANCE_SCALE: f64 = 4.0;
|
||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||
const LATENT_DIM_SCALE: f64 = 10.67;
|
||||
const PRIOR_CIN: usize = 16;
|
||||
const DECODER_CIN: usize = 4;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The prompt to be used for image generation.
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
|
||||
)]
|
||||
prompt: String,
|
||||
|
||||
#[arg(long, default_value = "")]
|
||||
uncond_prompt: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// The height in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
height: Option<usize>,
|
||||
|
||||
/// The width in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
width: Option<usize>,
|
||||
|
||||
/// The decoder weight file, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
decoder_weights: Option<String>,
|
||||
|
||||
/// The CLIP weight file, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
clip_weights: Option<String>,
|
||||
|
||||
/// The CLIP weight file used by the prior model, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
prior_clip_weights: Option<String>,
|
||||
|
||||
/// The prior weight file, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
prior_weights: Option<String>,
|
||||
|
||||
/// The VQGAN weight file, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
vqgan_weights: Option<String>,
|
||||
|
||||
#[arg(long, value_name = "FILE")]
|
||||
/// The file specifying the tokenizer to used for tokenization.
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long, value_name = "FILE")]
|
||||
/// The file specifying the tokenizer to used for prior tokenization.
|
||||
prior_tokenizer: Option<String>,
|
||||
|
||||
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
|
||||
#[arg(long)]
|
||||
sliced_attention_size: Option<usize>,
|
||||
|
||||
/// The number of steps to run the diffusion for.
|
||||
#[arg(long, default_value_t = 30)]
|
||||
n_steps: usize,
|
||||
|
||||
/// The number of samples to generate.
|
||||
#[arg(long, default_value_t = 1)]
|
||||
num_samples: i64,
|
||||
|
||||
/// The name of the final image to generate.
|
||||
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
|
||||
final_image: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ModelFile {
|
||||
Tokenizer,
|
||||
PriorTokenizer,
|
||||
Clip,
|
||||
PriorClip,
|
||||
Decoder,
|
||||
VqGan,
|
||||
Prior,
|
||||
}
|
||||
|
||||
impl ModelFile {
|
||||
fn get(&self, filename: Option<String>) -> Result<std::path::PathBuf> {
|
||||
use hf_hub::api::sync::Api;
|
||||
match filename {
|
||||
Some(filename) => Ok(std::path::PathBuf::from(filename)),
|
||||
None => {
|
||||
let repo_main = "warp-ai/wuerstchen";
|
||||
let repo_prior = "warp-ai/wuerstchen-prior";
|
||||
let (repo, path) = match self {
|
||||
Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"),
|
||||
Self::PriorTokenizer => (repo_prior, "tokenizer/tokenizer.json"),
|
||||
Self::Clip => (repo_main, "text_encoder/model.safetensors"),
|
||||
Self::PriorClip => (repo_prior, "text_encoder/model.safetensors"),
|
||||
Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"),
|
||||
Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"),
|
||||
Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"),
|
||||
};
|
||||
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
||||
Ok(filename)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn output_filename(
|
||||
basename: &str,
|
||||
sample_idx: i64,
|
||||
num_samples: i64,
|
||||
timestep_idx: Option<usize>,
|
||||
) -> String {
|
||||
let filename = if num_samples > 1 {
|
||||
match basename.rsplit_once('.') {
|
||||
None => format!("{basename}.{sample_idx}.png"),
|
||||
Some((filename_no_extension, extension)) => {
|
||||
format!("{filename_no_extension}.{sample_idx}.{extension}")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
basename.to_string()
|
||||
};
|
||||
match timestep_idx {
|
||||
None => filename,
|
||||
Some(timestep_idx) => match filename.rsplit_once('.') {
|
||||
None => format!("{filename}-{timestep_idx}.png"),
|
||||
Some((filename_no_extension, extension)) => {
|
||||
format!("{filename_no_extension}-{timestep_idx}.{extension}")
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt(
|
||||
prompt: &str,
|
||||
uncond_prompt: Option<&str>,
|
||||
tokenizer: std::path::PathBuf,
|
||||
clip_weights: std::path::PathBuf,
|
||||
clip_config: stable_diffusion::clip::Config,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let pad_id = match &clip_config.pad_with {
|
||||
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
||||
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
|
||||
};
|
||||
println!("Running with prompt \"{prompt}\".");
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let tokens_len = tokens.len();
|
||||
while tokens.len() < clip_config.max_position_embeddings {
|
||||
tokens.push(pad_id)
|
||||
}
|
||||
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
println!("Building the clip transformer.");
|
||||
let text_model =
|
||||
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
|
||||
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
|
||||
match uncond_prompt {
|
||||
None => Ok(text_embeddings),
|
||||
Some(uncond_prompt) => {
|
||||
let mut uncond_tokens = tokenizer
|
||||
.encode(uncond_prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let uncond_tokens_len = uncond_tokens.len();
|
||||
while uncond_tokens.len() < clip_config.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
let uncond_embeddings =
|
||||
text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
|
||||
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
|
||||
Ok(text_embeddings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let Args {
|
||||
prompt,
|
||||
uncond_prompt,
|
||||
cpu,
|
||||
height,
|
||||
width,
|
||||
n_steps,
|
||||
tokenizer,
|
||||
final_image,
|
||||
sliced_attention_size,
|
||||
num_samples,
|
||||
clip_weights,
|
||||
prior_weights,
|
||||
vqgan_weights,
|
||||
decoder_weights,
|
||||
tracing,
|
||||
..
|
||||
} = args;
|
||||
|
||||
let _guard = if tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let device = candle_examples::device(cpu)?;
|
||||
let height = height.unwrap_or(1024);
|
||||
let width = width.unwrap_or(1024);
|
||||
|
||||
let prior_text_embeddings = {
|
||||
let tokenizer = ModelFile::PriorTokenizer.get(args.prior_tokenizer)?;
|
||||
let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
|
||||
encode_prompt(
|
||||
&prompt,
|
||||
Some(&uncond_prompt),
|
||||
tokenizer.clone(),
|
||||
weights,
|
||||
stable_diffusion::clip::Config::wuerstchen_prior(),
|
||||
&device,
|
||||
)?
|
||||
};
|
||||
println!("generated prior text embeddings {prior_text_embeddings:?}");
|
||||
|
||||
let text_embeddings = {
|
||||
let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
|
||||
let weights = ModelFile::Clip.get(clip_weights)?;
|
||||
encode_prompt(
|
||||
&prompt,
|
||||
None,
|
||||
tokenizer.clone(),
|
||||
weights,
|
||||
stable_diffusion::clip::Config::wuerstchen(),
|
||||
&device,
|
||||
)?
|
||||
};
|
||||
println!("generated text embeddings {text_embeddings:?}");
|
||||
|
||||
println!("Building the prior.");
|
||||
let b_size = 1;
|
||||
let image_embeddings = {
|
||||
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
||||
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let mut latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||
&device,
|
||||
)?;
|
||||
|
||||
let prior = {
|
||||
let prior_weights = ModelFile::Prior.get(prior_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
wuerstchen::prior::WPrior::new(
|
||||
/* c_in */ PRIOR_CIN,
|
||||
/* c */ 1536,
|
||||
/* c_cond */ 1280,
|
||||
/* c_r */ 64,
|
||||
/* depth */ 32,
|
||||
/* nhead */ 24,
|
||||
args.use_flash_attn,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let timesteps = prior_scheduler.timesteps();
|
||||
let timesteps = ×teps[..timesteps.len() - 1];
|
||||
println!("prior denoising");
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
||||
let noise_pred = noise_pred.chunk(2, 0)?;
|
||||
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
|
||||
let noise_pred = (noise_pred_uncond
|
||||
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
|
||||
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||
}
|
||||
((latents * 42.)? - 1.)?
|
||||
};
|
||||
|
||||
println!("Building the vqgan.");
|
||||
let vqgan = {
|
||||
let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
wuerstchen::paella_vq::PaellaVQ::new(vb)?
|
||||
};
|
||||
|
||||
println!("Building the decoder.");
|
||||
|
||||
// https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
|
||||
let decoder = {
|
||||
let decoder_weights = ModelFile::Decoder.get(decoder_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
wuerstchen::diffnext::WDiffNeXt::new(
|
||||
/* c_in */ DECODER_CIN,
|
||||
/* c_out */ DECODER_CIN,
|
||||
/* c_r */ 64,
|
||||
/* c_cond */ 1024,
|
||||
/* clip_embd */ 1024,
|
||||
/* patch_size */ 2,
|
||||
args.use_flash_attn,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
|
||||
for idx in 0..num_samples {
|
||||
// https://huggingface.co/warp-ai/wuerstchen/blob/main/model_index.json
|
||||
let latent_height = (image_embeddings.dim(2)? as f64 * LATENT_DIM_SCALE) as usize;
|
||||
let latent_width = (image_embeddings.dim(3)? as f64 * LATENT_DIM_SCALE) as usize;
|
||||
|
||||
let mut latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(b_size, DECODER_CIN, latent_height, latent_width),
|
||||
&device,
|
||||
)?;
|
||||
|
||||
println!("diffusion process with prior {image_embeddings:?}");
|
||||
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(12, Default::default())?;
|
||||
let timesteps = scheduler.timesteps();
|
||||
let timesteps = ×teps[..timesteps.len() - 1];
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
|
||||
let noise_pred =
|
||||
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
|
||||
latents = scheduler.step(&noise_pred, t, &latents)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||
}
|
||||
println!(
|
||||
"Generating the final image for sample {}/{}.",
|
||||
idx + 1,
|
||||
num_samples
|
||||
);
|
||||
let image = vqgan.decode(&(&latents * 0.3764)?)?;
|
||||
// TODO: Add the clamping between 0 and 1.
|
||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||
candle_examples::save_image(&image, image_filename)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
run(args)
|
||||
}
|
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_examples::object_detection::{non_maximum_suppression, Bbox};
|
||||
use candle_transformers::object_detection::{non_maximum_suppression, Bbox};
|
||||
mod darknet;
|
||||
|
||||
use anyhow::Result;
|
||||
@ -46,7 +46,7 @@ pub fn report(
|
||||
let (npreds, pred_size) = pred.dims2()?;
|
||||
let nclasses = pred_size - 5;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
||||
let mut bboxes: Vec<Vec<Bbox<()>>> = (0..nclasses).map(|_| vec![]).collect();
|
||||
// Extract the bounding boxes for which confidence is above the threshold.
|
||||
for index in 0..npreds {
|
||||
let pred = Vec::<f32>::try_from(pred.get(index)?)?;
|
||||
@ -65,7 +65,7 @@ pub fn report(
|
||||
xmax: pred[0] + pred[2] / 2.,
|
||||
ymax: pred[1] + pred[3] / 2.,
|
||||
confidence,
|
||||
keypoints: vec![],
|
||||
data: (),
|
||||
};
|
||||
bboxes[class_index].push(bbox)
|
||||
}
|
||||
|
47
candle-examples/examples/yolo-v8/README.md
Normal file
@ -0,0 +1,47 @@
|
||||
# candle-yolo-v8: Object Detection and Pose Estimation
|
||||
|
||||
This is a port of [Ultralytics
|
||||
YOLOv8](https://github.com/ultralytics/ultralytics). The implementation is based
|
||||
on the [tinygrad
|
||||
version](https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py)
|
||||
and on the model architecture described in this
|
||||
[issue](https://github.com/ultralytics/ultralytics/issues/189). The supported
|
||||
tasks are object detection and pose estimation.
|
||||
|
||||
You can try this model online on the [Candle YOLOv8
|
||||
Space](https://huggingface.co/spaces/lmz/candle-yolo). The model then fully runs
|
||||
in your browser using WebAssembly - if you use a custom image it will never
|
||||
leave your phone/computer!
|
||||
|
||||
## Running some example
|
||||
|
||||
### Object Detection
|
||||
```bash
|
||||
cargo run --example yolo-v8 --release -- candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
```
|
||||
|
||||
This prints details about the detected objects and generates a `bike.pp.jpg` file.
|
||||
|
||||

|
||||
|
||||
Image source:
|
||||
[wikimedia](https://commons.wikimedia.org/wiki/File:Leading_group,_Giro_d%27Italia_2021,_Stage_15.jpg).
|
||||
|
||||

|
||||
|
||||
### Pose Estimation
|
||||
```bash
|
||||
cargo run --example yolo-v8 --release -- \
|
||||
candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose
|
||||
```
|
||||
|
||||

|
||||
|
||||
### Command-line flags
|
||||
|
||||
- `--which`: select the model variant to be used, `n`, `s` , `m`, `l`, or `x` by
|
||||
increasing size and quality.
|
||||
- `--task`: `detect` for object detection and `pose` for pose estimation.
|
||||
- `--legend-size`: the size of the characters to print.
|
||||
- `--model`: use a local model file rather than downloading it from the hub.
|
||||
|
BIN
candle-examples/examples/yolo-v8/assets/bike.jpg
Normal file
After Width: | Height: | Size: 179 KiB |
BIN
candle-examples/examples/yolo-v8/assets/bike.od.jpg
Normal file
After Width: | Height: | Size: 175 KiB |
BIN
candle-examples/examples/yolo-v8/assets/bike.pose.jpg
Normal file
After Width: | Height: | Size: 189 KiB |
@ -8,8 +8,8 @@ mod model;
|
||||
use model::{Multiples, YoloV8, YoloV8Pose};
|
||||
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_examples::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use image::DynamicImage;
|
||||
|
||||
@ -64,7 +64,7 @@ pub fn report_detect(
|
||||
let (pred_size, npreds) = pred.dims2()?;
|
||||
let nclasses = pred_size - 4;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
||||
let mut bboxes: Vec<Vec<Bbox<Vec<KeyPoint>>>> = (0..nclasses).map(|_| vec![]).collect();
|
||||
// Extract the bounding boxes for which confidence is above the threshold.
|
||||
for index in 0..npreds {
|
||||
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
|
||||
@ -83,7 +83,7 @@ pub fn report_detect(
|
||||
xmax: pred[0] + pred[2] / 2.,
|
||||
ymax: pred[1] + pred[3] / 2.,
|
||||
confidence,
|
||||
keypoints: vec![],
|
||||
data: vec![],
|
||||
};
|
||||
bboxes[class_index].push(bbox)
|
||||
}
|
||||
@ -176,7 +176,7 @@ pub fn report_pose(
|
||||
xmax: pred[0] + pred[2] / 2.,
|
||||
ymax: pred[1] + pred[3] / 2.,
|
||||
confidence,
|
||||
keypoints,
|
||||
data: keypoints,
|
||||
};
|
||||
bboxes.push(bbox)
|
||||
}
|
||||
@ -204,7 +204,7 @@ pub fn report_pose(
|
||||
image::Rgb([255, 0, 0]),
|
||||
);
|
||||
}
|
||||
for kp in b.keypoints.iter() {
|
||||
for kp in b.data.iter() {
|
||||
if kp.mask < 0.6 {
|
||||
continue;
|
||||
}
|
||||
@ -219,8 +219,8 @@ pub fn report_pose(
|
||||
}
|
||||
|
||||
for &(idx1, idx2) in KP_CONNECTIONS.iter() {
|
||||
let kp1 = &b.keypoints[idx1];
|
||||
let kp2 = &b.keypoints[idx2];
|
||||
let kp1 = &b.data[idx1];
|
||||
let kp2 = &b.data[idx2];
|
||||
if kp1.mask < 0.6 || kp2.mask < 0.6 {
|
||||
continue;
|
||||
}
|
||||
|
@ -1,6 +1,5 @@
|
||||
pub mod coco_classes;
|
||||
pub mod imagenet;
|
||||
pub mod object_detection;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
@ -16,6 +15,36 @@ pub fn device(cpu: bool) -> Result<Device> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
resize_longest: Option<usize>,
|
||||
) -> Result<(Tensor, usize, usize)> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
|
||||
let img = match resize_longest {
|
||||
None => img,
|
||||
Some(resize_longest) => {
|
||||
let (height, width) = (img.height(), img.width());
|
||||
let resize_longest = resize_longest as u32;
|
||||
let (height, width) = if height < width {
|
||||
let h = (resize_longest * height) / width;
|
||||
(h, resize_longest)
|
||||
} else {
|
||||
let w = (resize_longest * width) / height;
|
||||
(resize_longest, w)
|
||||
};
|
||||
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
|
||||
}
|
||||
};
|
||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
Ok((data, initial_h, initial_w))
|
||||
}
|
||||
|
||||
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
width: usize,
|
||||
@ -35,14 +64,14 @@ pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
||||
}
|
||||
|
||||
/// Saves an image to disk using the image crate, this expects an input with shape
|
||||
/// (c, width, height).
|
||||
/// (c, height, width).
|
||||
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
||||
let p = p.as_ref();
|
||||
let (channel, width, height) = img.dims3()?;
|
||||
let (channel, height, width) = img.dims3()?;
|
||||
if channel != 3 {
|
||||
candle::bail!("save_image expects an input of shape (3, width, height)")
|
||||
candle::bail!("save_image expects an input of shape (3, height, width)")
|
||||
}
|
||||
let img = img.transpose(0, 1)?.t()?.flatten_all()?;
|
||||
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
||||
let pixels = img.to_vec1::<u8>()?;
|
||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
||||
@ -52,3 +81,27 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
||||
image.save(p).map_err(candle::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn save_image_resize<P: AsRef<std::path::Path>>(
|
||||
img: &Tensor,
|
||||
p: P,
|
||||
h: usize,
|
||||
w: usize,
|
||||
) -> Result<()> {
|
||||
let p = p.as_ref();
|
||||
let (channel, height, width) = img.dims3()?;
|
||||
if channel != 3 {
|
||||
candle::bail!("save_image expects an input of shape (3, height, width)")
|
||||
}
|
||||
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
||||
let pixels = img.to_vec1::<u8>()?;
|
||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
||||
Some(image) => image,
|
||||
None => candle::bail!("error saving image {p:?}"),
|
||||
};
|
||||
let image = image::DynamicImage::from(image);
|
||||
let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
|
||||
image.save(p).map_err(candle::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.2.1"
|
||||
version = "0.2.3"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.2.1", package = "candle-core" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.2.3", package = "candle-core" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
@ -21,4 +21,4 @@ rayon = "1.7.0"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1", features = ["cuda"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3", features = ["cuda"] }
|
||||
|
@ -6,7 +6,7 @@ use rayon::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
|
||||
const KERNEL_FILES: [&str; 9] = [
|
||||
const KERNEL_FILES: [&str; 17] = [
|
||||
"flash_api.cu",
|
||||
"flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"flash_fwd_hdim160_fp16_sm80.cu",
|
||||
@ -16,14 +16,14 @@ const KERNEL_FILES: [&str; 9] = [
|
||||
"flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"flash_fwd_hdim96_fp16_sm80.cu",
|
||||
// "flash_fwd_hdim128_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim160_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim192_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim224_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim256_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim32_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim64_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim96_bf16_sm80.cu",
|
||||
"flash_fwd_hdim128_bf16_sm80.cu",
|
||||
"flash_fwd_hdim160_bf16_sm80.cu",
|
||||
"flash_fwd_hdim192_bf16_sm80.cu",
|
||||
"flash_fwd_hdim224_bf16_sm80.cu",
|
||||
"flash_fwd_hdim256_bf16_sm80.cu",
|
||||
"flash_fwd_hdim32_bf16_sm80.cu",
|
||||
"flash_fwd_hdim64_bf16_sm80.cu",
|
||||
"flash_fwd_hdim96_bf16_sm80.cu",
|
||||
];
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -57,9 +57,20 @@ fn main() -> Result<()> {
|
||||
#[allow(clippy::redundant_clone)]
|
||||
out_dir.clone()
|
||||
}
|
||||
Ok(build_dir) => PathBuf::from(build_dir),
|
||||
Ok(build_dir) => {
|
||||
let path = PathBuf::from(build_dir);
|
||||
path.canonicalize().expect(&format!(
|
||||
"Directory doesn't exists: {} (the current directory is {})",
|
||||
&path.display(),
|
||||
std::env::current_dir()?.display()
|
||||
))
|
||||
}
|
||||
};
|
||||
set_cuda_include_dir()?;
|
||||
|
||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||
|
||||
let compute_cap = compute_cap()?;
|
||||
|
||||
let out_file = build_dir.join("libflashattention.a");
|
||||
@ -95,14 +106,21 @@ fn main() -> Result<()> {
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.arg("-Icutlass/include")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg(cu_file);
|
||||
.arg("--verbose");
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
.arg("-allow-unsupported-compiler")
|
||||
.args(["-ccbin", ccbin_path]);
|
||||
}
|
||||
command.arg(cu_file);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
"nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
&command,
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
@ -122,7 +140,8 @@ fn main() -> Result<()> {
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while linking:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
"nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
&command,
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
|
@ -1,20 +1,19 @@
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// TODO: Switch back to handling bf16.
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// FP16_SWITCH(!params.is_bf16, [&] {
|
||||
// FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
// run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
// });
|
||||
// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
|
||||
// });
|
||||
// }
|
||||
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
extern "C" void run_mha(
|
||||
void *q_ptr,
|
||||
void *k_ptr,
|
||||
@ -52,7 +51,8 @@ extern "C" void run_mha(
|
||||
uint32_t seqlen_q_rounded,
|
||||
uint32_t seqlen_k_rounded,
|
||||
|
||||
int is_causal
|
||||
int is_causal,
|
||||
int is_bf16
|
||||
) {
|
||||
Flash_fwd_params params;
|
||||
// Reset the parameters
|
||||
@ -102,7 +102,7 @@ extern "C" void run_mha(
|
||||
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
||||
params.rp_dropout = 1.f / params.p_dropout;
|
||||
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
||||
params.is_bf16 = 0;
|
||||
params.is_bf16 = is_bf16;
|
||||
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
||||
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
||||
params.p_ptr = nullptr; // used for `return_softmax`.
|
||||
|
@ -38,6 +38,7 @@ extern "C" {
|
||||
seqlen_k_rounded: u32,
|
||||
|
||||
is_causal: c_int,
|
||||
is_bf16: c_int,
|
||||
);
|
||||
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle::cuda_backend::WrapErr;
|
||||
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
||||
use half::f16;
|
||||
use half::{bf16, f16};
|
||||
|
||||
pub struct FlashAttn {
|
||||
pub softmax_scale: f32,
|
||||
@ -15,24 +15,10 @@ fn round_multiple(x: usize, m: usize) -> usize {
|
||||
(x + m - 1) / m * m
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for FlashAttn {
|
||||
fn name(&self) -> &'static str {
|
||||
"flash-attn"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
candle::bail!("no cpu support for flash-attn")
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
impl FlashAttn {
|
||||
fn cuda_fwd_t<
|
||||
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||
>(
|
||||
&self,
|
||||
q: &candle::CudaStorage,
|
||||
q_l: &Layout,
|
||||
@ -40,15 +26,16 @@ impl candle::CustomOp3 for FlashAttn {
|
||||
k_l: &Layout,
|
||||
v: &candle::CudaStorage,
|
||||
v_l: &Layout,
|
||||
is_bf16: bool,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
// https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187
|
||||
let dev = q.device();
|
||||
let out_shape = q_l.shape().clone();
|
||||
let out_l = Layout::contiguous(&out_shape);
|
||||
|
||||
let q = q.as_cuda_slice::<f16>()?;
|
||||
let k = k.as_cuda_slice::<f16>()?;
|
||||
let v = v.as_cuda_slice::<f16>()?;
|
||||
let q = q.as_cuda_slice::<T>()?;
|
||||
let k = k.as_cuda_slice::<T>()?;
|
||||
let v = v.as_cuda_slice::<T>()?;
|
||||
let q = q.slice(q_l.start_offset()..);
|
||||
let k = k.slice(k_l.start_offset()..);
|
||||
let v = v.slice(v_l.start_offset()..);
|
||||
@ -104,10 +91,11 @@ impl candle::CustomOp3 for FlashAttn {
|
||||
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
let elem_count = out_shape.elem_count();
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
||||
|
||||
let causal = if self.causal { 1 } else { 0 };
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
@ -146,6 +134,7 @@ impl candle::CustomOp3 for FlashAttn {
|
||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_causal */ causal,
|
||||
/* is_bf16 */ is_bf16,
|
||||
)
|
||||
}
|
||||
|
||||
@ -154,6 +143,40 @@ impl candle::CustomOp3 for FlashAttn {
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for FlashAttn {
|
||||
fn name(&self) -> &'static str {
|
||||
"flash-attn"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
candle::bail!("no cpu support for flash-attn")
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
q: &candle::CudaStorage,
|
||||
q_l: &Layout,
|
||||
k: &candle::CudaStorage,
|
||||
k_l: &Layout,
|
||||
v: &candle::CudaStorage,
|
||||
v_l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
match q.dtype() {
|
||||
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false),
|
||||
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true),
|
||||
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Flash-attention v2 layer.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
@ -190,24 +213,10 @@ struct FlashAttnVarLen {
|
||||
seqlens_k: Tensor,
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for FlashAttnVarLen {
|
||||
fn name(&self) -> &'static str {
|
||||
"flash-attn-varlen"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
candle::bail!("no cpu support for flash-attn")
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
impl FlashAttnVarLen {
|
||||
fn cuda_fwd_t<
|
||||
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||
>(
|
||||
&self,
|
||||
q: &candle::CudaStorage,
|
||||
q_l: &Layout,
|
||||
@ -215,6 +224,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
|
||||
k_l: &Layout,
|
||||
v: &candle::CudaStorage,
|
||||
v_l: &Layout,
|
||||
is_bf16: bool,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
// https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327
|
||||
let dev = q.device();
|
||||
@ -314,6 +324,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
|
||||
.w()?;
|
||||
|
||||
let causal = if self.causal { 1 } else { 0 };
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
@ -354,6 +365,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
|
||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_causal */ causal,
|
||||
/* is_bf16 */ is_bf16,
|
||||
)
|
||||
}
|
||||
|
||||
@ -362,6 +374,40 @@ impl candle::CustomOp3 for FlashAttnVarLen {
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for FlashAttnVarLen {
|
||||
fn name(&self) -> &'static str {
|
||||
"flash-attn-varlen"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
candle::bail!("no cpu support for flash-attn")
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
q: &candle::CudaStorage,
|
||||
q_l: &Layout,
|
||||
k: &candle::CudaStorage,
|
||||
k_l: &Layout,
|
||||
v: &candle::CudaStorage,
|
||||
v_l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
match q.dtype() {
|
||||
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false),
|
||||
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true),
|
||||
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Flash-attention v2 layer with variable-length batching.
|
||||
///
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.2.1"
|
||||
version = "0.2.3"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -164,6 +164,8 @@ mod cuda {
|
||||
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||
|
||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||
let children = kernel_paths
|
||||
.par_iter()
|
||||
.flat_map(|p| {
|
||||
@ -188,8 +190,13 @@ mod cuda {
|
||||
.args(["--output-directory", &out_dir])
|
||||
// Flash attention only
|
||||
// .arg("--expt-relaxed-constexpr")
|
||||
.args(&include_options)
|
||||
.arg(p);
|
||||
.args(&include_options);
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
.arg("-allow-unsupported-compiler")
|
||||
.args(["-ccbin", ccbin_path]);
|
||||
}
|
||||
command.arg(p);
|
||||
Some((p, command.spawn()
|
||||
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
|
||||
}})
|
||||
|
@ -77,20 +77,30 @@ CAST_OP(double, __half, cast_f64_f16)
|
||||
|
||||
CAST_OP(uint32_t, uint32_t, cast_u32_u32)
|
||||
CAST_OP(uint32_t, uint8_t, cast_u32_u8 )
|
||||
CAST_OP(uint32_t, int64_t, cast_u32_i64 )
|
||||
CAST_OP(uint32_t, float, cast_u32_f32)
|
||||
CAST_OP(uint32_t, double, cast_u32_f64)
|
||||
|
||||
CAST_OP(uint8_t, uint32_t, cast_u8_u32)
|
||||
CAST_OP(uint8_t, uint8_t, cast_u8_u8 )
|
||||
CAST_OP(uint8_t, int64_t, cast_u8_i64 )
|
||||
CAST_OP(uint8_t, float, cast_u8_f32)
|
||||
CAST_OP(uint8_t, double, cast_u8_f64)
|
||||
|
||||
CAST_OP(int64_t, uint32_t, cast_i64_u32)
|
||||
CAST_OP(int64_t, uint8_t, cast_i64_u8 )
|
||||
CAST_OP(int64_t, int64_t, cast_i64_i64 )
|
||||
CAST_OP(int64_t, float, cast_i64_f32)
|
||||
CAST_OP(int64_t, double, cast_i64_f64)
|
||||
|
||||
CAST_OP(float, uint8_t, cast_f32_u8 )
|
||||
CAST_OP(float, uint32_t, cast_f32_u32)
|
||||
CAST_OP(float, int64_t, cast_f32_i64 )
|
||||
CAST_OP(float, float, cast_f32_f32)
|
||||
CAST_OP(float, double, cast_f32_f64)
|
||||
|
||||
CAST_OP(double, uint8_t, cast_f64_u8 )
|
||||
CAST_OP(double, uint32_t, cast_f64_u32)
|
||||
CAST_OP(double, int64_t, cast_f64_i64 )
|
||||
CAST_OP(double, float, cast_f64_f32)
|
||||
CAST_OP(double, double, cast_f64_f64)
|
||||
|
@ -51,6 +51,118 @@ __device__ void conv1d(
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void im2col1d(
|
||||
const size_t dst_numel,
|
||||
const size_t l_out,
|
||||
const size_t l_k,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t dilation,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// dst: (b_size, l_out, c_in, l_k)
|
||||
// src: (b_size, c_in, l_in)
|
||||
if (dst_i >= dst_numel) {
|
||||
return;
|
||||
}
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 3;
|
||||
const size_t b_in = src_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
|
||||
const size_t dst_s2 = l_k;
|
||||
const size_t dst_s1 = c_in * dst_s2;
|
||||
const size_t dst_s0 = l_out * dst_s1;
|
||||
|
||||
size_t tmp_dst_i = dst_i;
|
||||
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||
tmp_dst_i -= b_idx * dst_s0;
|
||||
const size_t l_idx = tmp_dst_i / dst_s1;
|
||||
tmp_dst_i -= l_idx * dst_s1;
|
||||
const size_t c_idx = tmp_dst_i / dst_s2;
|
||||
tmp_dst_i -= c_idx * dst_s2;
|
||||
const size_t l_k_idx = tmp_dst_i;
|
||||
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
|
||||
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
}
|
||||
else {
|
||||
src_l_idx -= padding;
|
||||
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
|
||||
dst[dst_i] = src[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void im2col(
|
||||
const size_t dst_numel,
|
||||
const size_t h_out,
|
||||
const size_t w_out,
|
||||
const size_t h_k,
|
||||
const size_t w_k,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t dilation,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// dst: (b_size, h_out, w_out, c_in, h_k, w_k)
|
||||
// src: (b_size, c_in, h_in, w_in)
|
||||
if (dst_i >= dst_numel) {
|
||||
return;
|
||||
}
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 4;
|
||||
const size_t b_in = src_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t h_in = src_dims[2];
|
||||
const size_t w_in = src_dims[3];
|
||||
|
||||
const size_t dst_s4 = w_k;
|
||||
const size_t dst_s3 = h_k * dst_s4;
|
||||
const size_t dst_s2 = c_in * dst_s3;
|
||||
const size_t dst_s1 = w_out * dst_s2;
|
||||
const size_t dst_s0 = h_out * dst_s1;
|
||||
|
||||
size_t tmp_dst_i = dst_i;
|
||||
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||
tmp_dst_i -= b_idx * dst_s0;
|
||||
const size_t h_idx = tmp_dst_i / dst_s1;
|
||||
tmp_dst_i -= h_idx * dst_s1;
|
||||
const size_t w_idx = tmp_dst_i / dst_s2;
|
||||
tmp_dst_i -= w_idx * dst_s2;
|
||||
const size_t c_idx = tmp_dst_i / dst_s3;
|
||||
tmp_dst_i -= c_idx * dst_s3;
|
||||
const size_t h_k_idx = tmp_dst_i / dst_s4;
|
||||
tmp_dst_i -= h_k_idx * dst_s4;
|
||||
const size_t w_k_idx = tmp_dst_i;
|
||||
size_t src_h_idx = h_idx * stride + h_k_idx * dilation;
|
||||
size_t src_w_idx = w_idx * stride + w_k_idx * dilation;
|
||||
if (src_h_idx < padding || src_h_idx >= h_in + padding) {
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
}
|
||||
else if (src_w_idx < padding || src_w_idx >= w_in + padding) {
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
}
|
||||
else {
|
||||
src_h_idx -= padding;
|
||||
src_w_idx -= padding;
|
||||
const size_t src_i =
|
||||
b_idx * src_s[0]
|
||||
+ c_idx * src_s[1]
|
||||
+ src_h_idx * src_s[2]
|
||||
+ src_w_idx * src_s[3];
|
||||
dst[dst_i] = src[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
// Naive implementation of conv2d.
|
||||
template <typename T, typename A>
|
||||
__device__ void conv2d(
|
||||
@ -363,6 +475,38 @@ extern "C" __global__ void FN_NAME( \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL1D_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t dst_numel, \
|
||||
const size_t l_out, \
|
||||
const size_t l_k, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t dilation, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t dst_numel, \
|
||||
const size_t h_out, \
|
||||
const size_t w_out, \
|
||||
const size_t h_k, \
|
||||
const size_t w_k, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t dilation, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \
|
||||
} \
|
||||
|
||||
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
@ -428,6 +572,8 @@ CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
|
||||
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -437,6 +583,8 @@ CONVT2D_OP(__half, float, conv_transpose2d_f16)
|
||||
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||
IM2COL_OP(__half, im2col_f16)
|
||||
IM2COL1D_OP(__half, im2col1d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
@ -468,3 +616,13 @@ UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
|
||||
UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64)
|
||||
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
|
||||
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
|
||||
|
||||
IM2COL_OP(float, im2col_f32)
|
||||
IM2COL_OP(double, im2col_f64)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
IM2COL_OP(uint32_t, im2col_u32)
|
||||
|
||||
IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(double, im2col1d_f64)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
|
@ -129,6 +129,10 @@ __device__ __forceinline__ float powg(float a, float b) { return powf(a, b); }
|
||||
__device__ __forceinline__ double powg(double a, double b) { return pow(a, b); }
|
||||
__device__ __forceinline__ float tanhg(float a) { return tanhf(a); }
|
||||
__device__ __forceinline__ double tanhg(double a) { return tanh(a); }
|
||||
__device__ __forceinline__ float erfg(float a) { return erff(a); }
|
||||
__device__ __forceinline__ double erfg(double a) { return erf(a); }
|
||||
__device__ __forceinline__ float normcdfg(float a) { return normcdff(a); }
|
||||
__device__ __forceinline__ double normcdfg(double a) { return normcdf(a); }
|
||||
__device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); }
|
||||
__device__ __forceinline__ double maxg(double a, double b) { return fmax(a, b); }
|
||||
__device__ __forceinline__ float ming(float a, float b) { return fminf(a, b); }
|
||||
@ -157,6 +161,8 @@ __device__ __forceinline__ __half sing(__half a) { return hsin(a); }
|
||||
__device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; }
|
||||
__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); }
|
||||
__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); }
|
||||
__device__ __forceinline__ __half erfg(__half a) { return __float2half(erff(__half2float(a))); }
|
||||
__device__ __forceinline__ __half normcdfg(__half a) { return __float2half(normcdff(__half2float(a))); }
|
||||
__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); }
|
||||
__device__ __forceinline__ __half logg(__half a) { return hlog(a); }
|
||||
__device__ __forceinline__ __half expg(__half a) { return hexp(a); }
|
||||
@ -173,6 +179,8 @@ __device__ __forceinline__ __nv_bfloat16 sing(__nv_bfloat16 a) { return hsin(a);
|
||||
__device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; }
|
||||
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
|
||||
__device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 erfg(__nv_bfloat16 a) { return __float2bfloat16(erff(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 normcdfg(__nv_bfloat16 a) { return __float2bfloat16(normcdff(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); }
|
||||
__device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); }
|
||||
__device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); }
|
||||
|
@ -49,6 +49,50 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
dst[dst_id] = shr[0];
|
||||
}
|
||||
|
||||
// Softmax implementation adapted from ggml.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159
|
||||
template <typename T, typename ACC>
|
||||
__device__ void softmax(const T * x, T * dst, const int ncols) {
|
||||
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int block_size = blockDim.y;
|
||||
const int tid = threadIdx.y;
|
||||
|
||||
T max_val = -INFINITY;
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const int i = row*ncols + col;
|
||||
max_val = maxg(max_val, x[i]);
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
max_val = maxg(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
|
||||
}
|
||||
|
||||
ACC tmp = 0.;
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const int i = row*ncols + col;
|
||||
const T val = expg(x[i] - max_val);
|
||||
tmp += static_cast<ACC>(val);
|
||||
dst[i] = val;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
|
||||
const ACC inv_tmp = 1. / tmp;
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const int i = row*ncols + col;
|
||||
dst[i] *= inv_tmp;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void
|
||||
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
@ -290,12 +334,21 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
} \
|
||||
}
|
||||
|
||||
#define SOFTMAX_OP(TYPENAME, ACC_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, TYPENAME *dst, \
|
||||
const int n_cols) { \
|
||||
softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
SOFTMAX_OP(__half, float, softmax_f16)
|
||||
SUM_OP(__half, sum_f16)
|
||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||
#endif
|
||||
@ -303,6 +356,8 @@ FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fa
|
||||
SUM_OP(float, sum_f32)
|
||||
SUM_OP(double, sum_f64)
|
||||
SUM_OP(uint32_t, sum_u32)
|
||||
SOFTMAX_OP(float, float, softmax_f32)
|
||||
SOFTMAX_OP(double, double, softmax_f64)
|
||||
|
||||
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
||||
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
||||
|
@ -28,6 +28,11 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_erf_fwd(T x) {
|
||||
return x * normcdfg(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_fwd(T x) {
|
||||
T x_sq = x * x;
|
||||
@ -86,10 +91,13 @@ UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
|
||||
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
|
||||
UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x))
|
||||
UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x))
|
||||
UNARY_OP(__nv_bfloat16, uerf_bf16, erfg(x))
|
||||
UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x))
|
||||
UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x))
|
||||
UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
|
||||
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
|
||||
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
@ -104,10 +112,13 @@ UNARY_OP(__half, ulog_f16, logg(x))
|
||||
UNARY_OP(__half, usin_f16, sing(x))
|
||||
UNARY_OP(__half, ucos_f16, cosg(x))
|
||||
UNARY_OP(__half, utanh_f16, tanhg(x))
|
||||
UNARY_OP(__half, uerf_f16, erfg(x))
|
||||
UNARY_OP(__half, unormcdf_f16, normcdfg(x))
|
||||
UNARY_OP(__half, uabs_f16, absg(x))
|
||||
UNARY_OP(__half, usqr_f16, x*x)
|
||||
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
||||
UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
||||
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
@ -131,6 +142,10 @@ UNARY_OP(float, ucos_f32, cosg(x))
|
||||
UNARY_OP(double, ucos_f64, cosg(x))
|
||||
UNARY_OP(float, utanh_f32, tanhg(x))
|
||||
UNARY_OP(double, utanh_f64, tanhg(x))
|
||||
UNARY_OP(float, uerf_f32, erfg(x))
|
||||
UNARY_OP(double, uerf_f64, erfg(x))
|
||||
UNARY_OP(float, unormcdf_f32, normcdfg(x))
|
||||
UNARY_OP(double, unormcdf_f64, normcdfg(x))
|
||||
UNARY_OP(float, uabs_f32, absg(x))
|
||||
UNARY_OP(double, uabs_f64, absg(x))
|
||||
UNARY_OP(float, usqr_f32, x*x)
|
||||
@ -139,6 +154,8 @@ UNARY_OP(float, usqrt_f32, sqrtg(x))
|
||||
UNARY_OP(double, usqrt_f64, sqrtg(x))
|
||||
UNARY_OP(float, ugelu_f32, gelu_fwd(x))
|
||||
UNARY_OP(double, ugelu_f64, gelu_fwd(x))
|
||||
UNARY_OP(float, ugelu_erf_f32, gelu_erf_fwd(x))
|
||||
UNARY_OP(double, ugelu_erf_f64, gelu_erf_fwd(x))
|
||||
UNARY_OP(float, urelu_f32, relu_fwd(x))
|
||||
UNARY_OP(double, urelu_f64, relu_fwd(x))
|
||||
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
|
||||
|
@ -11,13 +11,18 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
half = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
302
candle-nn/examples/cpu_benchmarks.rs
Normal file
@ -0,0 +1,302 @@
|
||||
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::quantized::GgmlType;
|
||||
use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
const CHECK_CONV2D: bool = false;
|
||||
|
||||
trait Benchmark {
|
||||
type PreProcessData;
|
||||
type RunResult;
|
||||
|
||||
fn preprocess() -> Result<Self::PreProcessData>;
|
||||
fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>;
|
||||
|
||||
const ITERS: usize;
|
||||
}
|
||||
|
||||
struct Im2Col {
|
||||
h_k: usize,
|
||||
w_k: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
impl Im2Col {
|
||||
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
|
||||
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
|
||||
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
|
||||
(h_out, w_out)
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::CustomOp1 for Im2Col {
|
||||
fn name(&self) -> &'static str {
|
||||
"im2col"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
let &Self {
|
||||
h_k,
|
||||
w_k,
|
||||
stride,
|
||||
dilation,
|
||||
padding,
|
||||
} = self;
|
||||
let (b, c, h, w) = layout.shape().dims4()?;
|
||||
let (h_out, w_out) = self.hw_out(h, w);
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let src = &slice[layout.start_offset()..];
|
||||
let mut dst = vec![0f32; b * h_out * w_out * c * h_k * w_k];
|
||||
let (src_s0, src_s1, src_s2, src_s3) = {
|
||||
let s = layout.stride();
|
||||
(s[0], s[1], s[2], s[3])
|
||||
};
|
||||
// TODO: provide specialized kernels for the common use cases.
|
||||
// - h_k = w_k = 1
|
||||
// - padding = 0
|
||||
// - stride = 1
|
||||
// - dilation = 1
|
||||
for b_idx in 0..b {
|
||||
let src_idx = b_idx * src_s0;
|
||||
let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
|
||||
for h_idx in 0..h_out {
|
||||
let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
|
||||
for w_idx in 0..w_out {
|
||||
let dst_idx = dst_idx + w_idx * c * h_k * w_k;
|
||||
for c_idx in 0..c {
|
||||
let dst_idx = dst_idx + c_idx * h_k * w_k;
|
||||
let src_idx = c_idx * src_s1 + src_idx;
|
||||
for h_k_idx in 0..h_k {
|
||||
let src_h = h_idx * stride + h_k_idx * dilation;
|
||||
if padding != 0 && (src_h < padding || src_h >= h + padding) {
|
||||
continue;
|
||||
}
|
||||
let src_h = src_h - padding;
|
||||
let src_idx = src_idx + src_h * src_s2;
|
||||
let dst_idx = dst_idx + h_k_idx * w_k;
|
||||
for w_k_idx in 0..w_k {
|
||||
let src_w = w_idx * stride + w_k_idx * dilation;
|
||||
if padding != 0 && (src_w < padding || src_w >= w + padding) {
|
||||
continue;
|
||||
}
|
||||
let src_w = src_w - padding;
|
||||
let src_idx = src_idx + src_w * src_s3;
|
||||
let dst_idx = dst_idx + w_k_idx;
|
||||
dst[dst_idx] = src[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (b * h_out * w_out, c * h_k * w_k).into()))
|
||||
}
|
||||
}
|
||||
|
||||
// Conv1d example as used in whisper.
|
||||
struct Conv1d;
|
||||
impl Benchmark for Conv1d {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
|
||||
Ok((inp, w))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv1d(&d.1, 0, 1, 1, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 5;
|
||||
}
|
||||
|
||||
// Conv2d example as used in stable-diffusion.
|
||||
struct Conv2d;
|
||||
impl Benchmark for Conv2d {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
Ok((inp, w))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv2d(&d.1, 0, 1, 1, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 5;
|
||||
}
|
||||
|
||||
// Conv2d example as used in stable-diffusion, im2col implementation.
|
||||
struct Conv2dIm2Col;
|
||||
impl Benchmark for Conv2dIm2Col {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
Ok((inp, w))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
// d.0.conv2d(&d.1, 0, 1, 1, 1)
|
||||
let (b, _, h, w) = d.0.dims4()?;
|
||||
let (_, _, h_k, w_k) = d.1.dims4()?;
|
||||
let op = Im2Col {
|
||||
h_k,
|
||||
w_k,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
padding: 0,
|
||||
};
|
||||
let (h_out, w_out) = op.hw_out(h, w);
|
||||
let col = d.0.apply_op1_no_bwd(&op)?;
|
||||
let res = col.matmul(&d.1.flatten_from(1)?.t()?)?;
|
||||
let res = res
|
||||
.reshape((b, h_out, w_out, ()))?
|
||||
.permute((0, 3, 1, 2))?
|
||||
.contiguous()?;
|
||||
if CHECK_CONV2D {
|
||||
let res2 = d.0.conv2d(&d.1, op.padding, op.stride, op.dilation, 1);
|
||||
let diff = (&res - res2)?.sqr()?.mean_all()?;
|
||||
println!("{diff}");
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
const ITERS: usize = 5;
|
||||
}
|
||||
|
||||
struct Matmul;
|
||||
impl Benchmark for Matmul {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.matmul(&d.1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
// This benchmark is similar to:
|
||||
// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp
|
||||
struct QMatMul;
|
||||
impl Benchmark for QMatMul {
|
||||
type PreProcessData = (candle::quantized::QMatMul, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
||||
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
|
||||
let mm = candle::quantized::QMatMul::from_qtensor(mm);
|
||||
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
||||
Ok((mm, arg))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.forward(&d.1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
struct Softmax;
|
||||
impl Benchmark for Softmax {
|
||||
type PreProcessData = Tensor;
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
// Typical whisper tiny size.
|
||||
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
candle_nn::ops::softmax(d, D::Minus1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
struct SoftmaxLastDim;
|
||||
impl Benchmark for SoftmaxLastDim {
|
||||
type PreProcessData = Tensor;
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
// Typical whisper tiny size.
|
||||
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
candle_nn::ops::softmax_last_dim(d)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
|
||||
use std::hint::black_box;
|
||||
|
||||
let iters = iters.unwrap_or(B::ITERS);
|
||||
let d = B::preprocess()?;
|
||||
let start = std::time::Instant::now();
|
||||
for _iter in 0..iters {
|
||||
let _res = black_box(B::run_one(black_box(&d))?);
|
||||
}
|
||||
println!("{:?}", start.elapsed() / iters as u32);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Task {
|
||||
Conv1d,
|
||||
Conv2d,
|
||||
Conv2dIm2Col,
|
||||
Matmul,
|
||||
Qmatmul,
|
||||
Softmax,
|
||||
SoftmaxLastDim,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// The benchmark to be run.
|
||||
#[command(subcommand)]
|
||||
task: Task,
|
||||
|
||||
#[arg(long)]
|
||||
iters: Option<usize>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
match args.task {
|
||||
Task::Conv1d => run::<Conv1d>(args.iters)?,
|
||||
Task::Conv2d => run::<Conv2d>(args.iters)?,
|
||||
Task::Conv2dIm2Col => run::<Conv2dIm2Col>(args.iters)?,
|
||||
Task::Matmul => run::<Matmul>(args.iters)?,
|
||||
Task::Softmax => run::<Softmax>(args.iters)?,
|
||||
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
|
||||
Task::Qmatmul => run::<QMatMul>(args.iters)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,18 +1,29 @@
|
||||
use candle::Tensor;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Activation {
|
||||
#[default]
|
||||
Gelu,
|
||||
#[serde(rename = "gated-gelu")]
|
||||
NewGelu,
|
||||
Relu,
|
||||
Elu(f64),
|
||||
LeakyRelu(f64),
|
||||
}
|
||||
|
||||
impl super::Module for Activation {
|
||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu => xs.gelu(),
|
||||
// TODO: This is "gelu_new", not the original "gelu".
|
||||
// There's some small numerical difference:
|
||||
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
|
||||
Self::NewGelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
&Self::Elu(alpha) => xs.elu(alpha),
|
||||
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -38,7 +38,7 @@ impl From<f64> for BatchNormConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BatchNorm {
|
||||
running_mean: Tensor,
|
||||
running_var: Tensor,
|
||||
|
@ -20,7 +20,7 @@ impl Default for Conv1dConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Conv1d {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
@ -39,6 +39,14 @@ impl Conv1d {
|
||||
pub fn config(&self) -> &Conv1dConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn weight(&self) -> &Tensor {
|
||||
&self.weight
|
||||
}
|
||||
|
||||
pub fn bias(&self) -> Option<&Tensor> {
|
||||
self.bias.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for Conv1d {
|
||||
@ -80,8 +88,7 @@ impl Default for Conv2dConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Conv2d {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
@ -100,6 +107,14 @@ impl Conv2d {
|
||||
pub fn config(&self) -> &Conv2dConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn weight(&self) -> &Tensor {
|
||||
&self.weight
|
||||
}
|
||||
|
||||
pub fn bias(&self) -> Option<&Tensor> {
|
||||
self.bias.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for Conv2d {
|
||||
@ -122,15 +137,76 @@ impl crate::Module for Conv2d {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct ConvTranspose2dConfig {
|
||||
pub padding: usize,
|
||||
pub output_padding: usize,
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
// TODO: support groups.
|
||||
}
|
||||
|
||||
impl Default for ConvTranspose2dConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
padding: 0,
|
||||
output_padding: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConvTranspose2d {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
config: ConvTranspose2dConfig,
|
||||
}
|
||||
|
||||
impl ConvTranspose2d {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose2dConfig) -> Self {
|
||||
Self {
|
||||
weight,
|
||||
bias,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &ConvTranspose2dConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for ConvTranspose2d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv_transpose2d(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.output_padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => {
|
||||
let b = bias.dims1()?;
|
||||
let bias = bias.reshape((1, b, 1, 1))?;
|
||||
Ok(x.broadcast_add(&bias)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conv1d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: Conv1dConfig,
|
||||
vs: crate::VarBuilder,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_with_hints(
|
||||
let ws = vb.get_with_hints(
|
||||
(out_channels, in_channels / cfg.groups, kernel_size),
|
||||
"weight",
|
||||
init_ws,
|
||||
@ -140,7 +216,7 @@ pub fn conv1d(
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vs.get_with_hints(out_channels, "bias", init_bs)?;
|
||||
let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
|
||||
Ok(Conv1d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
@ -149,10 +225,10 @@ pub fn conv2d(
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: Conv2dConfig,
|
||||
vs: crate::VarBuilder,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<Conv2d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_with_hints(
|
||||
let ws = vb.get_with_hints(
|
||||
(
|
||||
out_channels,
|
||||
in_channels / cfg.groups,
|
||||
@ -167,7 +243,7 @@ pub fn conv2d(
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vs.get_with_hints(out_channels, "bias", init_bs)?;
|
||||
let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
|
||||
Ok(Conv2d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
@ -176,10 +252,10 @@ pub fn conv2d_no_bias(
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: Conv2dConfig,
|
||||
vs: crate::VarBuilder,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<Conv2d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_with_hints(
|
||||
let ws = vb.get_with_hints(
|
||||
(
|
||||
out_channels,
|
||||
in_channels / cfg.groups,
|
||||
@ -191,3 +267,44 @@ pub fn conv2d_no_bias(
|
||||
)?;
|
||||
Ok(Conv2d::new(ws, None, cfg))
|
||||
}
|
||||
|
||||
pub fn conv_transpose2d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: ConvTranspose2dConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<ConvTranspose2d> {
|
||||
let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
|
||||
let init = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let ws = vb.get_with_hints(
|
||||
(in_channels, out_channels, kernel_size, kernel_size),
|
||||
"weight",
|
||||
init,
|
||||
)?;
|
||||
let bs = vb.get_with_hints(out_channels, "bias", init)?;
|
||||
Ok(ConvTranspose2d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
pub fn conv_transpose2d_no_bias(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: ConvTranspose2dConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<ConvTranspose2d> {
|
||||
let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
|
||||
let init = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let ws = vb.get_with_hints(
|
||||
(in_channels, out_channels, kernel_size, kernel_size),
|
||||
"weight",
|
||||
init,
|
||||
)?;
|
||||
Ok(ConvTranspose2d::new(ws, None, cfg))
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
//! Embedding Layer.
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Embedding {
|
||||
embeddings: Tensor,
|
||||
hidden_size: usize,
|
||||
@ -18,6 +18,11 @@ impl Embedding {
|
||||
pub fn embeddings(&self) -> &Tensor {
|
||||
&self.embeddings
|
||||
}
|
||||
|
||||
/// Get the hidden size of the embedding matrix
|
||||
pub fn hidden_size(&self) -> usize {
|
||||
self.hidden_size
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for Embedding {
|
||||
|
@ -4,7 +4,7 @@
|
||||
use candle::{DType, Result, Tensor};
|
||||
|
||||
// This group norm version handles both weight and bias so removes the mean.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GroupNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
|
@ -28,7 +28,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
|
||||
use candle::{DType, Result, Tensor};
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct LayerNormConfig {
|
||||
@ -60,7 +60,7 @@ impl From<f64> for LayerNormConfig {
|
||||
}
|
||||
|
||||
// This layer norm version handles both weight and bias so removes the mean.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
@ -104,15 +104,15 @@ impl crate::Module for LayerNorm {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let x = if self.remove_mean {
|
||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
x.broadcast_sub(&mean_x)?
|
||||
} else {
|
||||
x
|
||||
};
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
|
||||
match &self.bias {
|
||||
@ -143,7 +143,7 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
|
||||
}
|
||||
|
||||
/// RmsNorm is a specialized version of the LayerNorm module.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RmsNorm(LayerNorm);
|
||||
|
||||
impl RmsNorm {
|
||||
|