diff --git a/.gitignore b/.gitignore
index 2748d37e..d0a8c320 100644
--- a/.gitignore
+++ b/.gitignore
@@ -23,6 +23,7 @@ flamegraph.svg
*.dylib
*.so
*.swp
+*.swo
trace-*.json
candle-wasm-examples/*/build
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a52429cf..df9574d5 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,13 +1,58 @@
# Changelog
This documents the main changes to the `candle` crate.
-## v0.2.1 - Unreleased
+## v0.2.3 - Unreleased
### Added
+### Modified
+
+## v0.2.2 - 2023-09-18
+
+### Added
+- Support for `top_p` sampling
+ [819](https://github.com/huggingface/candle/pull/819).
+- T5 model including decoding
+ [864](https://github.com/huggingface/candle/pull/864).
+- 1-d upsampling
+ [839](https://github.com/huggingface/candle/pull/839).
+
+### Modified
+- Bugfix for conv2d
+ [820](https://github.com/huggingface/candle/pull/820).
+- Support tensor based indexing using `.i`
+ [842](https://github.com/huggingface/candle/pull/842).
+
+## v0.2.1 - 2023-09-11
+
+### Added
+- Add some RNNs (GRU and LSTM) in `candle-nn`
+ [674](https://github.com/huggingface/candle/pull/674),
+ [688](https://github.com/huggingface/candle/pull/688).
+- gguf v2 support
+ [725](https://github.com/huggingface/candle/pull/725).
+- Quantized llama example in Python using the pyo3 api
+ [716](https://github.com/huggingface/candle/pull/716).
+- `candle-nn` layer for conv2d-transposed
+ [760](https://github.com/huggingface/candle/pull/760).
+- Add the Segment-Anything Model (SAM) as an example
+ [773](https://github.com/huggingface/candle/pull/773).
+- TinyViT backbone for the segemnt anything example
+ [787](https://github.com/huggingface/candle/pull/787).
+- Shape with holes support
+ [770](https://github.com/huggingface/candle/pull/770).
+
### Modified
- Dilations are now supported in conv-transpose2d.
[671](https://github.com/huggingface/candle/pull/671).
+- Interactive mode for the quantized model
+ [690](https://github.com/huggingface/candle/pull/690).
+- Faster softmax operation
+ [747](https://github.com/huggingface/candle/pull/747).
+- Faster convolution operations on CPU and CUDA via im2col
+ [802](https://github.com/huggingface/candle/pull/802).
+- Moving some models to a more central location
+ [796](https://github.com/huggingface/candle/pull/796).
## v0.2.0 - 2023-08-30
diff --git a/Cargo.toml b/Cargo.toml
index ce41876a..6cbbf00f 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -8,17 +8,16 @@ members = [
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/llama2-c",
+ "candle-wasm-examples/segment-anything",
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
+ "candle-wasm-examples/bert",
]
-exclude = [
- "candle-flash-attn",
- "candle-kernels",
-]
+exclude = ["candle-flash-attn", "candle-kernels"]
resolver = "2"
[workspace.package]
-version = "0.2.1"
+version = "0.2.3"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@@ -33,7 +32,7 @@ byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.14", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
-gemm = { version = "0.15.6", package = "candle-gemm" }
+gemm = { version = "0.16.0", package = "candle-gemm" }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
diff --git a/README.md b/README.md
index 140382c7..93a47082 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,9 @@ Candle is a minimalist ML framework for Rust with a focus on performance (includ
and ease of use. Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
-[yolo](https://huggingface.co/spaces/lmz/candle-yolo).
+[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
+[Segment
+Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
## Get started
@@ -45,37 +47,54 @@ For more advanced examples, please have a look at the following section.
## Check out our examples
-Check out our [examples](./candle-examples/examples/):
+These online demos run entirely in your browser:
+- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
+ object recognition.
+- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
+- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
+- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
+
+We also provide a some command line based examples using state of the art models:
-- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
-- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
generation.
-- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
- image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
-- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
- using self-supervision (can be used for imagenet classification, depth
- evaluation, segmentation).
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
+
+
+
+- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
+ image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
+
+
+
+- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
+ image generative model.
+
+
+
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
estimation models.
-Run them using the following commands:
+
+
+- [segment-anything](./candle-examples/examples/segment-anything/): image
+ segmentation model with prompt.
+
+
+
+- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
+- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
+- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
+ using self-supervision (can be used for imagenet classification, depth
+ evaluation, segmentation).
+
+Run them using commands like:
```
-cargo run --example whisper --release
-cargo run --example llama --release
-cargo run --example falcon --release
-cargo run --example bert --release
-cargo run --example bigcode --release
-cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch"
-cargo run --example dinov2 --release -- --image path/to/myinput.jpg
cargo run --example quantized --release
-cargo run --example yolo-v3 --release -- myimage.jpg
-cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose
```
In order to use **CUDA** add `--features cuda` to the example command line. If
@@ -85,7 +104,8 @@ There are also some wasm examples for whisper and
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
`trunk` or try them online:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
-[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
+[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
+[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
For LLaMA2, run the following command to retrieve the weight files and start a
test server:
@@ -98,6 +118,15 @@ trunk serve --release --port 8081
And then head over to
[http://localhost:8081/](http://localhost:8081/).
+
+
+## Useful Libraries
+- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
+
+If you have an addition to this list, please submit a pull request.
+
+
+
## Features
@@ -110,10 +139,21 @@ And then head over to
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
- WASM support, run your models in a browser.
- Included models.
- - LLMs: LLaMA v1 and v2, Falcon, StarCoder.
+ - Language Models.
+ - LLaMA v1 and v2.
+ - Falcon.
+ - StarCoder.
+ - T5.
+ - Bert.
- Whisper (multi-lingual support).
- - Stable Diffusion.
- - Computer Vision: DINOv2, EfficientNet, yolo-v3, yolo-v8.
+ - Stable Diffusion v1.5, v2.1, XL v1.0.
+ - Wurstchen v2.
+ - Computer Vision Models.
+ - DINOv2.
+ - EfficientNet.
+ - yolo-v3.
+ - yolo-v8.
+ - Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types.
@@ -243,6 +283,35 @@ authentication token. See issue
git submodule update --init
```
+#### Compiling with flash-attention fails
+
+```
+/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
+```
+
+This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
+```
+env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
+```
+
+#### Linking error on windows when running rustdoc or mdbook tests
+
+```
+Couldn't compile the test.
+---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
+error: linking with `link.exe` failed: exit code: 1181
+//very long chain of linking
+ = note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'
+```
+
+Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:
+
+```
+mdbook test candle-book -L .\target\debug\deps\ `
+-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
+-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
+```
+
#### Tracking down errors
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml
index 320fb887..8ec92e87 100644
--- a/candle-book/Cargo.toml
+++ b/candle-book/Cargo.toml
@@ -11,11 +11,11 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
-candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
-candle-nn = { path = "../candle-nn", version = "0.2.1" }
-candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
-candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
+candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
+candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
+candle-nn = { path = "../candle-nn", version = "0.2.3" }
+candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
+candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md
index 1d05568a..59831af2 100644
--- a/candle-book/src/SUMMARY.md
+++ b/candle-book/src/SUMMARY.md
@@ -10,10 +10,10 @@
# Reference Guide
-- [Running a model](inference/README.md)
+- [Running a model](inference/inference.md)
- [Using the hub](inference/hub.md)
- [Error management](error_manage.md)
-- [Training](training/README.md)
+- [Training](training/training.md)
- [Simplified](training/simplified.md)
- [MNIST](training/mnist.md)
- [Fine-tuning]()
diff --git a/candle-book/src/error_manage.md b/candle-book/src/error_manage.md
index c1a16bd9..0623e0e3 100644
--- a/candle-book/src/error_manage.md
+++ b/candle-book/src/error_manage.md
@@ -29,7 +29,7 @@ After adding `RUST_BACKTRACE=1`:
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls:: for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
```
-Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
+Not super pretty at the moment, but we can see error occurred on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
diff --git a/candle-book/src/guide/hello_world.md b/candle-book/src/guide/hello_world.md
index fc4af0e1..b5b8d7b4 100644
--- a/candle-book/src/guide/hello_world.md
+++ b/candle-book/src/guide/hello_world.md
@@ -6,7 +6,7 @@ Open `src/main.rs` and fill in this content:
```rust
# extern crate candle_core;
-use candle_core::{DType, Device, Result, Tensor};
+use candle_core::{Device, Result, Tensor};
struct Model {
first: Tensor,
@@ -25,11 +25,11 @@ fn main() -> Result<()> {
// Use Device::new_cuda(0)?; to use the GPU.
let device = Device::Cpu;
- let first = Tensor::zeros((784, 100), DType::F32, &device)?;
- let second = Tensor::zeros((100, 10), DType::F32, &device)?;
+ let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
+ let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let model = Model { first, second };
- let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
+ let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
@@ -50,7 +50,7 @@ the classical `Linear` layer. We can do as such
```rust
# extern crate candle_core;
-# use candle_core::{DType, Device, Result, Tensor};
+# use candle_core::{Device, Result, Tensor};
struct Linear{
weight: Tensor,
bias: Tensor,
@@ -80,7 +80,7 @@ This will change the model running code into a new function
```rust
# extern crate candle_core;
-# use candle_core::{DType, Device, Result, Tensor};
+# use candle_core::{Device, Result, Tensor};
# struct Linear{
# weight: Tensor,
# bias: Tensor,
@@ -110,15 +110,15 @@ fn main() -> Result<()> {
let device = Device::cuda_if_available(0)?;
// Creating a dummy model
- let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
- let bias = Tensor::zeros((100, ), DType::F32, &device)?;
+ let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
+ let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear{weight, bias};
- let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
- let bias = Tensor::zeros((10, ), DType::F32, &device)?;
+ let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
+ let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear{weight, bias};
let model = Model { first, second };
- let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
+ let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
// Inference on the model
let digit = model.forward(&dummy_image)?;
@@ -146,7 +146,7 @@ And rewrite our examples using it
```rust
# extern crate candle_core;
# extern crate candle_nn;
-use candle_core::{DType, Device, Result, Tensor};
+use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module};
struct Model {
@@ -167,15 +167,15 @@ fn main() -> Result<()> {
let device = Device::Cpu;
// This has changed (784, 100) -> (100, 784) !
- let weight = Tensor::zeros((100, 784), DType::F32, &device)?;
- let bias = Tensor::zeros((100, ), DType::F32, &device)?;
+ let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
+ let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear::new(weight, Some(bias));
- let weight = Tensor::zeros((10, 100), DType::F32, &device)?;
- let bias = Tensor::zeros((10, ), DType::F32, &device)?;
+ let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
+ let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear::new(weight, Some(bias));
let model = Model { first, second };
- let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
+ let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
@@ -188,8 +188,8 @@ Feel free to modify this example to use `Conv2d` to create a classical convnet i
Now that we have the running dummy code we can get to more advanced topics:
-- [For PyTorch users](./guide/cheatsheet.md)
-- [Running existing models](./inference/README.md)
-- [Training models](./training/README.md)
+- [For PyTorch users](../guide/cheatsheet.md)
+- [Running existing models](../inference/inference.md)
+- [Training models](../training/training.md)
diff --git a/candle-book/src/inference/README.md b/candle-book/src/inference/inference.md
similarity index 100%
rename from candle-book/src/inference/README.md
rename to candle-book/src/inference/inference.md
diff --git a/candle-book/src/training/README.md b/candle-book/src/training/training.md
similarity index 100%
rename from candle-book/src/training/README.md
rename to candle-book/src/training/training.md
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index e7213919..7af9b6fa 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -12,7 +12,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
-candle-kernels = { path = "../candle-kernels", version = "0.2.1", optional = true }
+candle-kernels = { path = "../candle-kernels", version = "0.2.3", optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
diff --git a/candle-core/examples/cpu_benchmarks.rs b/candle-core/examples/cpu_benchmarks.rs
deleted file mode 100644
index 13175ac1..00000000
--- a/candle-core/examples/cpu_benchmarks.rs
+++ /dev/null
@@ -1,166 +0,0 @@
-/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
-#[cfg(feature = "mkl")]
-extern crate intel_mkl_src;
-
-#[cfg(feature = "accelerate")]
-extern crate accelerate_src;
-
-use candle_core::quantized::GgmlType;
-use candle_core::{Device, Result, Tensor, D};
-use clap::{Parser, Subcommand};
-
-fn softmax(xs: &Tensor, dim: D) -> Result {
- let dim = dim.to_index(xs.shape(), "softmax")?;
- let max = xs.max_keepdim(dim)?;
- let diff = xs.broadcast_sub(&max)?;
- let num = diff.exp()?;
- let den = num.sum_keepdim(dim)?;
- num.broadcast_div(&den)
-}
-
-trait Benchmark {
- type PreProcessData;
- type RunResult;
-
- fn preprocess() -> Result;
- fn run_one(_: &Self::PreProcessData) -> Result;
-
- const ITERS: usize;
-}
-
-// Conv1d example as used in whisper.
-struct Conv1d;
-impl Benchmark for Conv1d {
- type PreProcessData = (Tensor, Tensor);
- type RunResult = Tensor;
- fn preprocess() -> Result {
- let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
- let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
- Ok((inp, w))
- }
-
- fn run_one(d: &Self::PreProcessData) -> Result {
- d.0.conv1d(&d.1, 0, 1, 1, 1)
- }
-
- const ITERS: usize = 5;
-}
-
-// Conv2d example as used in stable-diffusion.
-struct Conv2d;
-impl Benchmark for Conv2d {
- type PreProcessData = (Tensor, Tensor);
- type RunResult = Tensor;
-
- fn preprocess() -> Result {
- let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
- let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
- Ok((inp, w))
- }
-
- fn run_one(d: &Self::PreProcessData) -> Result {
- d.0.conv2d(&d.1, 0, 1, 1, 1)
- }
-
- const ITERS: usize = 1;
-}
-
-struct Matmul;
-impl Benchmark for Matmul {
- type PreProcessData = (Tensor, Tensor);
- type RunResult = Tensor;
- fn preprocess() -> Result {
- let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
- let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
- Ok((lhs, rhs))
- }
-
- fn run_one(d: &Self::PreProcessData) -> Result {
- d.0.matmul(&d.1)
- }
-
- const ITERS: usize = 100;
-}
-
-// This benchmark is similar to:
-// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp
-struct QMatMul;
-impl Benchmark for QMatMul {
- type PreProcessData = (candle_core::quantized::QMatMul, Tensor);
- type RunResult = Tensor;
- fn preprocess() -> Result {
- let zeros = vec![candle_core::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
- let mm = candle_core::quantized::QTensor::new(zeros, (4096, 11008))?;
- let mm = candle_core::quantized::QMatMul::from_qtensor(mm);
- let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
- Ok((mm, arg))
- }
-
- fn run_one(d: &Self::PreProcessData) -> Result {
- d.0.forward(&d.1)
- }
-
- const ITERS: usize = 100;
-}
-
-struct Softmax;
-impl Benchmark for Softmax {
- type PreProcessData = Tensor;
- type RunResult = Tensor;
- fn preprocess() -> Result {
- // Typical whisper tiny size.
- let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
- Ok(x)
- }
-
- fn run_one(d: &Self::PreProcessData) -> Result {
- softmax(d, D::Minus1)
- }
-
- const ITERS: usize = 100;
-}
-
-fn run(iters: Option) -> Result<()> {
- use std::hint::black_box;
-
- let iters = iters.unwrap_or(B::ITERS);
- let d = B::preprocess()?;
- let start = std::time::Instant::now();
- for _iter in 0..iters {
- let _res = black_box(B::run_one(black_box(&d))?);
- }
- println!("{:?}", start.elapsed() / iters as u32);
- Ok(())
-}
-
-#[derive(Subcommand, Debug, Clone)]
-enum Task {
- Conv1d,
- Conv2d,
- Matmul,
- Qmatmul,
- Softmax,
-}
-
-#[derive(Parser, Debug)]
-#[command(author, version, about, long_about = None)]
-pub struct Args {
- /// The benchmark to be run.
- #[command(subcommand)]
- task: Task,
-
- #[arg(long)]
- iters: Option,
-}
-
-fn main() -> Result<()> {
- let args = Args::parse();
- match args.task {
- Task::Conv1d => run::(args.iters)?,
- Task::Conv2d => run::(args.iters)?,
- Task::Matmul => run::(args.iters)?,
- Task::Softmax => run::(args.iters)?,
- Task::Qmatmul => run::(args.iters)?,
- }
- Ok(())
-}
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index 2bc1fa2e..c3459004 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -218,12 +218,65 @@ fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> R
Ok(())
}
+fn run_quantize_safetensors(
+ in_file: std::path::PathBuf,
+ out_file: std::path::PathBuf,
+ q: Quantization,
+) -> Result<()> {
+ let mut out_file = std::fs::File::create(out_file)?;
+ let tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
+ println!("tensors: {}", tensors.len());
+
+ let quantize_fn = match q {
+ Quantization::Q4_0 => QTensor::quantize::,
+ Quantization::Q4_1 => QTensor::quantize::,
+ Quantization::Q5_0 => QTensor::quantize::,
+ Quantization::Q5_1 => QTensor::quantize::,
+ Quantization::Q8_0 => QTensor::quantize::,
+ Quantization::Q8_1 => QTensor::quantize::,
+ Quantization::Q2k => QTensor::quantize::,
+ Quantization::Q3k => QTensor::quantize::,
+ Quantization::Q4k => QTensor::quantize::,
+ Quantization::Q5k => QTensor::quantize::,
+ Quantization::Q6k => QTensor::quantize::,
+ Quantization::Q8k => QTensor::quantize::,
+ Quantization::F16 => QTensor::quantize::,
+ Quantization::F32 => QTensor::quantize::,
+ };
+
+ let qtensors = tensors
+ .into_par_iter()
+ .map(|(name, tensor)| {
+ println!(" quantizing {name} {tensor:?}");
+ let should_quantize = tensor.rank() == 2 && tensor.dim(0)? % 256 == 0;
+ let tensor = if should_quantize {
+ quantize_fn(&tensor)?
+ } else {
+ QTensor::quantize::(&tensor)?
+ };
+ Ok((name, tensor))
+ })
+ .collect::>>()?;
+ let qtensors = qtensors
+ .iter()
+ .map(|(k, v)| (k.as_str(), v))
+ .collect::>();
+ gguf_file::write(&mut out_file, &[], &qtensors)?;
+ Ok(())
+}
+
fn run_quantize(
in_file: std::path::PathBuf,
out_file: std::path::PathBuf,
q: Quantization,
qmode: QuantizationMode,
) -> Result<()> {
+ if let Some(extension) = in_file.extension() {
+ if extension == "safetensors" {
+ return run_quantize_safetensors(in_file, out_file, q);
+ }
+ }
+
// Open the out file early so as to fail directly on missing directories etc.
let mut out_file = std::fs::File::create(out_file)?;
let mut in_ = std::fs::File::open(&in_file)?;
diff --git a/candle-core/src/accelerate.rs b/candle-core/src/accelerate.rs
index 87e0ee8d..1cb34e19 100644
--- a/candle-core/src/accelerate.rs
+++ b/candle-core/src/accelerate.rs
@@ -370,6 +370,38 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
}
+#[inline]
+pub fn vs_tanh_inplace(y: &mut [f32]) {
+ unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
+}
+
+#[inline]
+pub fn vd_tanh_inplace(y: &mut [f64]) {
+ unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
+}
+
+#[inline]
+pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
+ }
+ vs_tanh_inplace(ys);
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = 0.5 * v * (1.0 + *y)
+ }
+}
+
+#[inline]
+pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
+ }
+ vd_tanh_inplace(ys);
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = 0.5 * v * (1.0 + *y)
+ }
+}
+
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline]
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 67a08714..03a07434 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -57,6 +57,7 @@ pub trait BackendStorage: Sized {
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result;
+ fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result;
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index d2099df7..a2548198 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -91,13 +91,14 @@ impl Tensor {
}
}
Op::Reshape(node)
+ | Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
- | Op::Reduce(node, _, _)
+ | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@@ -111,6 +112,7 @@ impl Tensor {
track_grad |= tg;
nodes
}
+ Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
nodes
@@ -262,6 +264,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
+ Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
+ op: "upsample-nearest1d",
+ })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
@@ -437,6 +442,10 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
+ Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
+ Op::Unary(_, UnaryOp::GeluErf) => {
+ Err(Error::BackwardNotSupported { op: "gelu-erf" })?
+ }
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
@@ -517,6 +526,7 @@ impl Tensor {
}
}
+#[derive(Debug)]
pub struct GradStore(HashMap);
impl GradStore {
diff --git a/candle-core/src/cpu/erf.rs b/candle-core/src/cpu/erf.rs
new file mode 100644
index 00000000..ca6be53f
--- /dev/null
+++ b/candle-core/src/cpu/erf.rs
@@ -0,0 +1,763 @@
+#![allow(clippy::excessive_precision)]
+// Code taken from https://github.com/statrs-dev/statrs
+//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
+//! related functions
+
+mod evaluate {
+ //! Provides functions that don't have a numerical solution and must
+ //! be solved computationally (e.g. evaluation of a polynomial)
+
+ /// evaluates a polynomial at `z` where `coeff` are the coeffecients
+ /// to a polynomial of order `k` where `k` is the length of `coeff` and the
+ /// coeffecient
+ /// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
+ /// `2z^2 - z + 3`
+ ///
+ /// # Remarks
+ ///
+ /// Returns 0 for a 0 length coefficient slice
+ pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
+ let n = coeff.len();
+ if n == 0 {
+ return 0.0;
+ }
+
+ let mut sum = *coeff.last().unwrap();
+ for c in coeff[0..n - 1].iter().rev() {
+ sum = *c + z * sum;
+ }
+ sum
+ }
+}
+use std::f64;
+
+/// `erf` calculates the error function at `x`.
+pub fn erf(x: f64) -> f64 {
+ if x.is_nan() {
+ f64::NAN
+ } else if x >= 0.0 && x.is_infinite() {
+ 1.0
+ } else if x <= 0.0 && x.is_infinite() {
+ -1.0
+ } else if x == 0. {
+ 0.0
+ } else {
+ erf_impl(x, false)
+ }
+}
+
+/// `erf_inv` calculates the inverse error function
+/// at `x`.
+pub fn erf_inv(x: f64) -> f64 {
+ if x == 0.0 {
+ 0.0
+ } else if x >= 1.0 {
+ f64::INFINITY
+ } else if x <= -1.0 {
+ f64::NEG_INFINITY
+ } else if x < 0.0 {
+ erf_inv_impl(-x, 1.0 + x, -1.0)
+ } else {
+ erf_inv_impl(x, 1.0 - x, 1.0)
+ }
+}
+
+/// `erfc` calculates the complementary error function
+/// at `x`.
+pub fn erfc(x: f64) -> f64 {
+ if x.is_nan() {
+ f64::NAN
+ } else if x == f64::INFINITY {
+ 0.0
+ } else if x == f64::NEG_INFINITY {
+ 2.0
+ } else {
+ erf_impl(x, true)
+ }
+}
+
+/// `erfc_inv` calculates the complementary inverse
+/// error function at `x`.
+pub fn erfc_inv(x: f64) -> f64 {
+ if x <= 0.0 {
+ f64::INFINITY
+ } else if x >= 2.0 {
+ f64::NEG_INFINITY
+ } else if x > 1.0 {
+ erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
+ } else {
+ erf_inv_impl(1.0 - x, x, 1.0)
+ }
+}
+
+// **********************************************************
+// ********** Coefficients for erf_impl polynomial **********
+// **********************************************************
+
+/// Polynomial coefficients for a numerator of `erf_impl`
+/// in the interval [1e-10, 0.5].
+const ERF_IMPL_AN: &[f64] = &[
+ 0.00337916709551257388990745,
+ -0.00073695653048167948530905,
+ -0.374732337392919607868241,
+ 0.0817442448733587196071743,
+ -0.0421089319936548595203468,
+ 0.0070165709512095756344528,
+ -0.00495091255982435110337458,
+ 0.000871646599037922480317225,
+];
+
+/// Polynomial coefficients for a denominator of `erf_impl`
+/// in the interval [1e-10, 0.5]
+const ERF_IMPL_AD: &[f64] = &[
+ 1.0,
+ -0.218088218087924645390535,
+ 0.412542972725442099083918,
+ -0.0841891147873106755410271,
+ 0.0655338856400241519690695,
+ -0.0120019604454941768171266,
+ 0.00408165558926174048329689,
+ -0.000615900721557769691924509,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [0.5, 0.75].
+const ERF_IMPL_BN: &[f64] = &[
+ -0.0361790390718262471360258,
+ 0.292251883444882683221149,
+ 0.281447041797604512774415,
+ 0.125610208862766947294894,
+ 0.0274135028268930549240776,
+ 0.00250839672168065762786937,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [0.5, 0.75].
+const ERF_IMPL_BD: &[f64] = &[
+ 1.0,
+ 1.8545005897903486499845,
+ 1.43575803037831418074962,
+ 0.582827658753036572454135,
+ 0.124810476932949746447682,
+ 0.0113724176546353285778481,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [0.75, 1.25].
+const ERF_IMPL_CN: &[f64] = &[
+ -0.0397876892611136856954425,
+ 0.153165212467878293257683,
+ 0.191260295600936245503129,
+ 0.10276327061989304213645,
+ 0.029637090615738836726027,
+ 0.0046093486780275489468812,
+ 0.000307607820348680180548455,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [0.75, 1.25].
+const ERF_IMPL_CD: &[f64] = &[
+ 1.0,
+ 1.95520072987627704987886,
+ 1.64762317199384860109595,
+ 0.768238607022126250082483,
+ 0.209793185936509782784315,
+ 0.0319569316899913392596356,
+ 0.00213363160895785378615014,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [1.25, 2.25].
+const ERF_IMPL_DN: &[f64] = &[
+ -0.0300838560557949717328341,
+ 0.0538578829844454508530552,
+ 0.0726211541651914182692959,
+ 0.0367628469888049348429018,
+ 0.00964629015572527529605267,
+ 0.00133453480075291076745275,
+ 0.778087599782504251917881e-4,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [1.25, 2.25].
+const ERF_IMPL_DD: &[f64] = &[
+ 1.0,
+ 1.75967098147167528287343,
+ 1.32883571437961120556307,
+ 0.552528596508757581287907,
+ 0.133793056941332861912279,
+ 0.0179509645176280768640766,
+ 0.00104712440019937356634038,
+ -0.106640381820357337177643e-7,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [2.25, 3.5].
+const ERF_IMPL_EN: &[f64] = &[
+ -0.0117907570137227847827732,
+ 0.014262132090538809896674,
+ 0.0202234435902960820020765,
+ 0.00930668299990432009042239,
+ 0.00213357802422065994322516,
+ 0.00025022987386460102395382,
+ 0.120534912219588189822126e-4,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [2.25, 3.5].
+const ERF_IMPL_ED: &[f64] = &[
+ 1.0,
+ 1.50376225203620482047419,
+ 0.965397786204462896346934,
+ 0.339265230476796681555511,
+ 0.0689740649541569716897427,
+ 0.00771060262491768307365526,
+ 0.000371421101531069302990367,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [3.5, 5.25].
+const ERF_IMPL_FN: &[f64] = &[
+ -0.00546954795538729307482955,
+ 0.00404190278731707110245394,
+ 0.0054963369553161170521356,
+ 0.00212616472603945399437862,
+ 0.000394984014495083900689956,
+ 0.365565477064442377259271e-4,
+ 0.135485897109932323253786e-5,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [3.5, 5.25].
+const ERF_IMPL_FD: &[f64] = &[
+ 1.0,
+ 1.21019697773630784832251,
+ 0.620914668221143886601045,
+ 0.173038430661142762569515,
+ 0.0276550813773432047594539,
+ 0.00240625974424309709745382,
+ 0.891811817251336577241006e-4,
+ -0.465528836283382684461025e-11,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [5.25, 8].
+const ERF_IMPL_GN: &[f64] = &[
+ -0.00270722535905778347999196,
+ 0.0013187563425029400461378,
+ 0.00119925933261002333923989,
+ 0.00027849619811344664248235,
+ 0.267822988218331849989363e-4,
+ 0.923043672315028197865066e-6,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [5.25, 8].
+const ERF_IMPL_GD: &[f64] = &[
+ 1.0,
+ 0.814632808543141591118279,
+ 0.268901665856299542168425,
+ 0.0449877216103041118694989,
+ 0.00381759663320248459168994,
+ 0.000131571897888596914350697,
+ 0.404815359675764138445257e-11,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [8, 11.5].
+const ERF_IMPL_HN: &[f64] = &[
+ -0.00109946720691742196814323,
+ 0.000406425442750422675169153,
+ 0.000274499489416900707787024,
+ 0.465293770646659383436343e-4,
+ 0.320955425395767463401993e-5,
+ 0.778286018145020892261936e-7,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [8, 11.5].
+const ERF_IMPL_HD: &[f64] = &[
+ 1.0,
+ 0.588173710611846046373373,
+ 0.139363331289409746077541,
+ 0.0166329340417083678763028,
+ 0.00100023921310234908642639,
+ 0.24254837521587225125068e-4,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [11.5, 17].
+const ERF_IMPL_IN: &[f64] = &[
+ -0.00056907993601094962855594,
+ 0.000169498540373762264416984,
+ 0.518472354581100890120501e-4,
+ 0.382819312231928859704678e-5,
+ 0.824989931281894431781794e-7,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [11.5, 17].
+const ERF_IMPL_ID: &[f64] = &[
+ 1.0,
+ 0.339637250051139347430323,
+ 0.043472647870310663055044,
+ 0.00248549335224637114641629,
+ 0.535633305337152900549536e-4,
+ -0.117490944405459578783846e-12,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [17, 24].
+const ERF_IMPL_JN: &[f64] = &[
+ -0.000241313599483991337479091,
+ 0.574224975202501512365975e-4,
+ 0.115998962927383778460557e-4,
+ 0.581762134402593739370875e-6,
+ 0.853971555085673614607418e-8,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [17, 24].
+const ERF_IMPL_JD: &[f64] = &[
+ 1.0,
+ 0.233044138299687841018015,
+ 0.0204186940546440312625597,
+ 0.000797185647564398289151125,
+ 0.117019281670172327758019e-4,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [24, 38].
+const ERF_IMPL_KN: &[f64] = &[
+ -0.000146674699277760365803642,
+ 0.162666552112280519955647e-4,
+ 0.269116248509165239294897e-5,
+ 0.979584479468091935086972e-7,
+ 0.101994647625723465722285e-8,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [24, 38].
+const ERF_IMPL_KD: &[f64] = &[
+ 1.0,
+ 0.165907812944847226546036,
+ 0.0103361716191505884359634,
+ 0.000286593026373868366935721,
+ 0.298401570840900340874568e-5,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [38, 60].
+const ERF_IMPL_LN: &[f64] = &[
+ -0.583905797629771786720406e-4,
+ 0.412510325105496173512992e-5,
+ 0.431790922420250949096906e-6,
+ 0.993365155590013193345569e-8,
+ 0.653480510020104699270084e-10,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [38, 60].
+const ERF_IMPL_LD: &[f64] = &[
+ 1.0,
+ 0.105077086072039915406159,
+ 0.00414278428675475620830226,
+ 0.726338754644523769144108e-4,
+ 0.477818471047398785369849e-6,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [60, 85].
+const ERF_IMPL_MN: &[f64] = &[
+ -0.196457797609229579459841e-4,
+ 0.157243887666800692441195e-5,
+ 0.543902511192700878690335e-7,
+ 0.317472492369117710852685e-9,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [60, 85].
+const ERF_IMPL_MD: &[f64] = &[
+ 1.0,
+ 0.052803989240957632204885,
+ 0.000926876069151753290378112,
+ 0.541011723226630257077328e-5,
+ 0.535093845803642394908747e-15,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [85, 110].
+const ERF_IMPL_NN: &[f64] = &[
+ -0.789224703978722689089794e-5,
+ 0.622088451660986955124162e-6,
+ 0.145728445676882396797184e-7,
+ 0.603715505542715364529243e-10,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [85, 110].
+const ERF_IMPL_ND: &[f64] = &[
+ 1.0,
+ 0.0375328846356293715248719,
+ 0.000467919535974625308126054,
+ 0.193847039275845656900547e-5,
+];
+
+// **********************************************************
+// ********** Coefficients for erf_inv_impl polynomial ******
+// **********************************************************
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0, 0.5].
+const ERF_INV_IMPL_AN: &[f64] = &[
+ -0.000508781949658280665617,
+ -0.00836874819741736770379,
+ 0.0334806625409744615033,
+ -0.0126926147662974029034,
+ -0.0365637971411762664006,
+ 0.0219878681111168899165,
+ 0.00822687874676915743155,
+ -0.00538772965071242932965,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0, 0.5].
+const ERF_INV_IMPL_AD: &[f64] = &[
+ 1.0,
+ -0.970005043303290640362,
+ -1.56574558234175846809,
+ 1.56221558398423026363,
+ 0.662328840472002992063,
+ -0.71228902341542847553,
+ -0.0527396382340099713954,
+ 0.0795283687341571680018,
+ -0.00233393759374190016776,
+ 0.000886216390456424707504,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.5, 0.75].
+const ERF_INV_IMPL_BN: &[f64] = &[
+ -0.202433508355938759655,
+ 0.105264680699391713268,
+ 8.37050328343119927838,
+ 17.6447298408374015486,
+ -18.8510648058714251895,
+ -44.6382324441786960818,
+ 17.445385985570866523,
+ 21.1294655448340526258,
+ -3.67192254707729348546,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.5, 0.75].
+const ERF_INV_IMPL_BD: &[f64] = &[
+ 1.0,
+ 6.24264124854247537712,
+ 3.9713437953343869095,
+ -28.6608180499800029974,
+ -20.1432634680485188801,
+ 48.5609213108739935468,
+ 10.8268667355460159008,
+ -22.6436933413139721736,
+ 1.72114765761200282724,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x less than 3.
+const ERF_INV_IMPL_CN: &[f64] = &[
+ -0.131102781679951906451,
+ -0.163794047193317060787,
+ 0.117030156341995252019,
+ 0.387079738972604337464,
+ 0.337785538912035898924,
+ 0.142869534408157156766,
+ 0.0290157910005329060432,
+ 0.00214558995388805277169,
+ -0.679465575181126350155e-6,
+ 0.285225331782217055858e-7,
+ -0.681149956853776992068e-9,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x less than 3.
+const ERF_INV_IMPL_CD: &[f64] = &[
+ 1.0,
+ 3.46625407242567245975,
+ 5.38168345707006855425,
+ 4.77846592945843778382,
+ 2.59301921623620271374,
+ 0.848854343457902036425,
+ 0.152264338295331783612,
+ 0.01105924229346489121,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 3 and 6.
+const ERF_INV_IMPL_DN: &[f64] = &[
+ -0.0350353787183177984712,
+ -0.00222426529213447927281,
+ 0.0185573306514231072324,
+ 0.00950804701325919603619,
+ 0.00187123492819559223345,
+ 0.000157544617424960554631,
+ 0.460469890584317994083e-5,
+ -0.230404776911882601748e-9,
+ 0.266339227425782031962e-11,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 3 and 6.
+const ERF_INV_IMPL_DD: &[f64] = &[
+ 1.0,
+ 1.3653349817554063097,
+ 0.762059164553623404043,
+ 0.220091105764131249824,
+ 0.0341589143670947727934,
+ 0.00263861676657015992959,
+ 0.764675292302794483503e-4,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 6 and 18.
+const ERF_INV_IMPL_EN: &[f64] = &[
+ -0.0167431005076633737133,
+ -0.00112951438745580278863,
+ 0.00105628862152492910091,
+ 0.000209386317487588078668,
+ 0.149624783758342370182e-4,
+ 0.449696789927706453732e-6,
+ 0.462596163522878599135e-8,
+ -0.281128735628831791805e-13,
+ 0.99055709973310326855e-16,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 6 and 18.
+const ERF_INV_IMPL_ED: &[f64] = &[
+ 1.0,
+ 0.591429344886417493481,
+ 0.138151865749083321638,
+ 0.0160746087093676504695,
+ 0.000964011807005165528527,
+ 0.275335474764726041141e-4,
+ 0.282243172016108031869e-6,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 18 and 44.
+const ERF_INV_IMPL_FN: &[f64] = &[
+ -0.0024978212791898131227,
+ -0.779190719229053954292e-5,
+ 0.254723037413027451751e-4,
+ 0.162397777342510920873e-5,
+ 0.396341011304801168516e-7,
+ 0.411632831190944208473e-9,
+ 0.145596286718675035587e-11,
+ -0.116765012397184275695e-17,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 18 and 44.
+const ERF_INV_IMPL_FD: &[f64] = &[
+ 1.0,
+ 0.207123112214422517181,
+ 0.0169410838120975906478,
+ 0.000690538265622684595676,
+ 0.145007359818232637924e-4,
+ 0.144437756628144157666e-6,
+ 0.509761276599778486139e-9,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x greater than 44.
+const ERF_INV_IMPL_GN: &[f64] = &[
+ -0.000539042911019078575891,
+ -0.28398759004727721098e-6,
+ 0.899465114892291446442e-6,
+ 0.229345859265920864296e-7,
+ 0.225561444863500149219e-9,
+ 0.947846627503022684216e-12,
+ 0.135880130108924861008e-14,
+ -0.348890393399948882918e-21,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x greater than 44.
+const ERF_INV_IMPL_GD: &[f64] = &[
+ 1.0,
+ 0.0845746234001899436914,
+ 0.00282092984726264681981,
+ 0.468292921940894236786e-4,
+ 0.399968812193862100054e-6,
+ 0.161809290887904476097e-8,
+ 0.231558608310259605225e-11,
+];
+
+/// `erf_impl` computes the error function at `z`.
+/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
+fn erf_impl(z: f64, inv: bool) -> f64 {
+ if z < 0.0 {
+ if !inv {
+ return -erf_impl(-z, false);
+ }
+ if z < -0.5 {
+ return 2.0 - erf_impl(-z, true);
+ }
+ return 1.0 + erf_impl(-z, false);
+ }
+
+ let result = if z < 0.5 {
+ if z < 1e-10 {
+ z * 1.125 + z * 0.003379167095512573896158903121545171688
+ } else {
+ z * 1.125
+ + z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
+ }
+ } else if z < 110.0 {
+ let (r, b) = if z < 0.75 {
+ (
+ evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
+ / evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
+ 0.3440242112,
+ )
+ } else if z < 1.25 {
+ (
+ evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
+ / evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
+ 0.419990927,
+ )
+ } else if z < 2.25 {
+ (
+ evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
+ / evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
+ 0.4898625016,
+ )
+ } else if z < 3.5 {
+ (
+ evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
+ / evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
+ 0.5317370892,
+ )
+ } else if z < 5.25 {
+ (
+ evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
+ / evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
+ 0.5489973426,
+ )
+ } else if z < 8.0 {
+ (
+ evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
+ / evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
+ 0.5571740866,
+ )
+ } else if z < 11.5 {
+ (
+ evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
+ / evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
+ 0.5609807968,
+ )
+ } else if z < 17.0 {
+ (
+ evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
+ / evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
+ 0.5626493692,
+ )
+ } else if z < 24.0 {
+ (
+ evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
+ / evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
+ 0.5634598136,
+ )
+ } else if z < 38.0 {
+ (
+ evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
+ / evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
+ 0.5638477802,
+ )
+ } else if z < 60.0 {
+ (
+ evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
+ / evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
+ 0.5640528202,
+ )
+ } else if z < 85.0 {
+ (
+ evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
+ / evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
+ 0.5641309023,
+ )
+ } else {
+ (
+ evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
+ / evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
+ 0.5641584396,
+ )
+ };
+ let g = (-z * z).exp() / z;
+ g * b + g * r
+ } else {
+ 0.0
+ };
+
+ if inv && z >= 0.5 {
+ result
+ } else if z >= 0.5 || inv {
+ 1.0 - result
+ } else {
+ result
+ }
+}
+
+// `erf_inv_impl` computes the inverse error function where
+// `p`,`q`, and `s` are the first, second, and third intermediate
+// parameters respectively
+fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
+ let result = if p <= 0.5 {
+ let y = 0.0891314744949340820313;
+ let g = p * (p + 10.0);
+ let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
+ g * y + g * r
+ } else if q >= 0.25 {
+ let y = 2.249481201171875;
+ let g = (-2.0 * q.ln()).sqrt();
+ let xs = q - 0.25;
+ let r =
+ evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
+ g / (y + r)
+ } else {
+ let x = (-q.ln()).sqrt();
+ if x < 3.0 {
+ let y = 0.807220458984375;
+ let xs = x - 1.125;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_CD);
+ y * x + r * x
+ } else if x < 6.0 {
+ let y = 0.93995571136474609375;
+ let xs = x - 3.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_DD);
+ y * x + r * x
+ } else if x < 18.0 {
+ let y = 0.98362827301025390625;
+ let xs = x - 6.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_ED);
+ y * x + r * x
+ } else if x < 44.0 {
+ let y = 0.99714565277099609375;
+ let xs = x - 18.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_FD);
+ y * x + r * x
+ } else {
+ let y = 0.99941349029541015625;
+ let xs = x - 44.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_GD);
+ y * x + r * x
+ }
+ };
+ s * result
+}
diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs
index 97e195ef..527646d6 100644
--- a/candle-core/src/cpu/kernels.rs
+++ b/candle-core/src/cpu/kernels.rs
@@ -1,4 +1,7 @@
-pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
+pub trait VecOps: num_traits::NumAssign + Copy {
+ fn min(self, rhs: Self) -> Self;
+ fn max(self, rhs: Self) -> Self;
+
/// Dot-product of two vectors.
///
/// # Safety
@@ -37,10 +40,7 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
- let x = *xs.add(i);
- if x > *res {
- *res = x
- }
+ *res = (*res).max(*xs.add(i))
}
}
@@ -54,15 +54,22 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
- let x = *xs.add(i);
- if x < *res {
- *res = x
- }
+ *res = (*res).min(*xs.add(i))
}
}
}
impl VecOps for f32 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
super::vec_dot_f32(lhs, rhs, res, len)
@@ -75,6 +82,16 @@ impl VecOps for f32 {
}
impl VecOps for half::f16 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
@@ -83,11 +100,61 @@ impl VecOps for half::f16 {
}
}
-impl VecOps for f64 {}
-impl VecOps for half::bf16 {}
-impl VecOps for u8 {}
-impl VecOps for u32 {}
-impl VecOps for i64 {}
+impl VecOps for f64 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+}
+impl VecOps for half::bf16 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+}
+impl VecOps for u8 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ ::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ ::max(self, other)
+ }
+}
+impl VecOps for u32 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ ::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ ::max(self, other)
+ }
+}
+impl VecOps for i64 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ ::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ ::max(self, other)
+ }
+}
#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
diff --git a/candle-core/src/cpu/mod.rs b/candle-core/src/cpu/mod.rs
index 9a8e6317..50afb30f 100644
--- a/candle-core/src/cpu/mod.rs
+++ b/candle-core/src/cpu/mod.rs
@@ -1,3 +1,4 @@
+pub mod erf;
pub mod kernels;
trait Cpu {
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index ed3dd3fc..4e808b34 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -2,6 +2,10 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
+use rayon::prelude::*;
+
+const USE_IM2COL_CONV1D: bool = true;
+const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
// intercept the oom errors to avoid panicking and provide a proper error.
@@ -445,7 +449,7 @@ pub fn unary_map_vec U, FV: FnMut(&[T], &mut [U
}
// This function maps over two strided index sequences.
-fn binary_map U>(
+pub fn binary_map U>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
@@ -525,7 +529,7 @@ fn binary_map U>(
}
// Similar to binary_map but with vectorized variants.
-fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])>(
+pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
@@ -723,6 +727,36 @@ impl Map1 for MaxPool2D {
}
}
+struct UpsampleNearest1D(usize);
+
+impl Map1 for UpsampleNearest1D {
+ fn f(&self, src: &[T], layout: &Layout) -> Result> {
+ // TODO: Specialized implementation for the case 2*sz?
+ let dst_sz = self.0;
+ let (b_sz, c, src_sz) = layout.shape().dims3()?;
+ let stride = layout.stride();
+ let stride_sz = stride[2];
+ let src_index = layout.start_offset();
+ let scale_sz = src_sz as f64 / dst_sz as f64;
+ let mut dst = vec![T::zero(); b_sz * c * dst_sz];
+ let src_idxs = (0..dst_sz)
+ .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
+ .collect::>();
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * dst_sz..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * dst_sz..];
+ let src_index = src_index + c_idx * stride[1];
+ for (idx, src_idx) in src_idxs.iter().enumerate() {
+ dst[idx] = src[src_index + src_idx * stride_sz]
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct UpsampleNearest2D(usize, usize);
impl Map1 for UpsampleNearest2D {
@@ -1052,10 +1086,8 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
- let num_threads = crate::utils::get_num_threads();
-
for offset in 0..p.k_size {
- crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let dst_idx = dst_c_idx * l_out;
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
@@ -1090,6 +1122,140 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
+struct Im2Col1D {
+ l_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col1D {
+ fn l_out(&self, l: usize) -> usize {
+ (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
+ }
+}
+
+impl Map1 for Im2Col1D {
+ fn f(&self, vs: &[T], layout: &Layout) -> Result> {
+ let &Self {
+ l_k,
+ stride,
+ dilation,
+ padding,
+ } = self;
+ let (b, c, l) = layout.shape().dims3()?;
+ let l_out = self.l_out(l);
+ let src = &vs[layout.start_offset()..];
+ let mut dst = vec![T::zero(); b * l_out * c * l_k];
+ let (src_s0, src_s1, src_s2) = {
+ let s = layout.stride();
+ (s[0], s[1], s[2])
+ };
+ // TODO: provide specialized kernels for the common use cases.
+ // - l_k = 1
+ // - padding = 0
+ // - stride = 1
+ // - dilation = 1
+ for b_idx in 0..b {
+ let src_idx = b_idx * src_s0;
+ let dst_idx = b_idx * l_out * c * l_k;
+ for l_idx in 0..l_out {
+ let dst_idx = dst_idx + l_idx * c * l_k;
+ for c_idx in 0..c {
+ let dst_idx = dst_idx + c_idx * l_k;
+ let src_idx = c_idx * src_s1 + src_idx;
+ for l_k_idx in 0..l_k {
+ let src_l = l_idx * stride + l_k_idx * dilation;
+ if padding != 0 && (src_l < padding || src_l >= l + padding) {
+ continue;
+ }
+ let src_l = src_l - padding;
+ let src_idx = src_idx + src_l * src_s2;
+ let dst_idx = dst_idx + l_k_idx;
+ dst[dst_idx] = src[src_idx]
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct Im2Col {
+ h_k: usize,
+ w_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col {
+ fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
+ let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
+ let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
+ (h_out, w_out)
+ }
+}
+
+impl Map1 for Im2Col {
+ fn f(&self, vs: &[T], layout: &Layout) -> Result> {
+ let &Self {
+ h_k,
+ w_k,
+ stride,
+ dilation,
+ padding,
+ } = self;
+ let (b, c, h, w) = layout.shape().dims4()?;
+ let (h_out, w_out) = self.hw_out(h, w);
+ let src = &vs[layout.start_offset()..];
+ let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
+ let (src_s0, src_s1, src_s2, src_s3) = {
+ let s = layout.stride();
+ (s[0], s[1], s[2], s[3])
+ };
+ // TODO: provide specialized kernels for the common use cases.
+ // - h_k = w_k = 1
+ // - padding = 0
+ // - stride = 1
+ // - dilation = 1
+ for b_idx in 0..b {
+ let src_idx = b_idx * src_s0;
+ let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
+ for h_idx in 0..h_out {
+ let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
+ for w_idx in 0..w_out {
+ let dst_idx = dst_idx + w_idx * c * h_k * w_k;
+ for c_idx in 0..c {
+ let dst_idx = dst_idx + c_idx * h_k * w_k;
+ let src_idx = c_idx * src_s1 + src_idx;
+ for h_k_idx in 0..h_k {
+ let src_h = h_idx * stride + h_k_idx * dilation;
+ if padding != 0 && (src_h < padding || src_h >= h + padding) {
+ continue;
+ }
+ let src_h = src_h - padding;
+ let src_idx = src_idx + src_h * src_s2;
+ let dst_idx = dst_idx + h_k_idx * w_k;
+ for w_k_idx in 0..w_k {
+ let src_w = w_idx * stride + w_k_idx * dilation;
+ if padding != 0 && (src_w < padding || src_w >= w + padding) {
+ continue;
+ }
+ let src_w = src_w - padding;
+ let src_idx = src_idx + src_w * src_s3;
+ let dst_idx = dst_idx + w_k_idx;
+ dst[dst_idx] = src[src_idx]
+ }
+ }
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
@@ -1123,11 +1289,9 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
- let num_threads = crate::utils::get_num_threads();
-
for offset_h in 0..p.k_h {
for offset_w in 0..p.k_w {
- crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let dst_idx = dst_c_idx * out_w * out_h;
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
@@ -1216,11 +1380,10 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
}
}
}
- let num_threads = crate::utils::get_num_threads();
for k_y in 0..p.k_h {
for k_x in 0..p.k_w {
- crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
@@ -1298,8 +1461,9 @@ impl Map2 for MatMul {
) -> Result> {
use gemm::{gemm, Parallelism};
- if T::DTYPE == DType::BF16 {
- return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?;
+ match T::DTYPE {
+ DType::F16 | DType::F32 | DType::F64 => {}
+ _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
}
let (b, m, n, k) = self.0;
@@ -2003,6 +2167,10 @@ impl BackendStorage for CpuStorage {
MaxPool2D(kernel_size, stride).map(self, layout)
}
+ fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result {
+ UpsampleNearest1D(sz).map(self, layout)
+ }
+
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result {
UpsampleNearest2D(h, w).map(self, layout)
}
@@ -2231,7 +2399,40 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv1D,
) -> Result {
- Conv1D(params).map(self, l, kernel, kernel_l)
+ if !USE_IM2COL_CONV1D {
+ return Conv1D(params).map(self, l, kernel, kernel_l);
+ }
+ let op = Im2Col1D {
+ l_k: params.k_size,
+ padding: params.padding,
+ stride: params.stride,
+ dilation: params.dilation,
+ };
+ let col = op.map(self, l)?;
+ let b = params.b_size;
+ let n = params.c_out;
+ let l_out = params.l_out();
+ let k = op.l_k * params.c_in;
+ let m = l_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
fn conv2d(
@@ -2241,7 +2442,43 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result {
- Conv2D(params).map(self, l, kernel, kernel_l)
+ if !USE_IM2COL_CONV2D {
+ return Conv2D(params).map(self, l, kernel, kernel_l);
+ }
+ let op = Im2Col {
+ h_k: params.k_h,
+ w_k: params.k_w,
+ padding: params.padding,
+ stride: params.stride,
+ dilation: params.dilation,
+ };
+ let col = op.map(self, l)?;
+ let b = params.b_size;
+ let n = params.c_out;
+ let (h_out, w_out) = (params.out_h(), params.out_w());
+ let k = op.h_k * op.w_k * params.c_in;
+ let m = h_out * w_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
+ .transpose(1, 2)?
+ .transpose(1, 3)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
fn conv_transpose2d(
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 663f2319..00fd1d04 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1,7 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
-use candle_kernels as kernels;
+pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{
@@ -312,6 +312,13 @@ impl BackendDevice for CudaDevice {
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
+ // curand can only generate an odd number of values.
+ // https://github.com/huggingface/candle/issues/734
+ let elem_count_round = if elem_count % 2 == 1 {
+ elem_count + 1
+ } else {
+ elem_count
+ };
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
@@ -321,7 +328,7 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
- let mut data = unsafe { self.alloc::(elem_count) }.w()?;
+ let mut data = unsafe { self.alloc::(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
@@ -329,7 +336,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data)
}
DType::F64 => {
- let mut data = unsafe { self.alloc::(elem_count) }.w()?;
+ let mut data = unsafe { self.alloc::(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
@@ -383,7 +390,7 @@ impl BackendDevice for CudaDevice {
}
#[derive(Debug)]
-enum CudaStorageSlice {
+pub enum CudaStorageSlice {
U8(CudaSlice),
U32(CudaSlice),
I64(CudaSlice),
@@ -394,7 +401,7 @@ enum CudaStorageSlice {
}
type S = CudaStorageSlice;
-trait Map1 {
+pub trait Map1 {
fn f(
&self,
src: &CudaSlice,
@@ -416,7 +423,7 @@ trait Map1 {
}
}
-trait Map2 {
+pub trait Map2 {
fn f(
&self,
src1: &CudaSlice,
@@ -441,7 +448,7 @@ trait Map2 {
}
}
-trait Map2InPlace {
+pub trait Map2InPlace {
fn f(
&self,
dst: &mut CudaSlice,
@@ -472,7 +479,7 @@ trait Map2InPlace {
}
}
-trait Map1Any {
+pub trait Map1Any {
fn f) -> S>(
&self,
src: &CudaSlice,
@@ -495,7 +502,7 @@ trait Map1Any {
}
}
-trait Map2Any {
+pub trait Map2Any {
fn f(
&self,
src1: &CudaSlice,
@@ -532,7 +539,7 @@ impl Map1 for Clone {
}
}
-fn kernel_name(root: &str) -> String {
+pub fn kernel_name(root: &str) -> String {
let dtype = T::DTYPE.as_str();
format!("{root}_{dtype}")
}
@@ -593,6 +600,105 @@ impl Map1 for Elu {
}
}
+struct Im2Col1D {
+ l_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col1D {
+ fn l_out(&self, l: usize) -> usize {
+ (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
+ }
+}
+
+impl Map1 for Im2Col1D {
+ fn f(
+ &self,
+ src: &CudaSlice,
+ dev: &CudaDevice,
+ layout: &Layout,
+ ) -> Result> {
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let l_out = self.l_out(dims[2]);
+ let dst_el = dims[0] * l_out * dims[1] * self.l_k;
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
+ let src = &src.slice(layout.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::("im2col1d"), kernels::CONV)?;
+ // SAFETY: Set later by running the kernel.
+ let dst = unsafe { dev.alloc::(dst_el) }.w()?;
+ let params = (
+ dst_el,
+ l_out,
+ self.l_k,
+ self.stride,
+ self.padding,
+ self.dilation,
+ &ds,
+ src,
+ &dst,
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(dst)
+ }
+}
+
+struct Im2Col {
+ h_k: usize,
+ w_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col {
+ fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
+ let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
+ let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
+ (h_out, w_out)
+ }
+}
+
+impl Map1 for Im2Col {
+ fn f(
+ &self,
+ src: &CudaSlice,
+ dev: &CudaDevice,
+ layout: &Layout,
+ ) -> Result> {
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
+ let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
+ let src = &src.slice(layout.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::("im2col"), kernels::CONV)?;
+ // SAFETY: Set later by running the kernel.
+ let dst = unsafe { dev.alloc::(dst_el) }.w()?;
+ let params = (
+ dst_el,
+ h_out,
+ w_out,
+ self.h_k,
+ self.w_k,
+ self.stride,
+ self.padding,
+ self.dilation,
+ &ds,
+ src,
+ &dst,
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(dst)
+ }
+}
+
struct Powf(f64);
impl Map1 for Powf {
fn f(
@@ -1310,8 +1416,8 @@ fn slice_src_and_dst<'a, T>(
#[derive(Debug)]
pub struct CudaStorage {
- slice: CudaStorageSlice,
- device: CudaDevice,
+ pub slice: CudaStorageSlice,
+ pub device: CudaDevice,
}
pub trait CudaDType: Sized {
@@ -1650,9 +1756,46 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv1D,
) -> Result {
+ const USE_IM2COL_CONV1D: bool = true;
+
let device = self.device().clone();
- let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
- Ok(Self { slice, device })
+ if !USE_IM2COL_CONV1D {
+ let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
+ return Ok(Self { slice, device });
+ }
+
+ let col = Im2Col1D {
+ l_k: params.k_size,
+ stride: params.stride,
+ dilation: params.dilation,
+ padding: params.padding,
+ }
+ .map(&self.slice, &device, l)?;
+ let col = Self { slice: col, device };
+ let l_out = params.l_out();
+ let b = params.b_size;
+ let n = params.c_out;
+ let k = params.k_size * params.c_in;
+ let m = l_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
#[cfg(not(feature = "cudnn"))]
@@ -1663,9 +1806,50 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result {
+ const USE_IM2COL_CONV2D: bool = true;
+
let device = self.device().clone();
- let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
- Ok(Self { slice, device })
+ if !USE_IM2COL_CONV2D {
+ let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
+ return Ok(Self { slice, device });
+ }
+
+ let col = Im2Col {
+ h_k: params.k_h,
+ w_k: params.k_w,
+ stride: params.stride,
+ dilation: params.dilation,
+ padding: params.padding,
+ }
+ .map(&self.slice, &device, l)?;
+ let col = Self { slice: col, device };
+ let h_out = params.out_h();
+ let w_out = params.out_w();
+ let b = params.b_size;
+ let n = params.c_out;
+ let k = params.k_h * params.k_w * params.c_in;
+ let m = h_out * w_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, h_out, w_out, n))
+ .transpose(1, 2)?
+ .transpose(1, 3)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
#[cfg(feature = "cudnn")]
@@ -1770,6 +1954,10 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
+ fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result {
+ crate::bail!("upsample-nearest1d is not supported on cuda")
+ }
+
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result {
let device = self.device().clone();
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
@@ -1889,6 +2077,9 @@ impl BackendStorage for CudaStorage {
let src_shape = src_l.shape();
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
+ if el_count == 0 {
+ return Ok(());
+ }
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs
index 235ad6e3..dd466ba2 100644
--- a/candle-core/src/cudnn.rs
+++ b/candle-core/src/cudnn.rs
@@ -54,8 +54,8 @@ pub(crate) fn launch_conv2d<
let x_shape = [
params.b_size as i32,
params.c_in as i32,
- params.i_w as i32,
params.i_h as i32,
+ params.i_w as i32,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
@@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d<
[
params.c_out as i32,
params.c_in as i32,
- params.k_w as i32,
params.k_h as i32,
+ params.k_w as i32,
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
- [params.b_size as i32, params.c_out as i32, w_out, h_out],
+ [params.b_size as i32, params.c_out as i32, h_out, w_out],
)?;
let conv2d = Conv2dForward {
conv: &conv,
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index adfc4a3c..c7a1567f 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -1,15 +1,24 @@
+//! Types for elements that can be stored and manipulated using tensors.
#![allow(clippy::redundant_closure_call)]
use crate::backend::BackendStorage;
use crate::{CpuStorage, Error, Result};
+/// The different types of elements allowed in tensors.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
+ // Unsigned 8 bits integer.
U8,
+ // Unsigned 32 bits integer.
U32,
+ // Signed 64 bits integer.
I64,
+ // Brain floating-point using half precision (16 bits).
BF16,
+ // Floating-point using half precision (16 bits).
F16,
+ // Floating-point using single precision (32 bits).
F32,
+ // Floating-point using double precision (64 bits).
F64,
}
@@ -33,6 +42,7 @@ impl std::str::FromStr for DType {
}
impl DType {
+ /// String representation for dtypes.
pub fn as_str(&self) -> &'static str {
match self {
Self::U8 => "u8",
@@ -45,6 +55,7 @@ impl DType {
}
}
+ /// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U8 => 1,
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 6c896653..5cc9c6d8 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -152,6 +152,10 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 1cf20a84..be8f7b07 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -30,7 +30,7 @@ pub enum Error {
UnsupportedDTypeForOp(DType, &'static str),
// === Dimension Index Errors ===
- #[error("{op}: dimension index {dim} out of range for {shape:?}")]
+ #[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
DimOutOfRange {
shape: Shape,
dim: i32,
@@ -207,11 +207,11 @@ pub type Result = std::result::Result;
impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
- Self::Wrapped(Box::new(err))
+ Self::Wrapped(Box::new(err)).bt()
}
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
- Self::Msg(err.to_string())
+ Self::Msg(err.to_string()).bt()
}
pub fn bt(self) -> Self {
diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs
index 2b6d694b..7b84d316 100644
--- a/candle-core/src/indexer.rs
+++ b/candle-core/src/indexer.rs
@@ -46,19 +46,31 @@ impl Tensor {
current_dim += 1;
out
}
+ TensorIndexer::IndexSelect(indexes) => {
+ if indexes.rank() != 1 {
+ crate::bail!("multi-dimensional tensor indexing is not supported")
+ }
+ let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
+ current_dim += 1;
+ out
+ }
+ TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
};
}
Ok(x)
}
}
-#[derive(Debug, Clone)]
+#[derive(Debug)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
/// This selects the elemnts for which an index has some specific value.
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound, Bound),
+ /// Indexing via a 1d tensor
+ IndexSelect(Tensor),
+ Err(Error),
}
impl From for TensorIndexer {
@@ -67,6 +79,31 @@ impl From for TensorIndexer {
}
}
+impl From<&[u32]> for TensorIndexer {
+ fn from(index: &[u32]) -> Self {
+ match Tensor::new(index, &crate::Device::Cpu) {
+ Ok(tensor) => TensorIndexer::IndexSelect(tensor),
+ Err(e) => TensorIndexer::Err(e),
+ }
+ }
+}
+
+impl From> for TensorIndexer {
+ fn from(index: Vec) -> Self {
+ let len = index.len();
+ match Tensor::from_vec(index, len, &crate::Device::Cpu) {
+ Ok(tensor) => TensorIndexer::IndexSelect(tensor),
+ Err(e) => TensorIndexer::Err(e),
+ }
+ }
+}
+
+impl From<&Tensor> for TensorIndexer {
+ fn from(tensor: &Tensor) -> Self {
+ TensorIndexer::IndexSelect(tensor.clone())
+ }
+}
+
macro_rules! impl_from_range {
($range_type:ty) => {
impl From<$range_type> for TensorIndexer {
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index a0347416..52effdcf 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -59,6 +59,7 @@ mod op;
pub mod pickle;
pub mod quantized;
pub mod safetensors;
+pub mod scalar;
pub mod shape;
mod storage;
mod strided_index;
@@ -109,14 +110,8 @@ impl ToUsize2 for (usize, usize) {
}
// A simple trait defining a module with forward method using a single argument.
-pub trait Module: std::fmt::Debug {
+pub trait Module {
fn forward(&self, xs: &Tensor) -> Result;
-
- /// Change the module to use training mode vs eval mode.
- ///
- /// The default implementation does nothing as this is only used for a couple modules such as
- /// dropout or batch-normalization.
- fn set_training(&mut self, _training: bool) {}
}
impl Module for quantized::QMatMul {
@@ -124,3 +119,9 @@ impl Module for quantized::QMatMul {
self.forward(xs)
}
}
+
+impl Result> Module for T {
+ fn forward(&self, xs: &Tensor) -> Result {
+ self(xs)
+ }
+}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index fbfc9c1a..4882a205 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -58,6 +58,8 @@ pub enum UnaryOp {
Sqr,
Sqrt,
Gelu,
+ GeluErf,
+ Erf,
Relu,
Tanh,
}
@@ -116,6 +118,7 @@ pub enum Op {
stride: (usize, usize),
},
+ UpsampleNearest1D(Tensor),
UpsampleNearest2D(Tensor),
Cat(Vec, usize),
@@ -324,6 +327,8 @@ pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
+pub(crate) struct GeluErf;
+pub(crate) struct Erf;
pub(crate) struct Relu;
pub(crate) struct Tanh;
@@ -600,6 +605,92 @@ impl UnaryOpT for Gelu {
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys)
}
+
+ #[cfg(feature = "accelerate")]
+ const F32_VEC: bool = true;
+
+ #[cfg(feature = "accelerate")]
+ #[inline(always)]
+ fn f32_vec(xs: &[f32], ys: &mut [f32]) {
+ crate::accelerate::vs_gelu(xs, ys)
+ }
+
+ #[cfg(feature = "accelerate")]
+ const F64_VEC: bool = true;
+
+ #[cfg(feature = "accelerate")]
+ #[inline(always)]
+ fn f64_vec(xs: &[f64], ys: &mut [f64]) {
+ crate::accelerate::vd_gelu(xs, ys)
+ }
+}
+
+impl UnaryOpT for Erf {
+ const NAME: &'static str = "erf";
+ const KERNEL: &'static str = "uerf";
+ const V: Self = Erf;
+ #[inline(always)]
+ fn bf16(v: bf16) -> bf16 {
+ bf16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f16(v: f16) -> f16 {
+ f16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f32(v: f32) -> f32 {
+ Self::f64(v as f64) as f32
+ }
+ #[inline(always)]
+ fn f64(v: f64) -> f64 {
+ crate::cpu::erf::erf(v)
+ }
+ #[inline(always)]
+ fn u8(_: u8) -> u8 {
+ 0
+ }
+ #[inline(always)]
+ fn u32(_: u32) -> u32 {
+ 0
+ }
+ #[inline(always)]
+ fn i64(_: i64) -> i64 {
+ 0
+ }
+}
+
+impl UnaryOpT for GeluErf {
+ const NAME: &'static str = "gelu_erf";
+ const KERNEL: &'static str = "ugelu_erf";
+ const V: Self = GeluErf;
+ #[inline(always)]
+ fn bf16(v: bf16) -> bf16 {
+ bf16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f16(v: f16) -> f16 {
+ f16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f32(v: f32) -> f32 {
+ Self::f64(v as f64) as f32
+ }
+ #[inline(always)]
+ fn f64(v: f64) -> f64 {
+ (crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
+ }
+ #[inline(always)]
+ fn u8(_: u8) -> u8 {
+ 0
+ }
+ #[inline(always)]
+ fn u32(_: u32) -> u32 {
+ 0
+ }
+ #[inline(always)]
+ fn i64(_: i64) -> i64 {
+ 0
+ }
}
impl UnaryOpT for Relu {
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index 65fd6a6e..a0fe455c 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -85,7 +85,7 @@ const _: () = assert!(std::mem::size_of::() == 34);
pub struct BlockQ8_1 {
pub(crate) d: f16,
pub(crate) s: f16,
- pub(crate) qs: [u8; QK8_1],
+ pub(crate) qs: [i8; QK8_1],
}
const _: () = assert!(std::mem::size_of::() == 36);
@@ -278,6 +278,7 @@ impl GgmlType for BlockQ4_1 {
}
sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
+ + f16::to_f32(xs.m) * f16::to_f32(ys.s)
}
Ok(sumf)
}
@@ -471,6 +472,7 @@ impl GgmlType for BlockQ5_1 {
}
sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
+ + f16::to_f32(xs.m) * f16::to_f32(ys.s)
}
Ok(sumf)
}
@@ -652,8 +654,8 @@ impl GgmlType for BlockQ8_1 {
for j in 0..Self::BLCK_SIZE / 2 {
let v0 = xs[j] * id;
let v1 = xs[j + Self::BLCK_SIZE / 2] * id;
- ys.qs[j] = f32::round(v0) as u8;
- ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as u8;
+ ys.qs[j] = f32::round(v0) as i8;
+ ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as i8;
sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32;
}
ys.s = f16::from_f32(sum as f32) * ys.d;
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index 5c2bb2b2..f627f0f6 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -229,7 +229,7 @@ impl QTensor {
}
}
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct QMatMul(std::sync::Arc);
impl QMatMul {
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index f37bb8ef..d588ea67 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -78,11 +78,7 @@ impl st::View for &Tensor {
}
impl Tensor {
- pub fn save_safetensors>(
- &self,
- name: &str,
- filename: P,
- ) -> Result<()> {
+ pub fn save_safetensors>(&self, name: &str, filename: P) -> Result<()> {
let data = [(name, self.clone())];
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
}
@@ -267,7 +263,7 @@ impl MmapedFile {
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
- pub unsafe fn new>(p: P) -> Result {
+ pub unsafe fn new>(p: P) -> Result {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let inner = memmap2::MmapOptions::new()
diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs
new file mode 100644
index 00000000..43e1f4c8
--- /dev/null
+++ b/candle-core/src/scalar.rs
@@ -0,0 +1,23 @@
+use crate::{Result, Tensor, WithDType};
+
+pub enum TensorScalar {
+ Tensor(Tensor),
+ Scalar(Tensor),
+}
+
+pub trait TensorOrScalar {
+ fn to_tensor_scalar(self) -> Result;
+}
+
+impl TensorOrScalar for &Tensor {
+ fn to_tensor_scalar(self) -> Result {
+ Ok(TensorScalar::Tensor(self.clone()))
+ }
+}
+
+impl TensorOrScalar for T {
+ fn to_tensor_scalar(self) -> Result {
+ let scalar = Tensor::new(self, &crate::Device::Cpu)?;
+ Ok(TensorScalar::Scalar(scalar))
+ }
+}
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index aea8b887..4d500e7f 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -1,3 +1,4 @@
+//! The shape of a tensor is a tuple with the size of each of its dimensions.
#![allow(clippy::redundant_closure_call)]
use crate::{Error, Result};
@@ -72,6 +73,14 @@ impl From<(usize, usize, usize, usize, usize)> for Shape {
}
}
+impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
+ fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
+ Self(vec![
+ d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
+ ])
+ }
+}
+
impl From> for Shape {
fn from(dims: Vec) -> Self {
Self(dims)
@@ -119,6 +128,7 @@ impl Shape {
Self(dims.to_vec())
}
+ /// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
pub fn rank(&self) -> usize {
self.0.len()
}
@@ -127,10 +137,12 @@ impl Shape {
self.0
}
+ /// The dimensions as a slice of `usize`.
pub fn dims(&self) -> &[usize] {
&self.0
}
+ /// The total number of elements, this is the product of all dimension sizes.
pub fn elem_count(&self) -> usize {
self.0.iter().product()
}
@@ -182,6 +194,8 @@ impl Shape {
true
}
+ /// Modifies the shape by adding a list of additional dimensions at the end of the existing
+ /// dimensions.
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
self.0.extend(additional_dims);
self
@@ -419,6 +433,29 @@ impl Dims for (D1, D2, D3, D4) {
}
}
+impl Dims for (D1, D2, D3, D4, D5) {
+ fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> {
+ let d0 = self.0.to_index(shape, op)?;
+ let d1 = self.1.to_index(shape, op)?;
+ let d2 = self.2.to_index(shape, op)?;
+ let d3 = self.3.to_index(shape, op)?;
+ let d4 = self.4.to_index(shape, op)?;
+ Ok(vec![d0, d1, d2, d3, d4])
+ }
+}
+
+impl Dims for (D1, D2, D3, D4, D5, D6) {
+ fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> {
+ let d0 = self.0.to_index(shape, op)?;
+ let d1 = self.1.to_index(shape, op)?;
+ let d2 = self.2.to_index(shape, op)?;
+ let d3 = self.3.to_index(shape, op)?;
+ let d4 = self.4.to_index(shape, op)?;
+ let d5 = self.5.to_index(shape, op)?;
+ Ok(vec![d0, d1, d2, d3, d4, d5])
+ }
+}
+
extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
@@ -457,3 +494,171 @@ mod tests {
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}
+
+pub trait ShapeWithOneHole {
+ fn into_shape(self, el_count: usize) -> Result;
+}
+
+impl> ShapeWithOneHole for S {
+ fn into_shape(self, _el_count: usize) -> Result {
+ Ok(self.into())
+ }
+}
+
+impl ShapeWithOneHole for ((),) {
+ fn into_shape(self, el_count: usize) -> Result {
+ Ok(el_count.into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let ((), d1) = self;
+ if el_count % d1 != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
+ }
+ Ok((el_count / d1, d1).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let (d1, ()) = self;
+ if el_count % d1 != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
+ }
+ Ok((d1, el_count / d1).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let ((), d1, d2) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let (d1, (), d2) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let (d1, d2, ()) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let ((), d1, d2, d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let (d1, (), d2, d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let (d1, d2, (), d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let (d1, d2, d3, ()) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, el_count / d).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let ((), d1, d2, d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let (d1, (), d2, d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let (d1, d2, (), d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let (d1, d2, d3, (), d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, el_count / d, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result {
+ let (d1, d2, d3, d4, ()) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, d4, el_count / d).into())
+ }
+}
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 8bd14ea9..9bd1fed6 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -369,6 +369,19 @@ impl Storage {
}
}
+ pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result {
+ match self {
+ Storage::Cpu(storage) => {
+ let storage = storage.upsample_nearest1d(layout, sz)?;
+ Ok(Self::Cpu(storage))
+ }
+ Self::Cuda(storage) => {
+ let storage = storage.upsample_nearest1d(layout, sz)?;
+ Ok(Self::Cuda(storage))
+ }
+ }
+ }
+
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result {
match self {
Storage::Cpu(storage) => {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index e181f240..9dccf2b5 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1,8 +1,10 @@
+//! Tensors are N-dimenional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
};
+use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@@ -103,6 +105,28 @@ macro_rules! binary_op {
};
}
+macro_rules! binary_op_scalar {
+ ($fn_name:ident, $op_name:ident) => {
+ pub fn $fn_name(&self, rhs: T) -> Result {
+ let rhs = match rhs.to_tensor_scalar()? {
+ crate::scalar::TensorScalar::Tensor(rhs) => rhs,
+ crate::scalar::TensorScalar::Scalar(rhs) => rhs
+ .to_dtype(self.dtype())?
+ .to_device(self.device())?
+ .broadcast_as(self.shape())?,
+ };
+ let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
+ let storage = self.storage().binary_impl::(
+ &*rhs.storage(),
+ self.layout(),
+ rhs.layout(),
+ )?;
+ let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
+ Ok(from_storage(storage, shape.clone(), op, false))
+ }
+ };
+}
+
macro_rules! broadcast_binary_op {
($fn_name:ident, $inner_fn_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result {
@@ -445,8 +469,8 @@ impl Tensor {
binary_op!(mul, Mul);
binary_op!(sub, Sub);
binary_op!(div, Div);
- binary_op!(maximum, Maximum);
- binary_op!(minimum, Minimum);
+ binary_op_scalar!(maximum, Maximum);
+ binary_op_scalar!(minimum, Minimum);
broadcast_binary_op!(broadcast_add, add);
broadcast_binary_op!(broadcast_mul, mul);
broadcast_binary_op!(broadcast_sub, sub);
@@ -465,6 +489,8 @@ impl Tensor {
unary_op!(sqr, Sqr);
unary_op!(sqrt, Sqrt);
unary_op!(gelu, Gelu);
+ unary_op!(gelu_erf, GeluErf);
+ unary_op!(erf, Erf);
unary_op!(relu, Relu);
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
@@ -642,7 +668,12 @@ impl Tensor {
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
let mut dims = self.dims().to_vec();
dims[dim] = 1;
- let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
+ let op = match op {
+ ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
+ BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
+ }
+ ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
+ };
let res = from_storage(storage, dims, op, false);
if keepdim {
Ok(res)
@@ -775,8 +806,15 @@ impl Tensor {
/// comparison operation is specified by the `op` argument.
///
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
- pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result {
- let shape = self.same_shape_binary_op(rhs, "cmp")?;
+ pub fn cmp(&self, rhs: T, op: CmpOp) -> Result {
+ let rhs = match rhs.to_tensor_scalar()? {
+ crate::scalar::TensorScalar::Tensor(rhs) => rhs,
+ crate::scalar::TensorScalar::Scalar(rhs) => rhs
+ .to_dtype(self.dtype())?
+ .to_device(self.device())?
+ .broadcast_as(self.shape())?,
+ };
+ let shape = self.same_shape_binary_op(&rhs, "cmp")?;
let storage = self
.storage()
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
@@ -785,45 +823,68 @@ impl Tensor {
}
/// Element-wise equality.
- pub fn eq(&self, rhs: &Self) -> Result {
+ pub fn eq(&self, rhs: T) -> Result {
self.cmp(rhs, CmpOp::Eq)
}
/// Element-wise non-equality.
- pub fn ne(&self, rhs: &Self) -> Result {
+ pub fn ne(&self, rhs: T) -> Result {
self.cmp(rhs, CmpOp::Ne)
}
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
/// rhs` and 0 otherwise.
- pub fn lt(&self, rhs: &Self) -> Result {
+ pub fn lt(&self, rhs: T) -> Result {
self.cmp(rhs, CmpOp::Lt)
}
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
/// rhs` and 0 otherwise.
- pub fn gt(&self, rhs: &Self) -> Result {
+ pub fn gt(&self, rhs: T) -> Result {
self.cmp(rhs, CmpOp::Gt)
}
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
/// rhs` and 0 otherwise.
- pub fn ge(&self, rhs: &Self) -> Result {
+ pub fn ge(&self, rhs: T) -> Result {
self.cmp(rhs, CmpOp::Ge)
}
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
/// rhs` and 0 otherwise.
- pub fn le(&self, rhs: &Self) -> Result {
+ pub fn le(&self, rhs: T) -> Result {
self.cmp(rhs, CmpOp::Le)
}
- /// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
+ /// Clamp the tensor values to be between `min` and `max`.
+ pub fn clamp(&self, min: T1, max: T2) -> Result {
+ self.maximum(min)?.minimum(max)
+ }
+
+ /// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
+ ///
+ /// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
+ /// tensor also has three dimensions, `(batch, channels, target_size)`.
+ pub fn interpolate1d(&self, target_size: usize) -> Result {
+ let (n, c, _l) = self.dims3()?;
+ let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
+ let storage = self
+ .storage()
+ .upsample_nearest1d(self.layout(), target_size)?;
+ Ok(from_storage(storage, (n, c, target_size), op, false))
+ }
+
+ /// Alias for `interpolate1d`.
+ pub fn upsample_nearest1d(&self, target_size: usize) -> Result {
+ self.interpolate1d(target_size)
+ }
+
+ /// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
/// nearest element.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
- pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result {
+ pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
let storage = self
@@ -832,6 +893,11 @@ impl Tensor {
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
}
+ /// Alias for `interpolate2d`.
+ pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result {
+ self.interpolate2d(target_h, target_w)
+ }
+
/// 2D average pooling over an input tensor with multiple channels.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
@@ -1684,12 +1750,15 @@ impl Tensor {
Ok(from_storage(storage, shape, BackpropOp::none(), true))
}
- // TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same.
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
/// a new storage and copies the data over, the returned tensor is always contiguous.
///
+ /// The shape can be specified using a tuple of `usize` and at most one `()` in which case
+ /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
+ /// as to match the number of elements in the tensor.
+ ///
/// ```rust
/// # use candle_core::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
@@ -1699,10 +1768,14 @@ impl Tensor {
///
/// let c = a.reshape((3, 2))?;
/// assert_eq!(c.shape().dims(), &[3, 2]);
+ ///
+ /// let c = a.reshape((2, (), 1))?;
+ /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
+ ///
/// # Ok::<(), candle_core::Error>(())
/// ```
- pub fn reshape>(&self, shape: S) -> Result {
- let shape = shape.into();
+ pub fn reshape(&self, s: S) -> Result {
+ let shape = s.into_shape(self.elem_count())?;
if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
@@ -1836,6 +1909,34 @@ impl Tensor {
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
+ for (arg_idx, arg) in args.iter().enumerate() {
+ let arg = arg.as_ref();
+ if arg0.rank() != arg.rank() {
+ Err(Error::UnexpectedNumberOfDims {
+ expected: arg0.rank(),
+ got: arg.rank(),
+ shape: arg.shape().clone(),
+ }
+ .bt())?
+ }
+ for (dim_idx, (v1, v2)) in arg0
+ .shape()
+ .dims()
+ .iter()
+ .zip(arg.shape().dims().iter())
+ .enumerate()
+ {
+ if dim_idx != dim && v1 != v2 {
+ Err(Error::ShapeMismatchCat {
+ dim: dim_idx,
+ first_shape: arg0.shape().clone(),
+ n: arg_idx + 1,
+ nth_shape: arg.shape().clone(),
+ }
+ .bt())?
+ }
+ }
+ }
if dim == 0 {
Self::cat0(args)
} else {
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 6af43196..edd0bd79 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1,4 +1,4 @@
-use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
+use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@@ -33,6 +33,44 @@ fn tensor_2d(device: &Device) -> Result<()> {
Ok(())
}
+fn clamp(device: &Device) -> Result<()> {
+ let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
+ let tensor = Tensor::new(data, device)?;
+ let tensor = tensor.clamp(1.5, 6.2)?;
+ assert_eq!(
+ tensor.to_vec2::()?,
+ [[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]],
+ );
+ Ok(())
+}
+
+fn unary_op(device: &Device) -> Result<()> {
+ let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
+ let tensor = Tensor::new(data, device)?;
+ assert_eq!(
+ test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
+ [
+ [-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
+ [2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
+ ]
+ );
+ assert_eq!(
+ test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
+ [
+ [-0.004, 0.8413, 3.9999, -0.046, 0.3457],
+ [2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
+ ]
+ );
+ assert_eq!(
+ test_utils::to_vec2_round(&tensor.erf()?, 4)?,
+ [
+ [-1.0, 0.8427, 1.0, -0.1125, 0.5205],
+ [0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
+ ]
+ );
+ Ok(())
+}
+
fn binary_op(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor1 = Tensor::new(data, device)?;
@@ -877,6 +915,14 @@ fn broadcasting(device: &Device) -> Result<()> {
Ok(())
}
+fn randn(device: &Device) -> Result<()> {
+ let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
+ assert_eq!(tensor.dims(), [5, 3]);
+ let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
+ assert_eq!(tensor.dims(), [5, 3]);
+ Ok(())
+}
+
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
@@ -889,6 +935,7 @@ test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
+test_device!(unary_op, unary_op_cpu, unary_op_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
@@ -899,6 +946,8 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
+test_device!(randn, randn_cpu, randn_gpu);
+test_device!(clamp, clamp_cpu, clamp_gpu);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381
diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml
index d69318e1..316f31c5 100644
--- a/candle-datasets/Cargo.toml
+++ b/candle-datasets/Cargo.toml
@@ -11,8 +11,8 @@ readme = "README.md"
[dependencies]
byteorder = { workspace = true }
-candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
-candle-nn = { path = "../candle-nn", version = "0.2.1" }
+candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
+candle-nn = { path = "../candle-nn", version = "0.2.3" }
hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }
diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs
index 30b0d01f..2dac883c 100644
--- a/candle-datasets/src/vision/mnist.rs
+++ b/candle-datasets/src/vision/mnist.rs
@@ -8,13 +8,9 @@ use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File;
use std::io::{self, BufReader, Read};
-fn read_u32(reader: &mut T) -> Result {
- let mut b = vec![0u8; 4];
- reader.read_exact(&mut b)?;
- let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| {
- (s + basis * u64::from(x), basis * 256)
- });
- Ok(result as u32)
+fn read_u32(reader: &mut T) -> std::io::Result {
+ use byteorder::ReadBytesExt;
+ reader.read_u32::()
}
fn check_magic_number(reader: &mut T, expected: u32) -> Result<()> {
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 9035eae0..0e2e8093 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -11,19 +11,19 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
-candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
-candle-nn = { path = "../candle-nn", version = "0.2.1" }
-candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
-candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
-safetensors = { workspace = true }
-serde = { workspace = true }
-serde_json = { workspace = true }
-num-traits = { workspace = true }
-intel-mkl-src = { workspace = true, optional = true }
+candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
+candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
+candle-nn = { path = "../candle-nn", version = "0.2.3" }
+candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true }
+intel-mkl-src = { workspace = true, optional = true }
+num-traits = { workspace = true }
+rayon = { workspace = true }
+safetensors = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
@@ -50,7 +50,7 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cudnn = ["candle/cudnn"]
-flash-attn = ["cuda", "dep:candle-flash-attn"]
+flash-attn = ["cuda", "candle-transformers/flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
diff --git a/candle-examples/examples/bert/README.md b/candle-examples/examples/bert/README.md
new file mode 100644
index 00000000..82ca5f40
--- /dev/null
+++ b/candle-examples/examples/bert/README.md
@@ -0,0 +1,44 @@
+# candle-bert
+
+Bert is a general large language model. In this example it can be used for two
+different tasks:
+- Compute sentence embeddings for a prompt.
+- Compute similarities between a set of sentences.
+
+
+## Sentence embeddings
+
+Bert is used to compute the sentence embeddings for a prompt. The model weights
+are downloaded from the hub on the first run.
+
+```bash
+cargo run --example bert --release -- --prompt "Here is a test sentence"
+
+> [[[ 0.0798, -0.0665, -0.0247, ..., -0.1082, -0.1000, -0.2751],
+> [ 0.4218, 0.2690, 0.2740, ..., 0.3889, 1.3503, 0.9908],
+> [ 0.0466, 0.3041, -0.1143, ..., 0.4427, 0.6926, -0.1515],
+> ...
+> [ 0.3396, 0.4320, -0.4408, ..., 0.9212, 0.2331, -0.6777],
+> [ 0.2789, 0.7539, 0.4306, ..., -0.0095, 0.3375, -1.7529],
+> [ 0.6737, 0.7882, 0.0548, ..., 0.1836, 0.7299, -0.6617]]]
+> Tensor[[1, 7, 384], f32]
+```
+
+## Similarities
+
+In this example, Bert is used to compute the sentence embeddings for a set of
+sentences (hardcoded in the examples). Then cosine similarities are computed for
+each sentence pair and they are reported by decreasing values, hence the first
+reported pair contains the two sentences that have the highest similarity score.
+The sentence embeddings are computed using average pooling through all the
+sentence tokens, including some potential padding.
+
+```bash
+cargo run --example bert --release
+
+> score: 0.85 'The new movie is awesome' 'The new movie is so great'
+> score: 0.61 'The cat sits outside' 'The cat plays in the garden'
+> score: 0.52 'I love pasta' 'Do you like pizza?'
+> score: 0.23 'The new movie is awesome' 'Do you like pizza?'
+> score: 0.22 'I love pasta' 'The new movie is awesome'
+```
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 6cee66ee..9d0eccdf 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -3,14 +3,13 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-mod model;
+use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use anyhow::{anyhow, Error as E, Result};
use candle::Tensor;
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
-use model::{BertModel, Config, DTYPE};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)]
diff --git a/candle-examples/examples/bigcode/README.md b/candle-examples/examples/bigcode/README.md
new file mode 100644
index 00000000..cb4e79b1
--- /dev/null
+++ b/candle-examples/examples/bigcode/README.md
@@ -0,0 +1,19 @@
+# candle-starcoder: code generation model
+
+[StarCoder/BigCode](https://huggingface.co/bigcode/starcoderbase-1b) is a LLM
+model specialized to code generation. The initial model was trained on 80
+programming languages.
+
+## Running some example
+
+```bash
+cargo run --example bigcode --release -- --prompt "fn fact(n: u64) -> u64 "
+
+> fn fact(n: u64) -> u64 {
+> if n == 0 {
+> 1
+> } else {
+> n * fact(n - 1)
+> }
+> }
+```
diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs
index 652cd47f..5f17109e 100644
--- a/candle-examples/examples/bigcode/main.rs
+++ b/candle-examples/examples/bigcode/main.rs
@@ -7,8 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
-mod model;
-use model::{Config, GPTBigCode};
+use candle_transformers::models::bigcode::{Config, GPTBigCode};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
@@ -29,9 +28,10 @@ impl TextGeneration {
tokenizer: Tokenizer,
seed: u64,
temp: Option,
+ top_p: Option,
device: &Device,
) -> Self {
- let logits_processor = LogitsProcessor::new(seed, temp);
+ let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
@@ -95,6 +95,10 @@ struct Args {
#[arg(long)]
temperature: Option,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -150,7 +154,14 @@ fn main() -> Result<()> {
let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
- let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
+ let mut pipeline = TextGeneration::new(
+ model,
+ tokenizer,
+ args.seed,
+ args.temperature,
+ args.top_p,
+ &device,
+ );
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}
diff --git a/candle-examples/examples/dinov2/README.md b/candle-examples/examples/dinov2/README.md
new file mode 100644
index 00000000..10d4ac1f
--- /dev/null
+++ b/candle-examples/examples/dinov2/README.md
@@ -0,0 +1,19 @@
+# candle-dinov2
+
+[DINOv2](https://github.com/facebookresearch/dinov2) is a computer vision model.
+In this example, it is used as an ImageNet classifier: the model returns the
+probability for the image to belong to each of the 1000 ImageNet categories.
+
+## Running some example
+
+```bash
+cargo run --example dinov2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
+
+> mountain bike, all-terrain bike, off-roader: 43.67%
+> bicycle-built-for-two, tandem bicycle, tandem: 33.20%
+> crash helmet : 13.23%
+> unicycle, monocycle : 2.44%
+> maillot : 2.42%
+```
+
+
diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs
index e80c81e2..d3adb37c 100644
--- a/candle-examples/examples/dinov2/main.rs
+++ b/candle-examples/examples/dinov2/main.rs
@@ -9,285 +9,10 @@ extern crate accelerate_src;
use clap::Parser;
-use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+use candle::{DType, IndexOp, D};
+use candle_nn::{Module, VarBuilder};
+use candle_transformers::models::dinov2;
-const IMG_SIZE: usize = 518;
-const PATCH_SIZE: usize = 14;
-const NUM_CLASSES: usize = 1000;
-
-fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result {
- if bias {
- candle_nn::linear(in_dim, out_dim, vb)
- } else {
- candle_nn::linear_no_bias(in_dim, out_dim, vb)
- }
-}
-
-#[derive(Debug)]
-struct Attention {
- qkv: Linear,
- proj: Linear,
- num_heads: usize,
- scale: f64,
-}
-
-impl Attention {
- fn new(
- vb: VarBuilder,
- dim: usize,
- num_heads: usize,
- qkv_bias: bool,
- proj_bias: bool,
- ) -> Result {
- let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
- let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
- let scale = 1. / ((dim / num_heads) as f64).sqrt();
- Ok(Self {
- qkv,
- proj,
- num_heads,
- scale,
- })
- }
-}
-
-impl Module for Attention {
- fn forward(&self, xs: &Tensor) -> Result {
- let (b, n, c) = xs.dims3()?;
- let qkv = self
- .qkv
- .forward(xs)?
- .reshape((b, n, 3, self.num_heads, c / self.num_heads))?
- .transpose(1, 2)? // 02134
- .transpose(0, 1)? // 20134
- .transpose(2, 3)?; // 20314
- let q = (qkv.i(0)? * self.scale)?;
- let k = qkv.i(1)?;
- let v = qkv.i(2)?;
- let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
- let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
- self.proj.forward(&attn)
- }
-}
-
-#[derive(Debug)]
-struct LayerScale {
- gamma: Tensor,
-}
-
-impl LayerScale {
- fn new(vb: VarBuilder, dim: usize) -> Result {
- let gamma = vb.get(dim, "gamma")?;
- Ok(Self { gamma })
- }
-}
-
-impl Module for LayerScale {
- fn forward(&self, xs: &Tensor) -> Result {
- xs.broadcast_mul(&self.gamma)
- }
-}
-
-#[derive(Debug)]
-struct Mlp {
- fc1: Linear,
- fc2: Linear,
-}
-
-impl Mlp {
- fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result {
- let out_features = in_features;
- let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
- let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
- Ok(Self { fc1, fc2 })
- }
-}
-
-impl Module for Mlp {
- fn forward(&self, xs: &Tensor) -> Result {
- let xs = self.fc1.forward(xs)?.gelu()?;
- self.fc2.forward(&xs)
- }
-}
-
-#[derive(Debug)]
-struct Block {
- norm1: LayerNorm,
- attn: Attention,
- ls1: LayerScale,
- norm2: LayerNorm,
- mlp: Mlp,
- ls2: LayerScale,
-}
-
-impl Block {
- fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result {
- let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
- let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
- let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
- let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
- let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
- let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
- Ok(Self {
- norm1,
- attn,
- ls1,
- norm2,
- mlp,
- ls2,
- })
- }
-}
-
-impl Module for Block {
- fn forward(&self, xs: &Tensor) -> Result {
- let residual = xs;
- let xs = self
- .ls1
- .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
- let xs = (xs + residual)?;
- let residual = &xs;
- let xs = self
- .ls2
- .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
- xs + residual
- }
-}
-
-#[derive(Debug)]
-struct PatchEmbed {
- proj: candle_nn::Conv2d,
- patch_size: (usize, usize),
- num_patches: usize,
-}
-
-impl PatchEmbed {
- fn new(
- vb: VarBuilder,
- img_size: usize,
- patch_size: usize,
- in_chans: usize,
- embed_dim: usize,
- ) -> Result {
- let config = candle_nn::Conv2dConfig {
- stride: patch_size,
- ..Default::default()
- };
- let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
- let num_patches = (img_size / patch_size) * (img_size / patch_size);
- Ok(Self {
- proj,
- patch_size: (patch_size, patch_size),
- num_patches,
- })
- }
-}
-
-impl Module for PatchEmbed {
- fn forward(&self, xs: &Tensor) -> Result {
- let (_b, _c, h, w) = xs.dims4()?;
- let (patch_h, patch_w) = self.patch_size;
- if (h % patch_h) != 0 {
- candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
- }
- if (w % patch_w) != 0 {
- candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
- }
- let xs = self.proj.forward(xs)?;
- let (b, c, h, w) = xs.dims4()?;
- // flatten embeddings.
- xs.reshape((b, c, h * w))?.transpose(1, 2)
- }
-}
-
-#[derive(Debug)]
-pub struct DinoVisionTransformer {
- patch_embed: PatchEmbed,
- cls_token: Tensor,
- pos_embed: Tensor,
- blocks: Vec,
- norm: LayerNorm,
- head: Linear,
-}
-
-impl DinoVisionTransformer {
- pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result {
- let patch_embed =
- PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
- let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
- let num_tokens = 1;
- let pos_embed = vb.get(
- (1, patch_embed.num_patches + num_tokens, embed_dim),
- "pos_embed",
- )?;
- let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
- let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
- let vb_b = vb.pp("blocks");
- let blocks = (0..depth)
- .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
- .collect::>>()?;
- Ok(Self {
- patch_embed,
- cls_token,
- pos_embed,
- blocks,
- norm,
- head,
- })
- }
-
- fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result {
- let npatch = xs.dim(1)? - 1;
- let n = self.pos_embed.dim(1)? - 1;
- let sqrt_n = (n as f64).sqrt();
- if npatch == n && w == h {
- return Ok(xs.clone());
- }
- let class_pos_embed = self.pos_embed.i((.., ..1))?;
- let patch_pos_embed = self.pos_embed.i((.., 1..))?;
- let dim = xs.dim(D::Minus1)?;
- let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
- let patch_pos_embed = patch_pos_embed
- .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
- .transpose(2, 3)?
- .transpose(1, 2)?;
- // This uses bicubic interpolation in the original implementation.
- let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
- let el_count = patch_pos_embed.shape().elem_count();
- let patch_pos_embed =
- patch_pos_embed
- .transpose(1, 2)?
- .transpose(2, 3)?
- .reshape((1, el_count / dim, dim))?;
- Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
- }
-
- fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result {
- let (_b, _nc, w, h) = xs.dims4()?;
- let xs = self.patch_embed.forward(xs)?;
- let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
- &xs + &self.interpolate_pos_encoding(&xs, w, h)?
- }
-}
-
-impl Module for DinoVisionTransformer {
- fn forward(&self, xs: &Tensor) -> Result {
- let mut xs = self.prepare_tokens_with_mask(xs)?;
- for blk in self.blocks.iter() {
- xs = blk.forward(&xs)?
- }
- let xs = self.norm.forward(&xs)?;
- let xs_norm_clstoken = xs.i((.., 0))?;
- let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
- let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
- self.head.forward(&xs)
- }
-}
-
-pub fn vit_small(vb: VarBuilder) -> Result {
- DinoVisionTransformer::new(vb, 12, 384, 6)
-}
#[derive(Parser)]
struct Args {
#[arg(long)]
@@ -320,7 +45,7 @@ pub fn main() -> anyhow::Result<()> {
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
- let model = vit_small(vb)?;
+ let model = dinov2::vit_small(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs
index cbe2c90a..1e45e301 100644
--- a/candle-examples/examples/efficientnet/main.rs
+++ b/candle-examples/examples/efficientnet/main.rs
@@ -8,340 +8,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
+use candle::{DType, IndexOp, D};
+use candle_nn::{Module, VarBuilder};
+use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig};
use clap::{Parser, ValueEnum};
-use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn as nn;
-use nn::{Module, VarBuilder};
-
-// Based on the Python version from torchvision.
-// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
-#[derive(Debug, Clone, Copy)]
-pub struct MBConvConfig {
- expand_ratio: f64,
- kernel: usize,
- stride: usize,
- input_channels: usize,
- out_channels: usize,
- num_layers: usize,
-}
-
-fn make_divisible(v: f64, divisor: usize) -> usize {
- let min_value = divisor;
- let new_v = usize::max(
- min_value,
- (v + divisor as f64 * 0.5) as usize / divisor * divisor,
- );
- if (new_v as f64) < 0.9 * v {
- new_v + divisor
- } else {
- new_v
- }
-}
-
-fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec {
- let bneck_conf = |e, k, s, i, o, n| {
- let input_channels = make_divisible(i as f64 * width_mult, 8);
- let out_channels = make_divisible(o as f64 * width_mult, 8);
- let num_layers = (n as f64 * depth_mult).ceil() as usize;
- MBConvConfig {
- expand_ratio: e,
- kernel: k,
- stride: s,
- input_channels,
- out_channels,
- num_layers,
- }
- };
- vec![
- bneck_conf(1., 3, 1, 32, 16, 1),
- bneck_conf(6., 3, 2, 16, 24, 2),
- bneck_conf(6., 5, 2, 24, 40, 2),
- bneck_conf(6., 3, 2, 40, 80, 3),
- bneck_conf(6., 5, 1, 80, 112, 3),
- bneck_conf(6., 5, 2, 112, 192, 4),
- bneck_conf(6., 3, 1, 192, 320, 1),
- ]
-}
-
-impl MBConvConfig {
- fn b0() -> Vec {
- bneck_confs(1.0, 1.0)
- }
- fn b1() -> Vec {
- bneck_confs(1.0, 1.1)
- }
- fn b2() -> Vec {
- bneck_confs(1.1, 1.2)
- }
- fn b3() -> Vec {
- bneck_confs(1.2, 1.4)
- }
- fn b4() -> Vec {
- bneck_confs(1.4, 1.8)
- }
- fn b5() -> Vec {
- bneck_confs(1.6, 2.2)
- }
- fn b6() -> Vec {
- bneck_confs(1.8, 2.6)
- }
- fn b7() -> Vec {
- bneck_confs(2.0, 3.1)
- }
-}
-
-/// Conv2D with same padding.
-#[derive(Debug)]
-struct Conv2DSame {
- conv2d: nn::Conv2d,
- s: usize,
- k: usize,
-}
-
-impl Conv2DSame {
- fn new(
- vb: VarBuilder,
- i: usize,
- o: usize,
- k: usize,
- stride: usize,
- groups: usize,
- bias: bool,
- ) -> Result {
- let conv_config = nn::Conv2dConfig {
- stride,
- groups,
- ..Default::default()
- };
- let conv2d = if bias {
- nn::conv2d(i, o, k, conv_config, vb)?
- } else {
- nn::conv2d_no_bias(i, o, k, conv_config, vb)?
- };
- Ok(Self {
- conv2d,
- s: stride,
- k,
- })
- }
-}
-
-impl Module for Conv2DSame {
- fn forward(&self, xs: &Tensor) -> Result {
- let s = self.s;
- let k = self.k;
- let (_, _, ih, iw) = xs.dims4()?;
- let oh = (ih + s - 1) / s;
- let ow = (iw + s - 1) / s;
- let pad_h = usize::max((oh - 1) * s + k - ih, 0);
- let pad_w = usize::max((ow - 1) * s + k - iw, 0);
- if pad_h > 0 || pad_w > 0 {
- let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
- let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
- self.conv2d.forward(&xs)
- } else {
- self.conv2d.forward(xs)
- }
- }
-}
-
-#[derive(Debug)]
-struct ConvNormActivation {
- conv2d: Conv2DSame,
- bn2d: nn::BatchNorm,
- activation: bool,
-}
-
-impl ConvNormActivation {
- fn new(
- vb: VarBuilder,
- i: usize,
- o: usize,
- k: usize,
- stride: usize,
- groups: usize,
- ) -> Result {
- let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
- let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
- Ok(Self {
- conv2d,
- bn2d,
- activation: true,
- })
- }
-
- fn no_activation(self) -> Self {
- Self {
- activation: false,
- ..self
- }
- }
-}
-
-impl Module for ConvNormActivation {
- fn forward(&self, xs: &Tensor) -> Result {
- let xs = self.conv2d.forward(xs)?;
- let xs = self.bn2d.forward(&xs)?;
- if self.activation {
- swish(&xs)
- } else {
- Ok(xs)
- }
- }
-}
-
-#[derive(Debug)]
-struct SqueezeExcitation {
- fc1: Conv2DSame,
- fc2: Conv2DSame,
-}
-
-impl SqueezeExcitation {
- fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result {
- let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
- let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
- Ok(Self { fc1, fc2 })
- }
-}
-
-impl Module for SqueezeExcitation {
- fn forward(&self, xs: &Tensor) -> Result {
- let residual = xs;
- // equivalent to adaptive_avg_pool2d([1, 1])
- let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
- let xs = self.fc1.forward(&xs)?;
- let xs = swish(&xs)?;
- let xs = self.fc2.forward(&xs)?;
- let xs = nn::ops::sigmoid(&xs)?;
- residual.broadcast_mul(&xs)
- }
-}
-
-#[derive(Debug)]
-struct MBConv {
- expand_cna: Option,
- depthwise_cna: ConvNormActivation,
- squeeze_excitation: SqueezeExcitation,
- project_cna: ConvNormActivation,
- config: MBConvConfig,
-}
-
-impl MBConv {
- fn new(vb: VarBuilder, c: MBConvConfig) -> Result {
- let vb = vb.pp("block");
- let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
- let expand_cna = if exp != c.input_channels {
- Some(ConvNormActivation::new(
- vb.pp("0"),
- c.input_channels,
- exp,
- 1,
- 1,
- 1,
- )?)
- } else {
- None
- };
- let start_index = if expand_cna.is_some() { 1 } else { 0 };
- let depthwise_cna =
- ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
- let squeeze_channels = usize::max(1, c.input_channels / 4);
- let squeeze_excitation =
- SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
- let project_cna =
- ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
- .no_activation();
- Ok(Self {
- expand_cna,
- depthwise_cna,
- squeeze_excitation,
- project_cna,
- config: c,
- })
- }
-}
-
-impl Module for MBConv {
- fn forward(&self, xs: &Tensor) -> Result {
- let use_res_connect =
- self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
- let ys = match &self.expand_cna {
- Some(expand_cna) => expand_cna.forward(xs)?,
- None => xs.clone(),
- };
- let ys = self.depthwise_cna.forward(&ys)?;
- let ys = self.squeeze_excitation.forward(&ys)?;
- let ys = self.project_cna.forward(&ys)?;
- if use_res_connect {
- ys + xs
- } else {
- Ok(ys)
- }
- }
-}
-
-fn swish(s: &Tensor) -> Result {
- s * nn::ops::sigmoid(s)?
-}
-
-#[derive(Debug)]
-struct EfficientNet {
- init_cna: ConvNormActivation,
- blocks: Vec,
- final_cna: ConvNormActivation,
- classifier: nn::Linear,
-}
-
-impl EfficientNet {
- fn new(p: VarBuilder, configs: Vec, nclasses: usize) -> Result {
- let f_p = p.pp("features");
- let first_in_c = configs[0].input_channels;
- let last_out_c = configs.last().unwrap().out_channels;
- let final_out_c = 4 * last_out_c;
- let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
- let nconfigs = configs.len();
- let mut blocks = vec![];
- for (index, cnf) in configs.into_iter().enumerate() {
- let f_p = f_p.pp(index + 1);
- for r_index in 0..cnf.num_layers {
- let cnf = if r_index == 0 {
- cnf
- } else {
- MBConvConfig {
- input_channels: cnf.out_channels,
- stride: 1,
- ..cnf
- }
- };
- blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
- }
- }
- let final_cna =
- ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
- let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
- Ok(Self {
- init_cna,
- blocks,
- final_cna,
- classifier,
- })
- }
-}
-
-impl Module for EfficientNet {
- fn forward(&self, xs: &Tensor) -> Result {
- let mut xs = self.init_cna.forward(xs)?;
- for block in self.blocks.iter() {
- xs = block.forward(&xs)?
- }
- let xs = self.final_cna.forward(&xs)?;
- // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
- let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
- self.classifier.forward(&xs)
- }
-}
-
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
B0,
diff --git a/candle-examples/examples/falcon/README.md b/candle-examples/examples/falcon/README.md
new file mode 100644
index 00000000..267c78c2
--- /dev/null
+++ b/candle-examples/examples/falcon/README.md
@@ -0,0 +1,3 @@
+# candle-falcon
+
+Falcon is a general large language model.
diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs
index 05507f08..b0973d64 100644
--- a/candle-examples/examples/falcon/main.rs
+++ b/candle-examples/examples/falcon/main.rs
@@ -14,8 +14,7 @@ use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
-mod model;
-use model::{Config, Falcon};
+use candle_transformers::models::falcon::{Config, Falcon};
struct TextGeneration {
model: Falcon,
@@ -26,17 +25,25 @@ struct TextGeneration {
repeat_last_n: usize,
}
+struct GenerationOptions {
+ temp: Option,
+ top_p: Option,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+}
+
impl TextGeneration {
fn new(
model: Falcon,
tokenizer: Tokenizer,
+ generation_options: GenerationOptions,
seed: u64,
- temp: Option,
device: &Device,
- repeat_penalty: f32,
- repeat_last_n: usize,
) -> Self {
- let logits_processor = LogitsProcessor::new(seed, temp);
+ let logits_processor =
+ LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p);
+ let repeat_penalty = generation_options.repeat_penalty;
+ let repeat_last_n = generation_options.repeat_last_n;
Self {
model,
tokenizer,
@@ -119,6 +126,10 @@ struct Args {
#[arg(long)]
temperature: Option,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -186,15 +197,14 @@ fn main() -> Result<()> {
let model = Falcon::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
- let mut pipeline = TextGeneration::new(
- model,
- tokenizer,
- args.seed,
- args.temperature,
- &device,
- args.repeat_penalty,
- args.repeat_last_n,
- );
+ let generation_options = GenerationOptions {
+ temp: args.temperature,
+ top_p: args.top_p,
+ repeat_penalty: args.repeat_penalty,
+ repeat_last_n: args.repeat_last_n,
+ };
+ let mut pipeline =
+ TextGeneration::new(model, tokenizer, generation_options, args.seed, &device);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 6f8766d4..b2d7d938 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
-mod model;
+use candle_transformers::models::llama as model;
use model::{Config, Llama, LlamaConfig};
const EOS_TOKEN: &str = "";
-const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Parser, Debug)]
@@ -43,6 +42,10 @@ struct Args {
#[arg(long)]
temperature: Option