Compare commits

...

42 Commits
0.8.0 ... 0.8.2

Author SHA1 Message Date
236c35e578 Bump the caret version to 0.8.2. (#2703) 2025-01-07 15:50:16 +01:00
6f8351dfda add link to README (#2701) 2025-01-04 23:07:30 +01:00
57f41da13b Fix mistral attention on Metal (#2699)
Co-authored-by: Luka Zakrajsek <luka.zakrajsek@soniox.com>
2025-01-04 16:11:20 +01:00
cbaa0ad46f UniPC for diffusion sampling (#2684)
* feat: Add unipc multistep scheduler

* chore: Clippy and formatting

* chore: Update comments

* chore: Avoid unsafety in float ordering

* refactor: Update Scheduler::step mutability requirements

* fix: Corrector img2img

* chore: Update unipc ref link to latest diffusers release

* chore: Deduplicate float ordering

* fix: Panic when running with dev profile
2025-01-01 21:34:17 +01:00
b12c7c2888 Update the hf-hub dependency to 0.4.0. (#2691)
* Update the hf-hub dependency to 0.4.0.

* Fix the book.

* Use 0.4.1.
2024-12-31 19:07:47 +01:00
94ffc2ec6f Actually remove the default hf-hub cache path for glm. (#2696) 2024-12-31 11:00:44 +01:00
7354afc673 Use the default hf-hub cache for glm. (#2695) 2024-12-31 10:55:45 +01:00
2a705e6f37 Flash-Attn upgrade / SoftCap Candle-FlashAttn [3/n] (#2690)
* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace

* softcap is working, including test and api.

* make softcap test case better

* unpadded lse added
2024-12-31 10:04:47 +01:00
a594ef669c Flash-Attn upgrade / SoftCap Candle-FlashAttn [2/n] (#2689)
* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace

* softcap is working, including test and api.

* make softcap test case better

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-12-31 09:41:23 +01:00
71cd6d5533 Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace
2024-12-31 09:32:22 +01:00
d60eba1408 Streamline the glm4 example. (#2694) 2024-12-31 09:21:41 +01:00
e38e2a85dd Fix a cuda warning. (#2693) 2024-12-31 09:06:10 +01:00
460616fc84 Update README.org (#2670)
The command line error in the CPU section of the documentation.
2024-12-30 11:32:02 +01:00
91f1f019b1 Added XLMRobertaModel for Reranking (#2686)
* add xlm-roberta-base

* Add task enum for fill-mask and reranker in xlm-roberta example; update README and fix attention mask dimensions

- Introduced a new `Task` enum to replace string task identifiers in the xlm-roberta example.
- Updated the logic in `main.rs` to handle tasks using the new enum.
- Enhanced README with example output for fill-mask task.
- Fixed dimension retrieval in `prepare_4d_attention_mask` function for better clarity and safety.

* Clippy fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-12-30 11:16:57 +01:00
cd639131f0 Fix bug in whisper transformer (#2681)
* Fix bug in whisper transformer
- due to num_threads going to zero
in single threaded case

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-12-24 13:58:21 +01:00
11aa30be10 Fix Batcher iterator break when return_last_incomplete_batch and items.is_empty (#2654) (#2655) 2024-12-24 08:41:26 +01:00
1be6b090c7 Fix position encodings for Pixtral (#2678)
* init commit: add position id in meshgrid

* pass in subsampled positions

* clippy fix

* clippy fix
2024-12-23 13:22:35 +01:00
62ced44ea9 Add a Context trait similar to anyhow::Context. (#2676)
* Add a Context trait similar to anyhow::Context.

* Switch two unwrap to context.
2024-12-22 09:18:13 +01:00
5c2f893e5a make DepthAnythingV2 more reusable (#2675)
* make DepthAnythingV2 more reusable

* Fix clippy lints.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-12-21 12:06:03 +01:00
67cab7d6b8 Bump the crate version to 0.8.1. (#2662) 2024-12-07 17:03:53 +01:00
1807be84f4 Change/bert encoder public (#2658)
* change: BertEncoder struct to public

* change: make certain fields in Config struct public

* change: all fields in bert config struct to be public

* change: add clone to bert encoder and others

* Clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-12-04 21:22:30 +01:00
145aa7193c Add Nvembed v2 model (#2649)
* Update mod.rs

* Create mod.rs

* Create decoder.rs

* Create model.rs

* Create main.rs

* Create README.md

* Update README.md

* Update main.rs

* Update and rename decoder.rs to embedding.rs

* Update mod.rs

* Update model.rs
2024-12-03 10:56:01 +01:00
6f715f9256 add scatter add (#2656) 2024-12-01 18:39:38 +01:00
dba7a9c93e add u32 - U32 gather (#2653) 2024-11-30 23:18:07 +01:00
b52c2c6050 Clippy fixes for the cuda feature. (#2650) 2024-11-29 09:01:34 +01:00
4f59ed38b0 Adds support for stella_en_v5 embedding model -400M variant (#2608)
* Adds support for stella_en_v5 embedding model -400M variant

* Unified stella

* WIP: Unified Stella

* Combined stella for both 1.5B and 400M variants

* Cargo fmt for the CI

* removed redundant stella-400m model and example after merge into stella-en-v5

* cargo fmt --all

---------

Co-authored-by: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-11-29 09:01:08 +01:00
54e7fc3c97 Lint fixes introduced with Rust 1.83 (#2646)
* Fixes for lint errors introduced with Rust 1.83

* rustfmt

* Fix more lints.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-28 23:00:21 +01:00
23ed8a9ded Fix for whisper-microphone example failure if audio isn't chunk aligned (#2645)
At least on my macOS Sequoia system (MBP 14" 2021, M1 Pro), when I run
the `whisper-microphone` example after it has gathered 10 seconds of
audio, it fails before the transcription:

```
Error: Insufficient buffer size 384 for input channel 0, expected 1024
```

At least for the audio device I'm using (Airpods Pro Max), there is no
guarantee that each audio buffer is a multiple of 1024 samples.  Thus at
the end of the 10 seconds, `buffered_pcm` can have some samples at the
end that do not form a complete 1024 sample chunk.

This fixes that by tracking when there is a partial chunk at the end of
the buffer, and leaving it in `buffered_pcm` to be processed on the next
loop iteration.

Note that, in the interest of keeping this PR as small as possible, I
didn't make any other changes to this example.
2024-11-27 22:35:11 +01:00
21c686387c Onnx Support for Sign operation #2641 (#2642)
* Support for Sign operation #2641

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-26 23:10:09 +01:00
b4deb5c5a9 Provide a method to allow PTH files with state maps to be loaded. (#2639)
* Provide a method to allow PTH files iwth state maps to be loaded.

* add a line to the doc

* String-. &str
2024-11-26 22:52:53 +01:00
c12db594e3 fix typo (#2606) 2024-11-23 08:40:00 +01:00
f86f4d6224 Tweak the CI to avoid running out of disk space. (#2630)
* Tweak the CI to avoid running out of disk space.

* Linux only.
2024-11-19 04:32:36 +01:00
3159f91b90 20241118 docs (#2629)
* module docs

* varbuilder gguf docs

* add a link to gguf files

* small additonal mod doc titles

* safetensor docs

* more core docs

* more module docs in canlde_core

* 2 more link fixes
2024-11-19 04:07:07 +01:00
1a0f9ccf16 Import the ggml_cuda_dp4a function. (#2628) 2024-11-19 03:41:34 +01:00
e86565624b Fix for clippy. (#2626) 2024-11-18 14:32:38 +01:00
386fd8abb4 Module Docs (#2624)
* update whisper

* update llama2c

* update t5

* update phi and t5

* add a blip model

* qlamma doc

* add two new docs

* add docs and emoji

* additional models

* openclip

* pixtral

* edits on the  model docs

* update yu

* update a fe wmore models

* add persimmon

* add model-level doc

* names

* update module doc

* links in heira

* remove empty URL

* update more hyperlinks

* updated hyperlinks

* more links

* Update mod.rs

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2024-11-18 14:19:23 +01:00
12d7e7b145 More Model Module Docs (#2623)
* dinov2

* add another example

* ad dinov2reg4

* eva2

* efficientvit

* moondream

* update t5

* update t5

* rwkv

* stable diffusion docs

* add wasm link

* add segment_anything

* adjsut for clippy

* ignore bertdoc

* dinov2 ignore

* update block to be text

* remove the rust blocks for the moment

* bump python to 3.11

* add a setup-python step

* add py311 to test as well
2024-11-17 20:27:24 +01:00
a3f200e369 Module Docs (#2620)
* update bert docs

* update based

* update bigcode

* add pixtral

* add flux as well
2024-11-16 09:09:17 +01:00
00d8a0c178 Remove some unused macros. (#2618)
* Remove some unused macros.

* More unused fixes.
2024-11-15 16:46:55 +01:00
f689ce5d39 Documentation Pass for Models (#2617)
* links in chinese_clip

* links for clip model

* add mod docs for flux and llava

* module doc for MMDIT and MIMI

* add docs for a few more modesl

* mod docs for bert naser and beit

* add module docs for convmixer colpali codegeex and chatglm

* add another series of moddocs

* add  fastvit-llama2_c

* module docs mamba -> mobileone

* module docs from moondream-phi3

* mod docs for quantized and qwen

* update to yi

* fix long names

* Update llama2_c.rs

* Update llama2_c_weights.rs

* Fix the link for mimi + tweaks

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2024-11-15 08:30:15 +01:00
0ed24b9852 Add max-all/min-all. (#2616) 2024-11-14 21:08:04 +01:00
06350c31c7 Add some missing index-select metal kernels. (#2613)
* Add some missing index-select metal kernels.

* Make some matrix contiguous pre-matmul.
2024-11-12 17:10:12 +01:00
232 changed files with 5762 additions and 804 deletions

View File

@ -16,6 +16,9 @@ jobs:
rust: [stable]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -34,7 +37,13 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- name: Delete huge unnecessary tools folder
if: runner.os == 'Linux'
run: rm -rf /opt/hostedtoolcache
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
profile: minimal

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.8.0"
version = "0.8.2"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,20 +33,20 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.8.0" }
candle-datasets = { path = "./candle-datasets", version = "0.8.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.0" }
candle-kernels = { path = "./candle-kernels", version = "0.8.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.0" }
candle-nn = { path = "./candle-nn", version = "0.8.0" }
candle-onnx = { path = "./candle-onnx", version = "0.8.0" }
candle-transformers = { path = "./candle-transformers", version = "0.8.0" }
candle = { path = "./candle-core", package = "candle-core", version = "0.8.2" }
candle-datasets = { path = "./candle-datasets", version = "0.8.2" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.2" }
candle-kernels = { path = "./candle-kernels", version = "0.8.2" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.2" }
candle-nn = { path = "./candle-nn", version = "0.8.2" }
candle-onnx = { path = "./candle-onnx", version = "0.8.2" }
candle-transformers = { path = "./candle-transformers", version = "0.8.2" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = { version = "0.3.3", package = "candle-hf-hub" }
hf-hub = "0.4.1"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }

View File

@ -189,6 +189,7 @@ And then head over to
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
- [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible.
- [`llms-from-scratch-rs`](https://github.com/nerdai/llms-from-scratch-rs): A comprehensive Rust translation of the code from Sebastian Raschka's Build an LLM from Scratch book.
If you have an addition to this list, please submit a pull request.

View File

@ -11,8 +11,8 @@ Then let's start by downloading the [model file](https://huggingface.co/bert-bas
```rust
# extern crate candle_core;
# extern crate candle_hf_hub;
use candle_hf_hub::api::sync::Api;
# extern crate hf_hub;
use hf_hub::api::sync::Api;
use candle_core::Device;
let api = Api::new().unwrap();
@ -50,8 +50,8 @@ Now that we have our weights, we can use them in our bert architecture:
```rust
# extern crate candle_core;
# extern crate candle_nn;
# extern crate candle_hf_hub;
# use candle_hf_hub::api::sync::Api;
# extern crate hf_hub;
# use hf_hub::api::sync::Api;
#
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());

View File

@ -1,3 +1,5 @@
//! Traits to Define Backend Behavior
//!
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};

View File

@ -1,4 +1,4 @@
/// Methods for backpropagation of gradients.
//! Methods for backpropagation of gradients.
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;

View File

@ -1,3 +1,5 @@
//! 1D and 2D Convolutions
//!
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
#[derive(Debug, Clone, PartialEq, Eq)]

View File

@ -1,3 +1,5 @@
//! Traits and methods for CPU-backed Tensors
pub mod erf;
pub mod kernels;

View File

@ -1,3 +1,4 @@
//! Implementation of Backend Fns for CPU
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
@ -65,7 +66,7 @@ impl Map2U8 for Cmp {
struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
impl<'a, I: IntDType> Map2 for WCond<'a, I> {
impl<I: IntDType> Map2 for WCond<'_, I> {
const OP: &'static str = "where";
#[inline(always)]
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
@ -215,7 +216,7 @@ struct ReduceSum<'a> {
reduce_dims_and_stride: Vec<(usize, usize)>,
}
impl<'a> ReduceSum<'a> {
impl ReduceSum<'_> {
#[inline(always)]
fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
where
@ -280,7 +281,7 @@ impl<'a> ReduceSum<'a> {
}
}
impl<'a> Map1 for ReduceSum<'a> {
impl Map1 for ReduceSum<'_> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
self.fold_impl(src, src_l, T::zero())
@ -453,7 +454,7 @@ struct Gather<'a, I: IntDType> {
dim: usize,
}
impl<'a, I: IntDType> Map1 for Gather<'a, I> {
impl<I: IntDType> Map1 for Gather<'_, I> {
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
@ -506,7 +507,7 @@ struct IndexSelect<'a, T: IntDType> {
dim: usize,
}
impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
impl<I: IntDType> Map1 for IndexSelect<'_, I> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
@ -559,7 +560,7 @@ struct ScatterAdd<'a, I: IntDType> {
dim: usize,
}
impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
const OP: &'static str = "scatter-add";
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let dst_len = l1.shape().elem_count();
@ -615,7 +616,7 @@ struct IndexAdd<'a, I: IntDType> {
dim: usize,
}
impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
impl<I: IntDType> Map2 for IndexAdd<'_, I> {
const OP: &'static str = "index-add";
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
// v1, l1 -> self
@ -735,7 +736,7 @@ fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl<'a> Map2 for Conv1D<'a> {
impl Map2 for Conv1D<'_> {
const OP: &'static str = "conv1d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
@ -959,7 +960,7 @@ impl Map1 for Col2Im1D {
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
impl Map2 for ConvTranspose1D<'_> {
const OP: &'static str = "conv_transpose1d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
@ -1028,7 +1029,7 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
impl Map2 for Conv2D<'_> {
const OP: &'static str = "conv2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
@ -1116,7 +1117,7 @@ impl<'a> Map2 for Conv2D<'a> {
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl<'a> Map2 for ConvTranspose2D<'a> {
impl Map2 for ConvTranspose2D<'_> {
const OP: &'static str = "conv_transpose2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;

View File

@ -1,3 +1,5 @@
//! Implementation of Backend traits for CUDA device
//!
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
@ -253,7 +255,7 @@ impl Map1 for Powf {
}
struct FastReduce<'a>(&'a [usize], ReduceOp);
impl<'a> Map1Any for FastReduce<'a> {
impl Map1Any for FastReduce<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
@ -348,7 +350,7 @@ impl<U: UnaryOpT> Map1 for U {
}
struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map1 for IndexSelect<'a> {
impl Map1 for IndexSelect<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
@ -408,7 +410,7 @@ impl<'a> Map1 for IndexSelect<'a> {
}
struct Gather<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map1 for Gather<'a> {
impl Map1 for Gather<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
@ -459,7 +461,7 @@ impl<'a> Map1 for Gather<'a> {
}
struct IndexAdd<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map2InPlace for IndexAdd<'a> {
impl Map2InPlace for IndexAdd<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
@ -507,7 +509,7 @@ impl<'a> Map2InPlace for IndexAdd<'a> {
}
struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map2InPlace for ScatterAdd<'a> {
impl Map2InPlace for ScatterAdd<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
@ -552,7 +554,7 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
}
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl<'a> Map2 for Conv1D<'a> {
impl Map2 for Conv1D<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
@ -593,7 +595,7 @@ impl<'a> Map2 for Conv1D<'a> {
}
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
impl Map2 for Conv2D<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
@ -658,7 +660,7 @@ impl Map1 for Col2Im1D {
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
impl Map2 for ConvTranspose1D<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
@ -707,7 +709,7 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
}
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl<'a> Map2 for ConvTranspose2D<'a> {
impl Map2 for ConvTranspose2D<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
@ -848,7 +850,7 @@ impl Map1 for UpsampleNearest2D {
}
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
impl<'a> Map2 for WhereCond<'a> {
impl Map2 for WhereCond<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
t: &CudaSlice<T>,

View File

@ -11,6 +11,7 @@ pub enum DeviceLocation {
Metal { gpu_id: usize },
}
/// Cpu, Cuda, or Metal
#[derive(Debug, Clone)]
pub enum Device {
Cpu,

View File

@ -1,6 +1,7 @@
/// Pretty printing of tensors
/// This implementation should be in line with the PyTorch version.
/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py
//! Pretty printing of tensors
//!
//! This implementation should be in line with the [PyTorch version](https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py).
//!
use crate::{DType, Result, Tensor, WithDType};
use half::{bf16, f16};

View File

@ -1,3 +1,5 @@
//! Implementation of the Cuda backend when Cuda support has not been compiled in.
//!
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};

View File

@ -1,3 +1,4 @@
//! Candle-specific Error and Result
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
#[derive(Debug, Clone)]
@ -8,8 +9,14 @@ pub struct MatMulUnexpectedStriding {
pub msg: &'static str,
}
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
}
}
/// Main library error type.
#[derive(thiserror::Error, Debug)]
#[derive(thiserror::Error)]
pub enum Error {
// === DType Errors ===
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
@ -198,8 +205,14 @@ pub enum Error {
UnsupportedSafeTensorDtype(safetensors::Dtype),
/// Arbitrary errors wrapping.
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
#[error("{0}")]
Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
#[error("{context}\n{inner}")]
Context {
inner: Box<Self>,
context: Box<dyn std::fmt::Display + Send + Sync>,
},
/// Adding path information to an error.
#[error("path: {path:?} {inner}")]
@ -217,16 +230,19 @@ pub enum Error {
/// User generated error message, typically created via `bail!`.
#[error("{0}")]
Msg(String),
#[error("unwrap none")]
UnwrapNone,
}
pub type Result<T> = std::result::Result<T, Error>;
impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err)).bt()
}
pub fn msg(err: impl std::error::Error) -> Self {
pub fn msg(err: impl std::fmt::Display) -> Self {
Self::Msg(err.to_string()).bt()
}
@ -252,6 +268,13 @@ impl Error {
path: p.as_ref().to_path_buf(),
}
}
pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
Self::Context {
inner: Box::new(self),
context: Box::new(c),
}
}
}
#[macro_export]
@ -274,3 +297,41 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
(_, Err(e)) => Err(e),
}
}
// Taken from anyhow.
pub trait Context<T> {
/// Wrap the error value with additional context.
fn context<C>(self, context: C) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static;
/// Wrap the error value with additional context that is evaluated lazily
/// only once an error does occur.
fn with_context<C, F>(self, f: F) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C;
}
impl<T> Context<T> for Option<T> {
fn context<C>(self, context: C) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
{
match self {
Some(v) => Ok(v),
None => Err(Error::UnwrapNone.context(context).bt()),
}
}
fn with_context<C, F>(self, f: F) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C,
{
match self {
Some(v) => Ok(v),
None => Err(Error::UnwrapNone.context(f()).bt()),
}
}
}

View File

@ -1,3 +1,4 @@
//! Tensor Layouts including contiguous or sparse strides
use crate::{Error, Result, Shape};
#[derive(Debug, PartialEq, Eq, Clone)]

View File

@ -7,8 +7,8 @@
//!
//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
//! let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
//!
//! let c = a.matmul(&b)?;
//!
//! # Ok(())}
//! ```
//!
@ -94,7 +94,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use error::{Context, Error, Result};
pub use indexer::{IndexOp, TensorIndexer};
pub use layout::Layout;
pub use shape::{Shape, D};
@ -140,7 +140,7 @@ impl ToUsize2 for (usize, usize) {
}
}
// A simple trait defining a module with forward method using a single argument.
/// Defining a module with forward method using a single argument.
pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
@ -160,8 +160,8 @@ impl<M: Module> Module for Option<&M> {
}
}
// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
/// A single forward method using a single single tensor argument and a flag to
/// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}

View File

@ -1,3 +1,5 @@
//! Implementation of Backend traits for Metal
//!
use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
@ -1237,11 +1239,12 @@ impl BackendStorage for MetalStorage {
let dst_el = ids_l.shape().elem_count();
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let buffer = device.new_buffer(dst_el, dtype, "gather")?;
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16",
(DType::U32, DType::BF16) => "gather_u32_bf16",
(DType::U32, DType::U32) => "gather_u32_u32",
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
};
let command_buffer = self.device.command_buffer()?;
@ -1281,6 +1284,7 @@ impl BackendStorage for MetalStorage {
(DType::U8, DType::F32) => "sa_u8_f32",
(DType::U8, DType::F16) => "sa_u8_f16",
(DType::U8, DType::BF16) => "sa_u8_bf16",
(DType::U32, DType::U32) => "sa_u32_u32",
(DType::U32, DType::F32) => "sa_u32_f32",
(DType::U32, DType::F16) => "sa_u32_f16",
(DType::U32, DType::BF16) => "sa_u32_bf16",
@ -1324,14 +1328,23 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::U8) => "is_u8_u8",
(DType::U8, DType::U32) => "is_u8_u32",
(DType::U8, DType::I64) => "is_u8_i64",
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U8, DType::F32) => "is_u8_f32",
(DType::U8, DType::F16) => "is_u8_f16",
(DType::U32, DType::U8) => "is_u32_u8",
(DType::U32, DType::U32) => "is_u32_u32",
(DType::U32, DType::I64) => "is_u32_i64",
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",
(DType::I64, DType::U8) => "is_i64_u8",
(DType::I64, DType::U32) => "is_i64_u32",
(DType::I64, DType::I64) => "is_i64_i64",
(DType::I64, DType::F32) => "is_i64_f32",
(DType::I64, DType::F16) => "is_i64_f16",
(DType::I64, DType::BF16) => "is_i64_bf16",

View File

@ -1,3 +1,5 @@
//! Tensor Opertion Enums and Traits
//!
#![allow(clippy::redundant_closure_call)]
use crate::Tensor;
use half::{bf16, f16};

View File

@ -1,7 +1,7 @@
// Just enough pickle support to be able to read PyTorch checkpoints.
//! Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point.
use crate::{DType, Error as E, Layout, Result, Tensor};
use crate::{Context, DType, Error as E, Layout, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
use std::io::BufRead;
@ -537,7 +537,7 @@ impl Stack {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
let key = objs.pop().context("empty objs")?;
d.push((key, value))
}
} else {
@ -557,7 +557,7 @@ impl Stack {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
let key = objs.pop().context("empty objs")?;
pydict.push((key, value))
}
self.push(Object::Dict(pydict))
@ -661,7 +661,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
if !file_name.ends_with("data.pkl") {
continue;
}
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
let reader = zip.by_name(file_name)?;
let mut reader = std::io::BufReader::new(reader);
let mut stack = Stack::empty();

View File

@ -36,7 +36,7 @@ pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
pub const MATRIX_ROW_PADDING: usize = 512;
fn ceil_div(p: usize, q: usize) -> usize {
(p + q - 1) / q
p.div_ceil(q)
}
fn pad(p: usize, q: usize) -> usize {

View File

@ -134,7 +134,7 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
super::QTensor::new(data, dims)
}
/// Creates a [Tensor] from a raw GGML tensor.
/// Creates a Tensor from a raw GGML tensor.
pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],

View File

@ -1,9 +1,8 @@
//! Support for the GGUF file format.
//! Support for the [GGUF file format](https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md).
//!
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
use crate::{Device, Result};
use crate::{Context, Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
@ -339,7 +338,7 @@ impl Value {
if value_type.len() != 1 {
crate::bail!("multiple value-types in the same array {value_type:?}")
}
value_type.into_iter().next().unwrap()
value_type.into_iter().next().context("empty value_type")?
};
w.write_u32::<LittleEndian>(value_type.to_u32())?;
w.write_u64::<LittleEndian>(v.len() as u64)?;
@ -458,7 +457,7 @@ impl Content {
Some(Value::I32(v)) if *v >= 0 => *v as u64,
_ => DEFAULT_ALIGNMENT,
};
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
let tensor_data_offset = position.div_ceil(alignment) * alignment;
Ok(Self {
magic,
metadata,

View File

@ -1850,8 +1850,8 @@ pub fn matmul<T: GgmlType>(
crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
}
let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE;
let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE;
let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE);
let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE);
// TODO: Do not make this copy if the DotType is f32.
// TODO: Pre-allocate this.
let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];

View File

@ -1,4 +1,5 @@
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
//! Code for GGML and GGUF files
use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
@ -480,7 +481,7 @@ impl crate::CustomOp1 for QTensor {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
let last_k = dst_shape.pop().context("empty dst_shape")?;
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}

View File

@ -1,3 +1,14 @@
//! Module to load `safetensor` files into CPU/GPU memory.
//!
//! There are multiple ways to load tensors from safetensor files:
//! - `load` function for loading directly into memory and returning a HashMap of tensors
//! - `MmapedSafetensors` for memory mapping files and avoiding full allocation
//! - `SliceSafetensors` for working with in-memory buffers
//! - `BufferedSafetensors` for owning a buffer of data
//!
//! Tensors can also be serialized to safetensor format using the `save` function or
//! `Tensor::save_safetensors` method.
//!
use crate::{DType, Device, Error, Result, Tensor, WithDType};
use safetensors::tensor as st;
use safetensors::tensor::SafeTensors;
@ -171,7 +182,7 @@ pub trait Load {
fn load(&self, device: &Device) -> Result<Tensor>;
}
impl<'a> Load for st::TensorView<'a> {
impl Load for st::TensorView<'_> {
fn load(&self, device: &Device) -> Result<Tensor> {
convert(self, device)
}

View File

@ -1,3 +1,5 @@
//! TensorScalar Enum and Trait
//!
use crate::{Result, Tensor, WithDType};
pub enum TensorScalar {

View File

@ -52,6 +52,49 @@ impl ArgSort {
}
}
#[cfg(feature = "cuda")]
mod cuda {
use super::*;
use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
use crate::{CudaDevice, WithDType};
impl crate::cuda_backend::Map1Any for ArgSort {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &crate::Layout,
_wrap: W,
) -> Result<S> {
let slice = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
} else {
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
};
let ncols = self.last_dim;
let nrows = elem_count / ncols;
let ncols_pad = next_power_of_2(ncols);
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
let cfg = LaunchConfig {
grid_dim: (1, nrows as u32, 1),
block_dim: (ncols_pad as u32, 1, 1),
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
};
unsafe { func.launch(cfg, params) }.w()?;
Ok(S::U32(dst))
}
}
}
impl crate::CustomOp1 for ArgSort {
fn name(&self) -> &'static str {
"argsort"
@ -81,46 +124,8 @@ impl crate::CustomOp1 for ArgSort {
storage: &crate::CudaStorage,
layout: &crate::Layout,
) -> Result<(crate::CudaStorage, crate::Shape)> {
use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
use crate::{CudaDevice, WithDType};
impl Map1Any for ArgSort {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &crate::Layout,
_wrap: W,
) -> Result<S> {
let slice = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
} else {
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
};
let ncols = self.last_dim;
let nrows = elem_count / ncols;
let ncols_pad = next_power_of_2(ncols);
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
let cfg = LaunchConfig {
grid_dim: (1, nrows as u32, 1),
block_dim: (ncols_pad as u32, 1, 1),
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
};
unsafe { func.launch(cfg, params) }.w()?;
Ok(S::U32(dst))
}
}
use crate::backend::BackendStorage;
use crate::cuda_backend::Map1Any;
let dev = storage.device();
let slice = self.map(&storage.slice, dev, layout)?;
let dst = crate::cuda_backend::CudaStorage {

View File

@ -1,3 +1,5 @@
//! StreamTensror useful for streaming ops.
//!
use crate::{Result, Shape, Tensor};
pub trait Dim: crate::shape::Dim + Copy {}

View File

@ -32,7 +32,7 @@ impl<'a> StridedIndex<'a> {
}
}
impl<'a> Iterator for StridedIndex<'a> {
impl Iterator for StridedIndex<'_> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {

View File

@ -242,7 +242,7 @@ impl Tensor {
Self::zeros_impl(shape, dtype, device, false)
}
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other
/// Creates a new tensor filled with zeros with same shape, dtype, and device as the other
/// tensor.
///
/// ```rust
@ -1760,6 +1760,42 @@ impl Tensor {
&self.op
}
/// Computes the max of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.max_all()?;
/// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn max_all(&self) -> Result<Tensor> {
if self.rank() == 0 {
Ok(self.clone())
} else {
self.flatten_all()?.max(0)
}
}
/// Computes the min of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.min_all()?;
/// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn min_all(&self) -> Result<Tensor> {
if self.rank() == 0 {
Ok(self.clone())
} else {
self.flatten_all()?.min(0)
}
}
/// Computes the sum of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///

View File

@ -1,4 +1,4 @@
use crate::{shape::Dim, Error, Result, Shape, Tensor};
use crate::{shape::Dim, Context, Error, Result, Shape, Tensor};
impl Tensor {
/// Concatenates two or more tensors along a particular dimension.
@ -134,7 +134,7 @@ impl Tensor {
.bt())?
}
}
let next_offset = offsets.last().unwrap() + arg.elem_count();
let next_offset = offsets.last().context("empty offsets")? + arg.elem_count();
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);

View File

@ -1,3 +1,4 @@
//! Useful functions for checking features.
use std::str::FromStr;
pub fn get_num_threads() -> usize {

View File

@ -78,7 +78,7 @@ impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> {
match self.inner.inner.next() {
Some(item) => items.push(item),
None => {
if self.return_last_incomplete_batch {
if self.return_last_incomplete_batch && !items.is_empty() {
break;
}
return None;
@ -102,7 +102,7 @@ impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> {
ys.push(y)
}
None => {
if self.return_last_incomplete_batch {
if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() {
break;
}
return None;
@ -127,7 +127,7 @@ impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> {
match self.inner.inner.next() {
Some(item) => items.push(item),
None => {
if self.return_last_incomplete_batch {
if self.return_last_incomplete_batch && !items.is_empty() {
break;
}
return None;
@ -154,7 +154,7 @@ impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResu
}
Some(Err(err)) => errs.push(err),
None => {
if self.return_last_incomplete_batch {
if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() {
break;
}
return None;

View File

@ -87,7 +87,7 @@ impl<'a> DatasetRandomIter<'a> {
}
}
impl<'a> Iterator for DatasetRandomIter<'a> {
impl Iterator for DatasetRandomIter<'_> {
type Item = Result<(Tensor, Tensor)>;
fn next(&mut self) -> Option<Self::Item> {

View File

@ -27,7 +27,7 @@ intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.22.0", features = ["auto-initialize"], optional = true }
pyo3 = { version = "0.22.0", features = ["auto-initialize", "abi3-py311"], optional = true }
rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
safetensors = { workspace = true }

View File

@ -13,7 +13,7 @@ THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios,
** Running with ~cpu~
#+begin_src shell
cargo run --example codegeex4-9b --release --cpu -- --prompt "please write a insertion sort in rust" --sample-len 300
cargo run --example codegeex4-9b --release -- --cpu --prompt "please write a insertion sort in rust" --sample-len 300
#+end_src
** Output_Example

View File

@ -6,10 +6,8 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use std::ffi::OsString;
use std::path::PathBuf;
use clap::Parser;
use std::{ffi::OsString, path::PathBuf, sync::Arc};
use candle::DType::{F32, U8};
use candle::{DType, Device, Module, Result, Tensor};
@ -82,7 +80,7 @@ pub fn main() -> anyhow::Result<()> {
};
let config = DepthAnythingV2Config::vit_small();
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
let depth_anything = DepthAnythingV2::new(Arc::new(dinov2), config, vb)?;
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;

View File

@ -250,7 +250,11 @@ fn run(args: Args) -> Result<()> {
};
println!("img\n{img}");
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
let filename = match args.seed {
None => "out.jpg".to_string(),
Some(s) => format!("out-{s}.jpg"),
};
candle_examples::save_image(&img.i(0)?, filename)?;
Ok(())
}

View File

@ -7,48 +7,25 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
** Running with ~cuda~
#+begin_src shell
cargo run --example glm4 --release --features cuda
cargo run --example glm4 --release --features cuda -- --prompt "Hello world"
#+end_src
** Running with ~cpu~
#+begin_src shell
cargo run --example glm4 --release -- --cpu
cargo run --example glm4 --release -- --cpu--prompt "Hello world"
#+end_src
** Output Example
#+begin_src shell
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
Finished release [optimized] target(s) in 0.24s
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
cargo run --features cuda -r --example glm4 -- --prompt "Hello "
avx: true, neon: false, simd128: false, f16c: true
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
cache path .
retrieved the files in 6.88963ms
loaded the model in 6.113752297s
retrieved the files in 6.454375ms
loaded the model in 3.652383779s
starting the inference loop
[欢迎使用GLM-4,请输入prompt]
请你告诉我什么是FFT
266 tokens generated (34.50 token/s)
Result:
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换DFT的方法它广泛应用于信号处理、图像处理和数据分析等领域。
具体来说FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中我们通常需要知道信号的频率成分这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。
以下是使用 Python 中的 numpy 进行 FFT 的简单示例:
```python
import numpy as np
# 创建一个时域信号
t = np.linspace(0, 1, num=100)
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)
# 对该信号做FFT变换并计算其幅值谱
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))
```
在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
Hello 2018, hello new year! Im so excited to be back and sharing with you all my favorite things from the past month. This is a monthly series where I share whats been inspiring me lately in hopes that it will inspire you too!
...
#+end_src
This example will read prompt from stdin

View File

@ -12,136 +12,117 @@ struct TextGeneration {
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
args: Args,
dtype: DType,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
dtype: DType,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
fn new(model: Model, tokenizer: Tokenizer, args: Args, device: &Device, dtype: DType) -> Self {
let logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
args,
device: device.clone(),
dtype,
}
}
fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
use std::io::BufRead;
use std::io::BufReader;
fn run(&mut self) -> anyhow::Result<()> {
use std::io::Write;
let args = &self.args;
println!("starting the inference loop");
println!("[欢迎使用GLM-4,请输入prompt]");
let stdin = std::io::stdin();
let reader = BufReader::new(stdin);
for line in reader.lines() {
let line = line.expect("Failed to read line");
let tokens = self.tokenizer.encode(line, true).expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
let tokens = self
.tokenizer
.encode(args.prompt.to_string(), true)
.expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
if args.verbose {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
} else {
print!("{}", &args.prompt);
std::io::stdout().flush()?;
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();
for index in 0..args.sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();
let mut count = 0;
let mut result = vec![];
for index in 0..sample_len {
count += 1;
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("Token error");
if self.verbose_prompt {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
count, next_token, token
);
}
result.push(token);
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("token decode error");
if args.verbose {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
generated_tokens, next_token, token
);
} else {
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
println!("Result:");
for tokens in result {
print!("{tokens}");
}
self.model.reset_kv_cache(); // clean the cache
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(name = "cache", short, long, default_value = ".")]
cache_path: String,
#[arg(name = "cache", short)]
cache_path: Option<String>,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
prompt: String,
/// Display the tokens for the specified prompt and outputs.
#[arg(long)]
verbose: bool,
/// The temperature used to generate samples.
#[arg(long)]
@ -197,28 +178,32 @@ fn main() -> anyhow::Result<()> {
);
let start = std::time::Instant::now();
println!("cache path {}", args.cache_path);
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
.build()
.map_err(anyhow::Error::msg)?;
let api = match args.cache_path.as_ref() {
None => hf_hub::api::sync::Api::new()?,
Some(path) => {
hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(path.to_string().into()))
.build()
.map_err(anyhow::Error::msg)?
}
};
let model_id = match args.model_id {
let model_id = match args.model_id.as_ref() {
Some(model_id) => model_id.to_string(),
None => "THUDM/glm-4-9b".to_string(),
};
let revision = match args.revision {
let revision = match args.revision.as_ref() {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
let tokenizer_filename = match args.tokenizer.as_ref() {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.map_err(anyhow::Error::msg)?,
};
let filenames = match args.weight_file {
let filenames = match args.weight_file.as_ref() {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
@ -238,18 +223,7 @@ fn main() -> anyhow::Result<()> {
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
dtype,
);
pipeline.run(args.sample_len)?;
let mut pipeline = TextGeneration::new(model, tokenizer, args, &device, dtype);
pipeline.run()?;
Ok(())
}

View File

@ -17,7 +17,7 @@ pub struct Config {
impl Config {
fn vocab_size(&self) -> usize {
let pad = self.pad_vocab_size_multiple;
(self.vocab_size + pad - 1) / pad * pad
self.vocab_size.div_ceil(pad) * pad
}
fn dt_rank(&self) -> usize {

View File

@ -0,0 +1,43 @@
# NV-Embed-v2
Candle implementation (inference only) of [NV-Embed-v2](https://huggingface.co/nvidia/NV-Embed-v2), a text embedding model that ranks No. 1 (as of Nov 25 2024) on the [MTEB](https://huggingface.co/spaces/mteb/leaderboard) benchmark with a score of 72.31 across 56 text embedding tasks.
## Running an example: Retrieval
```bash
cargo run --example nvembed_v2 --release
> scores: [[87.4269, 0.4629],
> [ 0.9653, 86.0372]]
> Tensor[[2, 2], f32]
```
In this example, we have two queries and two passages (the corresponding answers). The output tensor represents the similarity scores between each query-passage pair. The scores are computed by taking the dot product of the query and passage embeddings and scaling the result by 100.
```rust
let queries = [
"are judo throws allowed in wrestling?",
"how to become a radiology technician in michigan?",
];
let query_instruction =
"Instruct: Given a question, retrieve passages that answer the question\nQuery: "
.to_string();
let passages = [
"Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.",
"Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan."
];
let passage_instruction = "".to_string();
```
If you already have the model and tokenizer files, you can use the `--tokenizer` and `--model-files` options to specify their full paths, instead of downloading them from the hub.
## Running an example: Sentence embedding
```bash
cargo run --example nvembed_v2 --release -- --prompt "Here is a test sentence"
> Embedding: [[ 0.0066, -0.0048, 0.0066, ..., -0.0096, 0.0119, -0.0052]]
> Tensor[[1, 4096], f32]
```
In this example, we pass a prompt to the model and it outputs the vector encoding of the prompt.
## Hardware Requirements
29.25GB at fp32
## License
CC-BY-NC-4.0. This model should not be used for any commercial purpose. Refer the [license](https://spdx.org/licenses/CC-BY-NC-4.0) for the detailed terms.

View File

@ -0,0 +1,214 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use candle::{DType, IndexOp, Shape, Tensor, D};
use candle_nn::VarBuilder;
use candle_transformers::models::nvembed_v2::model::Model;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingDirection, PaddingParams, Tokenizer, TruncationParams};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
model: Option<String>,
/// Comma-separated list of model files (e.g., '/path/file1.safetensors,/path/file2.safetensors,/path/file3.safetensors')
#[arg(long)]
model_files: Option<String>,
}
impl Args {
fn build_model_and_tokenizer(&self) -> anyhow::Result<(Model, tokenizers::Tokenizer)> {
let model_name = match self.model.as_ref() {
Some(model) => model.to_string(),
None => "nvidia/NV-Embed-v2".to_string(),
};
let api = Api::new()?;
let repo = api.repo(Repo::new(model_name.to_string(), RepoType::Model));
let model_files = match &self.model_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
let tokenizer_file = match &self.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let device = candle_examples::device(self.cpu)?;
let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
let _ = tokenizer
.with_padding(Some(PaddingParams {
direction: PaddingDirection::Right,
pad_id: 2,
pad_token: "</s>".to_string(),
..Default::default()
}))
.with_truncation(Some(TruncationParams {
max_length: 32768,
..Default::default()
}));
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device) }?;
let nvembed_model = Model::new(vb);
Ok((nvembed_model?, tokenizer))
}
}
fn encode(
model: &mut Model,
tokenizer: &Tokenizer,
examples: Vec<String>,
instruction: &str,
) -> Result<Tensor> {
let device = &model.device;
let dtype = model.dtype;
// Format input text
let eos_token = if let Some(padding) = tokenizer.get_padding() {
padding.pad_token.clone()
} else {
"".to_string()
};
let bos = "<s>".to_string();
let input_texts = examples
.iter()
.map(|input_example| format!("{bos}{instruction}{input_example}{eos_token}"))
.collect::<Vec<String>>();
// Tokenize
let encodings = tokenizer.encode_batch(input_texts, false).map_err(E::msg)?;
let input_ids_list = encodings
.iter()
.map(|encoding| {
Tensor::from_slice(
encoding.get_ids(),
Shape::from(encoding.get_ids().len()),
device,
)
})
.collect::<Result<Vec<_>, _>>()?;
let input_ids = Tensor::stack(&input_ids_list, 0)?;
// Mask out padding tokens for both embedding model and latent attention model
let attention_masks: Vec<Tensor> = encodings
.iter()
.map(|encoding| {
Tensor::from_slice(
encoding.get_attention_mask(),
Shape::from(encoding.get_attention_mask().len()),
device,
)?
.to_dtype(dtype)
})
.collect::<Result<Vec<_>, _>>()?;
let attention_mask = Tensor::stack(&attention_masks, 0)?;
// Mask out instruction tokens for latent attention model
let pool_mask = if !instruction.is_empty() {
let encoded_instruction = tokenizer.encode(instruction, false).map_err(E::msg)?;
let instruction_lens = encoded_instruction.get_tokens().len();
let zeros = Tensor::zeros(
attention_mask.i((.., ..instruction_lens))?.shape(),
dtype,
device,
)?;
let b = attention_mask.dims()[0];
attention_mask.slice_assign(&[..b, ..instruction_lens], &zeros)?
} else {
attention_mask.clone()
};
let hiddens = model
.forward(&input_ids, &attention_mask, &pool_mask)?
.squeeze(1)?;
// Normalize embedding
div_l2_norm(&hiddens)
}
fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
Ok(v.broadcast_div(&l2_norm)?)
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (mut model, tokenizer) = args.build_model_and_tokenizer()?;
if let Some(prompt) = args.prompt {
let emb = encode(&mut model, &tokenizer, vec![prompt], "")?;
println!("Embedding: {emb}");
} else {
let queries = [
"are judo throws allowed in wrestling?",
"how to become a radiology technician in michigan?",
];
let passages = [
"Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.",
"Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan."
];
let passage_instruction = "".to_string();
let query_instruction =
"Instruct: Given a question, retrieve passages that answer the question\nQuery: "
.to_string();
let passages: Vec<String> = passages.iter().map(|s| s.to_string()).collect();
let queries: Vec<String> = queries.iter().map(|s| s.to_string()).collect();
let emb_query = encode(&mut model, &tokenizer, queries, &query_instruction)?;
let emb_passage = encode(&mut model, &tokenizer, passages, &passage_instruction)?;
let scores = (emb_query.matmul(&emb_passage.t()?)? * 100.0)?;
println!("scores: {scores}");
}
Ok(())
}

View File

@ -1,5 +1,4 @@
use std::collections::VecDeque;
use std::fmt::Display;
use candle::{DType, Device, Error, Module, Result, Tensor, Var};
use candle_nn::{
@ -167,6 +166,7 @@ fn track(
Ok(())
}
#[allow(unused)]
struct Actor<'a> {
varmap: VarMap,
vb: VarBuilder<'a>,
@ -211,7 +211,7 @@ impl Actor<'_> {
let target_network = make_network("target-actor")?;
// this sets the two networks to be equal to each other using tau = 1.0
track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0);
track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0)?;
Ok(Self {
varmap,
@ -244,6 +244,7 @@ impl Actor<'_> {
}
}
#[allow(unused)]
struct Critic<'a> {
varmap: VarMap,
vb: VarBuilder<'a>,
@ -287,7 +288,7 @@ impl Critic<'_> {
let target_network = make_network("target-critic")?;
// this sets the two networks to be equal to each other using tau = 1.0
track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0);
track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0)?;
Ok(Self {
varmap,
@ -322,6 +323,7 @@ impl Critic<'_> {
}
}
#[allow(unused)]
#[allow(clippy::upper_case_acronyms)]
pub struct DDPG<'a> {
actor: Actor<'a>,

View File

@ -1,4 +1,3 @@
#![allow(unused)]
//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
use candle::{Device, Result, Tensor};
use pyo3::prelude::*;

View File

@ -1,5 +1,3 @@
#![allow(unused)]
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

View File

@ -14,7 +14,7 @@ fn new_model(
) -> Result<(impl Module, VarMap)> {
let input_size = input_shape.iter().product();
let mut varmap = VarMap::new();
let varmap = VarMap::new();
let var_builder = VarBuilder::from_varmap(&varmap, dtype, device);
let model = seq()

View File

@ -1,9 +1,8 @@
#![allow(unused)]
//! Vectorized version of the gym environment.
use candle::{DType, Device, Result, Tensor};
use pyo3::prelude::*;
use pyo3::types::PyDict;
#[allow(unused)]
#[derive(Debug)]
pub struct Step {
pub obs: Tensor,
@ -11,6 +10,7 @@ pub struct Step {
pub is_done: Tensor,
}
#[allow(unused)]
pub struct VecGymEnv {
env: PyObject,
action_space: usize,
@ -21,6 +21,7 @@ fn w(res: PyErr) -> candle::Error {
candle::Error::wrap(res)
}
#[allow(unused)]
impl VecGymEnv {
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
Python::with_gil(|py| {

View File

@ -477,7 +477,7 @@ fn run(args: Args) -> Result<()> {
),
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let mut scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
if let Some(seed) = seed {
device.set_seed(seed)?;
@ -539,7 +539,7 @@ fn run(args: Args) -> Result<()> {
};
for idx in 0..num_samples {
let timesteps = scheduler.timesteps();
let timesteps = scheduler.timesteps().to_vec();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;

View File

@ -21,7 +21,7 @@ Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling
The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example.
```bash
$ cargo run --example stella-en-v5 --release --features <metal | cuda>
$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 1.5b
>
> Score: 0.8178786
@ -37,9 +37,29 @@ $ cargo run --example stella-en-v5 --release --features <metal | cuda>
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types >
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
>
$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 400m
>
> Score: 0.8397539
> Query: What are some ways to reduce stress?
> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending
> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent
> stress from building up.
>
>
>
> Score: 0.809545
> Query: What are the benefits of drinking green tea?
> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
>
```
## Supported options:
- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
- `Stella_en_v5` has 2 model variants published - a 1.5B variant and 400M variant. This is enabled through the flag `--which`. E.g. `--which 400m` or `--which 1.5b`.
- `Stella_en_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
- As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option.

View File

@ -212,6 +212,14 @@ impl EncodeTask {
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "1.5b")]
Large,
#[value(name = "400m")]
Small,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -219,6 +227,9 @@ struct Args {
#[arg(long)]
cpu: bool,
#[arg(long)]
which: Which,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -250,24 +261,33 @@ struct Args {
// Tokenizer creation is super critical in our case.
// We are going to be `padding: Left` for each batch
fn create_tokenizer(tokenizer_file: &Path) -> Result<Tokenizer> {
fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result<Tokenizer> {
let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
pad_id
} else {
return Err(anyhow!(
"Tokenizer doesn't contain expected `<|endoftext|>` token"
));
};
// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_id,
pad_token: "<|endoftext|>".to_string(),
..Default::default()
}));
if which == Which::Large {
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
pad_id
} else {
return Err(anyhow!(
"Tokenizer doesn't contain expected `<|endoftext|>` token"
));
};
// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_id,
pad_token: "<|endoftext|>".to_string(),
..Default::default()
}));
} else {
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Right,
..Default::default()
}));
}
Ok(tokenizer)
}
@ -298,7 +318,19 @@ fn main() -> Result<()> {
Some(d) => d,
None => EmbedDim::Dim1024,
};
let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string()));
let (repo, cfg) = match args.which {
Which::Large => (
"dunzhang/stella_en_1.5B_v5",
Config::new_1_5_b_v5(embed_dim.embed_dim()),
),
Which::Small => (
"dunzhang/stella_en_400M_v5",
Config::new_400_m_v5(embed_dim.embed_dim()),
),
};
let repo = api.repo(Repo::model(repo.to_string()));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
@ -330,7 +362,7 @@ fn main() -> Result<()> {
println!("retrieved the files in {:?}", start.elapsed());
// Initializing the tokenizer which would require us to add padding to the `left` for batch encoding
let tokenizer = create_tokenizer(tokenizer_filename.as_path())?;
let tokenizer = create_tokenizer(tokenizer_filename.as_path(), args.which)?;
let start = std::time::Instant::now();
@ -343,11 +375,7 @@ fn main() -> Result<()> {
let embed_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? };
let model = EmbeddingModel::new(
&Config::new_1_5_b_v5(embed_dim.embed_dim()),
base_vb,
embed_vb,
)?;
let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?;
println!("loaded the model in {:?}", start.elapsed());

View File

@ -624,13 +624,27 @@ pub fn main() -> Result<()> {
continue;
}
let mut resampled_pcm = vec![];
for buffered_pcm in buffered_pcm.chunks(1024) {
// resample the audio, one chunk of 1024 samples at a time.
// in case the audio input failed to produce an exact multiple of 1024 samples,
// process the remainder on the next iteration of the loop.
let full_chunks = buffered_pcm.len() / 1024;
let remainder = buffered_pcm.len() % 1024;
for chunk in 0..full_chunks {
let buffered_pcm = &buffered_pcm[chunk * 1024..(chunk + 1) * 1024];
let pcm = resampler.process(&[&buffered_pcm], None)?;
resampled_pcm.extend_from_slice(&pcm[0])
resampled_pcm.extend_from_slice(&pcm[0]);
}
let pcm = resampled_pcm;
println!("{} {}", buffered_pcm.len(), pcm.len());
buffered_pcm.clear();
if remainder == 0 {
buffered_pcm.clear();
} else {
// efficiently copy the remainder to the beginning of the `buffered_pcm` buffer and
// truncate it. That's more efficient then allocating a new vector and copying into it
println!("audio device produced partial chunk with {remainder} samples; processing the remainder on the next iteration of the loop");
buffered_pcm.copy_within(full_chunks * 1024.., 0);
buffered_pcm.truncate(remainder);
}
let mel = audio::pcm_to_mel(&config, &pcm, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(

View File

@ -0,0 +1,30 @@
# candle-xlm-roberta
This example demonstrates how to use the XLM-RoBERTa model in Candle especially known for their use in reranking. It uses the `fill-mask` task to generate a word for a masked token. And a `reranker` task to rerank a list of documents for a given query.
## Usage
Fill Mask:
```bash
cargo run --example xlm-roberta --release -- --task fill-mask --model xlm-roberta-base
```
```markdown
Sentence: 0 : Hello I'm a fashion model.
Sentence: 1 : I'm a little boy.
Sentence: 2 : I'm living in berlin.
```
Reranker:
```bash
cargo run --example xlm-roberta --release -- --task reranker --model bge-reranker-base
```
```markdown
Ranking Results:
--------------------------------------------------------------------------------
> Rank #4 | Score: 0.0001 | South Korea is a country in East Asia.
> Rank #5 | Score: 0.0000 | There are forests in the mountains.
> Rank #2 | Score: 0.7314 | Pandas look like bears.
> Rank #3 | Score: 0.6948 | There are some animals with black and white fur.
> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
--------------------------------------------------------------------------------
```

View File

@ -0,0 +1,277 @@
use std::path::PathBuf;
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::xlm_roberta::{
Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Debug, Clone, ValueEnum)]
enum Model {
BgeRerankerBase,
BgeRerankerLarge,
BgeRerankerBaseV2,
XLMRobertaBase,
XLMRobertaLarge,
}
#[derive(Debug, Clone, ValueEnum)]
enum Task {
FillMask,
Reranker,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long, default_value = "bge-reranker-base")]
model: Model,
#[arg(long, default_value = "reranker")]
task: Task,
// Path to the tokenizer file.
#[arg(long)]
tokenizer_file: Option<String>,
// Path to the weight files.
#[arg(long)]
weight_files: Option<String>,
// Path to the config file.
#[arg(long)]
config_file: Option<String>,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.task {
Task::FillMask => match args.model {
Model::XLMRobertaBase => "FacebookAI/xlm-roberta-base".to_string(),
Model::XLMRobertaLarge => "FacebookAI/xlm-roberta-large".to_string(),
_ => anyhow::bail!("BGE models are not supported for fill-mask task"),
},
Task::Reranker => match args.model {
Model::BgeRerankerBase => "BAAI/bge-reranker-base".to_string(),
Model::BgeRerankerLarge => "BAAI/bge-reranker-large".to_string(),
Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
_ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
},
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let weights_filename = match args.weight_files {
Some(files) => PathBuf::from(files),
None => match repo.get("model.safetensors") {
Ok(safetensors) => safetensors,
Err(_) => match repo.get("pytorch_model.bin") {
Ok(pytorch_model) => pytorch_model,
Err(e) => {
return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e)));
}
},
},
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(args.cpu)?;
let vb = if weights_filename.ends_with("model.safetensors") {
unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F16, &device)
.unwrap()
}
} else {
println!("Loading weights from pytorch_model.bin");
VarBuilder::from_pth(&weights_filename, candle::DType::F16, &device).unwrap()
};
tokenizer
.with_padding(Some(PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
pad_id: config.pad_token_id,
..Default::default()
}))
.with_truncation(None)
.map_err(E::msg)?;
match args.task {
Task::FillMask => {
let prompt = vec![
"Hello I'm a <mask> model.".to_string(),
"I'm a <mask> boy.".to_string(),
"I'm <mask> in berlin.".to_string(),
];
let model = XLMRobertaForMaskedLM::new(&config, vb)?;
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
let attention_mask =
get_attention_mask(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
let output = model
.forward(
&input_ids,
&attention_mask,
&token_type_ids,
None,
None,
None,
)?
.to_dtype(candle::DType::F32)?;
let max_outs = output.argmax(2)?;
let max_out = max_outs.to_vec2::<u32>()?;
let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect();
let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap();
for (i, sentence) in decoded.iter().enumerate() {
println!("Sentence: {} : {}", i + 1, sentence);
}
}
Task::Reranker => {
let query = "what is panda?".to_string();
let documents = ["South Korea is a country in East Asia.".to_string(),
"There are forests in the mountains.".to_string(),
"Pandas look like bears.".to_string(),
"There are some animals with black and white fur.".to_string(),
"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.".to_string()];
// create pairs of query and documents
let pairs = documents
.iter()
.map(|doc| (query.clone(), doc.clone()))
.collect::<Vec<_>>();
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
let attention_mask =
get_attention_mask(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
let output = model.forward(&input_ids, &attention_mask, &token_type_ids)?;
let output = candle_nn::ops::sigmoid(&output)?.t().unwrap();
let ranks = output
.arg_sort_last_dim(false)?
.to_vec2::<u32>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
println!("\nRanking Results:");
println!("{:-<80}", "");
documents.iter().enumerate().for_each(|(idx, doc)| {
let rank = ranks.iter().position(|&r| r == idx as u32).unwrap();
let score = output
.get_on_dim(1, idx)
.unwrap()
.to_dtype(candle::DType::F32)
.unwrap()
.to_vec1::<f32>()
.unwrap();
println!("Rank #{:<2} | Score: {:.4} | {}", rank + 1, score[0], doc);
});
println!("{:-<80}", "");
}
}
Ok(())
}
#[derive(Debug)]
pub enum TokenizeInput<'a> {
Single(&'a [String]),
Pairs(&'a [(String, String)]),
}
pub fn tokenize_batch(
tokenizer: &Tokenizer,
input: TokenizeInput,
device: &Device,
) -> anyhow::Result<Tensor> {
let tokens = match input {
TokenizeInput::Single(text_batch) => tokenizer
.encode_batch(text_batch.to_vec(), true)
.map_err(E::msg)?,
TokenizeInput::Pairs(pairs) => tokenizer
.encode_batch(pairs.to_vec(), true)
.map_err(E::msg)?,
};
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
Ok(Tensor::stack(&token_ids, 0)?)
}
pub fn get_attention_mask(
tokenizer: &Tokenizer,
input: TokenizeInput,
device: &Device,
) -> anyhow::Result<Tensor> {
let tokens = match input {
TokenizeInput::Single(text_batch) => tokenizer
.encode_batch(text_batch.to_vec(), true)
.map_err(E::msg)?,
TokenizeInput::Pairs(pairs) => tokenizer
.encode_batch(pairs.to_vec(), true)
.map_err(E::msg)?,
};
let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
Ok(Tensor::stack(&attention_mask, 0)?)
}

View File

@ -6,7 +6,6 @@ pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
/// Loads an image from disk using the image crate at the requested resolution,
/// using the given std and mean parameters.
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
pub fn load_image_with_std_mean<P: AsRef<std::path::Path>>(
p: P,
res: usize,

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.8.0"
version = "0.8.2"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.0" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.2" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -54,6 +54,7 @@ fn main() -> Result<()> {
println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
println!("cargo:rerun-if-changed=kernels/block_info.h");
println!("cargo:rerun-if-changed=kernels/static_switch.h");
println!("cargo:rerun-if-changed=kernels/hardware_info.h");
let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
Err(_) =>

View File

@ -18,8 +18,9 @@ struct BlockInfo {
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
, leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
, seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}
@ -30,13 +31,14 @@ struct BlockInfo {
template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int leftpad_k;
const int seqlen_k_cache;
const int actual_seqlen_k;
};

View File

@ -7,13 +7,7 @@
#include <cuda.h>
#include <vector>
// #ifdef OLD_GENERATOR_PATH
// #include <ATen/CUDAGeneratorImpl.h>
// #else
// #include <ATen/cuda/CUDAGeneratorImpl.h>
// #endif
//
// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params {
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int * __restrict__ leftpad_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);

View File

@ -53,9 +53,12 @@ extern "C" void run_mha(
int is_bf16,
int is_causal,
int unpadded_lse,
int window_size_left,
int window_size_right
int window_size_right,
float softcap
) {
Flash_fwd_params params;
// Reset the parameters
@ -99,8 +102,16 @@ extern "C" void run_mha(
params.d_rounded = d_rounded;
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
if (softcap > 0.0) {
params.softcap = softmax_scale / softcap;
params.scale_softmax = softcap;
params.scale_softmax_log2 = softcap * M_LOG2E;
} else{
// Remove potential NaN
params.softcap = 0.0;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
}
params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
@ -118,6 +129,7 @@ extern "C" void run_mha(
params.is_seqlens_k_cumulative = true;
params.num_splits = 1;
params.unpadded_lse = unpadded_lse;
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -4,6 +4,8 @@
#pragma once
// #include "philox_unpack.cuh" // For at::cuda::philox::unpack
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
@ -22,14 +24,6 @@ namespace flash {
using namespace cute;
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
@ -328,7 +322,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}
mask.template apply_mask<Is_causal, Is_even_MN>(
@ -394,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();
@ -691,7 +685,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
make_stride(params.rotary_dim / 2, _1{}));
@ -712,9 +706,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
// const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
const index_t row_offset_knew = bidb * params.knew_batch_stride
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
// const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
const index_t row_offset_vnew = bidb * params.vnew_batch_stride
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
@ -792,7 +788,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM);
} else {
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
@ -886,7 +882,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}
@ -961,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();
@ -1226,7 +1222,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
constexpr int kBlockN = kNThreads / kBlockM;
using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;

View File

@ -3,11 +3,11 @@
******************************************************************************/
#pragma once
// #include <ATen/cuda/CUDAContext.h>
// #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include "error.h"
#include "static_switch.h"
#include "hardware_info.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
@ -205,7 +205,8 @@ inline bool cuda_is_sm8x() {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
bool is_sm8x = cuda_is_sm8x();
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x = cc_major == 8 && cc_minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if (is_sm8x) {
@ -228,7 +229,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
bool is_sm8x = cuda_is_sm8x();
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x = cc_major == 8 && cc_minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if constexpr(!Is_dropout) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
@ -262,7 +264,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160;
bool is_sm8x = cuda_is_sm8x();
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x = cc_major == 8 && cc_minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),

View File

@ -0,0 +1,42 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <tuple>
#include <cstdio>
#if !defined(__CUDACC_RTC__)
#include "cuda_runtime.h"
#endif
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \
cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
inline int get_current_device() {
int device;
CHECK_CUDA(cudaGetDevice(&device));
return device;
}
inline std::tuple<int, int> get_compute_capability(int device) {
int capability_major, capability_minor;
CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device));
CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device));
return {capability_major, capability_minor};
}
inline int get_num_sm(int device) {
int multiprocessor_count;
CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
return multiprocessor_count;
}

View File

@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base {
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
AutoVectorizingCopyWithAssumedAlignment<128>
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
Stride< _16, _1>>
>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base {
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinCont = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
};
@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base {
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
using SmemLayoutQdOtransposed = decltype(
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutdKV = decltype(tile_to_shape(
SmemLayoutAtomdKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
using SmemLayoutAtomdQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutdQ = decltype(tile_to_shape(
SmemLayoutAtomdQ{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
// Double buffer for sQ
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base {
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
AutoVectorizingCopyWithAssumedAlignment<128>
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopydO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydKV = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomdQaccum = std::conditional_t<
@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base {
Stride< _16, _1>>
>;
using GmemTiledCopydQaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomdQaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemTiledCopydQaccumAtomicAdd = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
Stride<_32, _1>>{},
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store

View File

@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(src_tensor); ++i) {
dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash

View File

@ -42,9 +42,12 @@ extern "C" {
is_bf16: c_int,
is_causal: c_int,
unpadded_lse: c_int,
window_size_left: c_int,
window_size_right: c_int,
softcap: f32,
);
}

View File

@ -11,6 +11,7 @@ pub struct FlashAttn {
pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
pub softcap: Option<f32>,
}
fn round_multiple(x: usize, m: usize) -> usize {
@ -199,8 +200,10 @@ impl FlashAttn {
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* upadded_lse */ 0,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
/* softcap */ self.softcap.unwrap_or(0f32),
)
}
@ -271,6 +274,7 @@ pub fn flash_attn(
alibi_slopes: None,
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
@ -308,6 +312,7 @@ pub fn flash_attn_windowed(
alibi_slopes: None,
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
@ -342,6 +347,7 @@ pub fn flash_attn_alibi(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
@ -381,6 +387,52 @@ pub fn flash_attn_alibi_windowed(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors `k` and `v` with fewer heads
/// than `q`. The number of heads in `k` and `v` must be divisible by the number of heads in `q`.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Optional alibi slopes tensor with shape `(num_heads_q)`.
/// * `softmax_scale` - Scaling factor for the softmax operation.
/// * `window_size_left` - Optional limit on left attention to value tokens.
/// * `window_size_right` - Optional limit on right attention to value tokens.
/// * `softcap` - Gemma style softcap the attention logits before the softmax.
///
/// # Causal Mask
///
/// Setting `window_size_left=None` and `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`.
///
/// # Returns
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_alibi_windowed_softcap(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: Option<&Tensor>,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
softcap: f32,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
alibi_slopes: alibi_slopes.cloned(),
window_size_left,
window_size_right,
softcap: Some(softcap),
};
q.apply_op3(k, v, op)
}
@ -394,6 +446,7 @@ struct FlashAttnVarLen {
pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
pub softcap: Option<f32>,
}
impl FlashAttnVarLen {
@ -466,7 +519,7 @@ impl FlashAttnVarLen {
candle::bail!("the last dim of v must be contiguous {v_stride:?}")
}
let (_total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
let (total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
let (total_k, num_heads_k, _head_size_og) = k_l.shape().dims3()?;
let expected_kv = (total_k, num_heads_k, head_size_og);
if expected_kv != k_l.shape().dims3()? {
@ -549,9 +602,7 @@ impl FlashAttnVarLen {
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
let softmax_lse = dev
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };
@ -611,8 +662,10 @@ impl FlashAttnVarLen {
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* upadded_lse */ 1,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
/* softcap */ self.softcap.unwrap_or(0.0),
)
}
@ -699,6 +752,7 @@ pub fn flash_attn_varlen(
alibi_slopes: None,
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
@ -752,6 +806,7 @@ pub fn flash_attn_varlen_windowed(
alibi_slopes: None,
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
@ -802,6 +857,7 @@ pub fn flash_attn_varlen_alibi(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
@ -857,6 +913,65 @@ pub fn flash_attn_varlen_alibi_windowed(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Option, alibi slopes tensor with shape `(num_heads_q)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
/// * `window_size_left` - Option, limit left attention to value tokens.
/// * `window_size_right` - Option, limit right attention to value tokens.
/// * `softcap` - Gemma style softcap the attention logits before the softmax.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
pub fn flash_attn_varlen_alibi_windowed_softcap(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: Option<&Tensor>,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
softcap: f32,
) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: alibi_slopes.cloned(),
window_size_left,
window_size_right,
softcap: Some(softcap),
};
q.apply_op3(k, v, op)
}

Some files were not shown because too many files have changed in this diff Show More