Compare commits

...

19 Commits
0.8.1 ... 0.8.2

Author SHA1 Message Date
236c35e578 Bump the caret version to 0.8.2. (#2703) 2025-01-07 15:50:16 +01:00
6f8351dfda add link to README (#2701) 2025-01-04 23:07:30 +01:00
57f41da13b Fix mistral attention on Metal (#2699)
Co-authored-by: Luka Zakrajsek <luka.zakrajsek@soniox.com>
2025-01-04 16:11:20 +01:00
cbaa0ad46f UniPC for diffusion sampling (#2684)
* feat: Add unipc multistep scheduler

* chore: Clippy and formatting

* chore: Update comments

* chore: Avoid unsafety in float ordering

* refactor: Update Scheduler::step mutability requirements

* fix: Corrector img2img

* chore: Update unipc ref link to latest diffusers release

* chore: Deduplicate float ordering

* fix: Panic when running with dev profile
2025-01-01 21:34:17 +01:00
b12c7c2888 Update the hf-hub dependency to 0.4.0. (#2691)
* Update the hf-hub dependency to 0.4.0.

* Fix the book.

* Use 0.4.1.
2024-12-31 19:07:47 +01:00
94ffc2ec6f Actually remove the default hf-hub cache path for glm. (#2696) 2024-12-31 11:00:44 +01:00
7354afc673 Use the default hf-hub cache for glm. (#2695) 2024-12-31 10:55:45 +01:00
2a705e6f37 Flash-Attn upgrade / SoftCap Candle-FlashAttn [3/n] (#2690)
* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace

* softcap is working, including test and api.

* make softcap test case better

* unpadded lse added
2024-12-31 10:04:47 +01:00
a594ef669c Flash-Attn upgrade / SoftCap Candle-FlashAttn [2/n] (#2689)
* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace

* softcap is working, including test and api.

* make softcap test case better

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-12-31 09:41:23 +01:00
71cd6d5533 Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace
2024-12-31 09:32:22 +01:00
d60eba1408 Streamline the glm4 example. (#2694) 2024-12-31 09:21:41 +01:00
e38e2a85dd Fix a cuda warning. (#2693) 2024-12-31 09:06:10 +01:00
460616fc84 Update README.org (#2670)
The command line error in the CPU section of the documentation.
2024-12-30 11:32:02 +01:00
91f1f019b1 Added XLMRobertaModel for Reranking (#2686)
* add xlm-roberta-base

* Add task enum for fill-mask and reranker in xlm-roberta example; update README and fix attention mask dimensions

- Introduced a new `Task` enum to replace string task identifiers in the xlm-roberta example.
- Updated the logic in `main.rs` to handle tasks using the new enum.
- Enhanced README with example output for fill-mask task.
- Fixed dimension retrieval in `prepare_4d_attention_mask` function for better clarity and safety.

* Clippy fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-12-30 11:16:57 +01:00
cd639131f0 Fix bug in whisper transformer (#2681)
* Fix bug in whisper transformer
- due to num_threads going to zero
in single threaded case

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-12-24 13:58:21 +01:00
11aa30be10 Fix Batcher iterator break when return_last_incomplete_batch and items.is_empty (#2654) (#2655) 2024-12-24 08:41:26 +01:00
1be6b090c7 Fix position encodings for Pixtral (#2678)
* init commit: add position id in meshgrid

* pass in subsampled positions

* clippy fix

* clippy fix
2024-12-23 13:22:35 +01:00
62ced44ea9 Add a Context trait similar to anyhow::Context. (#2676)
* Add a Context trait similar to anyhow::Context.

* Switch two unwrap to context.
2024-12-22 09:18:13 +01:00
5c2f893e5a make DepthAnythingV2 more reusable (#2675)
* make DepthAnythingV2 more reusable

* Fix clippy lints.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-12-21 12:06:03 +01:00
86 changed files with 2549 additions and 386 deletions

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.8.1"
version = "0.8.2"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,20 +33,20 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.8.1" }
candle-datasets = { path = "./candle-datasets", version = "0.8.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.1" }
candle-kernels = { path = "./candle-kernels", version = "0.8.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.1" }
candle-nn = { path = "./candle-nn", version = "0.8.1" }
candle-onnx = { path = "./candle-onnx", version = "0.8.1" }
candle-transformers = { path = "./candle-transformers", version = "0.8.1" }
candle = { path = "./candle-core", package = "candle-core", version = "0.8.2" }
candle-datasets = { path = "./candle-datasets", version = "0.8.2" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.2" }
candle-kernels = { path = "./candle-kernels", version = "0.8.2" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.2" }
candle-nn = { path = "./candle-nn", version = "0.8.2" }
candle-onnx = { path = "./candle-onnx", version = "0.8.2" }
candle-transformers = { path = "./candle-transformers", version = "0.8.2" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = { version = "0.3.3", package = "candle-hf-hub" }
hf-hub = "0.4.1"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }

View File

@ -189,6 +189,7 @@ And then head over to
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
- [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible.
- [`llms-from-scratch-rs`](https://github.com/nerdai/llms-from-scratch-rs): A comprehensive Rust translation of the code from Sebastian Raschka's Build an LLM from Scratch book.
If you have an addition to this list, please submit a pull request.

View File

@ -11,8 +11,8 @@ Then let's start by downloading the [model file](https://huggingface.co/bert-bas
```rust
# extern crate candle_core;
# extern crate candle_hf_hub;
use candle_hf_hub::api::sync::Api;
# extern crate hf_hub;
use hf_hub::api::sync::Api;
use candle_core::Device;
let api = Api::new().unwrap();
@ -50,8 +50,8 @@ Now that we have our weights, we can use them in our bert architecture:
```rust
# extern crate candle_core;
# extern crate candle_nn;
# extern crate candle_hf_hub;
# use candle_hf_hub::api::sync::Api;
# extern crate hf_hub;
# use hf_hub::api::sync::Api;
#
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());

View File

@ -9,8 +9,14 @@ pub struct MatMulUnexpectedStriding {
pub msg: &'static str,
}
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
}
}
/// Main library error type.
#[derive(thiserror::Error, Debug)]
#[derive(thiserror::Error)]
pub enum Error {
// === DType Errors ===
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
@ -199,8 +205,14 @@ pub enum Error {
UnsupportedSafeTensorDtype(safetensors::Dtype),
/// Arbitrary errors wrapping.
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
#[error("{0}")]
Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
#[error("{context}\n{inner}")]
Context {
inner: Box<Self>,
context: Box<dyn std::fmt::Display + Send + Sync>,
},
/// Adding path information to an error.
#[error("path: {path:?} {inner}")]
@ -218,16 +230,19 @@ pub enum Error {
/// User generated error message, typically created via `bail!`.
#[error("{0}")]
Msg(String),
#[error("unwrap none")]
UnwrapNone,
}
pub type Result<T> = std::result::Result<T, Error>;
impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err)).bt()
}
pub fn msg(err: impl std::error::Error) -> Self {
pub fn msg(err: impl std::fmt::Display) -> Self {
Self::Msg(err.to_string()).bt()
}
@ -253,6 +268,13 @@ impl Error {
path: p.as_ref().to_path_buf(),
}
}
pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
Self::Context {
inner: Box::new(self),
context: Box::new(c),
}
}
}
#[macro_export]
@ -275,3 +297,41 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
(_, Err(e)) => Err(e),
}
}
// Taken from anyhow.
pub trait Context<T> {
/// Wrap the error value with additional context.
fn context<C>(self, context: C) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static;
/// Wrap the error value with additional context that is evaluated lazily
/// only once an error does occur.
fn with_context<C, F>(self, f: F) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C;
}
impl<T> Context<T> for Option<T> {
fn context<C>(self, context: C) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
{
match self {
Some(v) => Ok(v),
None => Err(Error::UnwrapNone.context(context).bt()),
}
}
fn with_context<C, F>(self, f: F) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C,
{
match self {
Some(v) => Ok(v),
None => Err(Error::UnwrapNone.context(f()).bt()),
}
}
}

View File

@ -94,7 +94,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use error::{Context, Error, Result};
pub use indexer::{IndexOp, TensorIndexer};
pub use layout::Layout;
pub use shape::{Shape, D};

View File

@ -1,7 +1,7 @@
//! Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point.
use crate::{DType, Error as E, Layout, Result, Tensor};
use crate::{Context, DType, Error as E, Layout, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
use std::io::BufRead;
@ -537,7 +537,7 @@ impl Stack {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
let key = objs.pop().context("empty objs")?;
d.push((key, value))
}
} else {
@ -557,7 +557,7 @@ impl Stack {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
let key = objs.pop().context("empty objs")?;
pydict.push((key, value))
}
self.push(Object::Dict(pydict))
@ -661,7 +661,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
if !file_name.ends_with("data.pkl") {
continue;
}
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
let reader = zip.by_name(file_name)?;
let mut reader = std::io::BufReader::new(reader);
let mut stack = Stack::empty();

View File

@ -2,7 +2,7 @@
//!
use super::{GgmlDType, QTensor};
use crate::{Device, Result};
use crate::{Context, Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
@ -338,7 +338,7 @@ impl Value {
if value_type.len() != 1 {
crate::bail!("multiple value-types in the same array {value_type:?}")
}
value_type.into_iter().next().unwrap()
value_type.into_iter().next().context("empty value_type")?
};
w.write_u32::<LittleEndian>(value_type.to_u32())?;
w.write_u64::<LittleEndian>(v.len() as u64)?;

View File

@ -1,5 +1,5 @@
//! Code for GGML and GGUF files
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
@ -481,7 +481,7 @@ impl crate::CustomOp1 for QTensor {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
let last_k = dst_shape.pop().context("empty dst_shape")?;
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}

View File

@ -52,6 +52,49 @@ impl ArgSort {
}
}
#[cfg(feature = "cuda")]
mod cuda {
use super::*;
use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
use crate::{CudaDevice, WithDType};
impl crate::cuda_backend::Map1Any for ArgSort {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &crate::Layout,
_wrap: W,
) -> Result<S> {
let slice = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
} else {
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
};
let ncols = self.last_dim;
let nrows = elem_count / ncols;
let ncols_pad = next_power_of_2(ncols);
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
let cfg = LaunchConfig {
grid_dim: (1, nrows as u32, 1),
block_dim: (ncols_pad as u32, 1, 1),
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
};
unsafe { func.launch(cfg, params) }.w()?;
Ok(S::U32(dst))
}
}
}
impl crate::CustomOp1 for ArgSort {
fn name(&self) -> &'static str {
"argsort"
@ -81,46 +124,8 @@ impl crate::CustomOp1 for ArgSort {
storage: &crate::CudaStorage,
layout: &crate::Layout,
) -> Result<(crate::CudaStorage, crate::Shape)> {
use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
use crate::{CudaDevice, WithDType};
impl Map1Any for ArgSort {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &crate::Layout,
_wrap: W,
) -> Result<S> {
let slice = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
} else {
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
};
let ncols = self.last_dim;
let nrows = elem_count / ncols;
let ncols_pad = next_power_of_2(ncols);
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
let cfg = LaunchConfig {
grid_dim: (1, nrows as u32, 1),
block_dim: (ncols_pad as u32, 1, 1),
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
};
unsafe { func.launch(cfg, params) }.w()?;
Ok(S::U32(dst))
}
}
use crate::backend::BackendStorage;
use crate::cuda_backend::Map1Any;
let dev = storage.device();
let slice = self.map(&storage.slice, dev, layout)?;
let dst = crate::cuda_backend::CudaStorage {

View File

@ -1,4 +1,4 @@
use crate::{shape::Dim, Error, Result, Shape, Tensor};
use crate::{shape::Dim, Context, Error, Result, Shape, Tensor};
impl Tensor {
/// Concatenates two or more tensors along a particular dimension.
@ -134,7 +134,7 @@ impl Tensor {
.bt())?
}
}
let next_offset = offsets.last().unwrap() + arg.elem_count();
let next_offset = offsets.last().context("empty offsets")? + arg.elem_count();
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);

View File

@ -78,7 +78,7 @@ impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> {
match self.inner.inner.next() {
Some(item) => items.push(item),
None => {
if self.return_last_incomplete_batch {
if self.return_last_incomplete_batch && !items.is_empty() {
break;
}
return None;
@ -102,7 +102,7 @@ impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> {
ys.push(y)
}
None => {
if self.return_last_incomplete_batch {
if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() {
break;
}
return None;
@ -127,7 +127,7 @@ impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> {
match self.inner.inner.next() {
Some(item) => items.push(item),
None => {
if self.return_last_incomplete_batch {
if self.return_last_incomplete_batch && !items.is_empty() {
break;
}
return None;
@ -154,7 +154,7 @@ impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResu
}
Some(Err(err)) => errs.push(err),
None => {
if self.return_last_incomplete_batch {
if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() {
break;
}
return None;

View File

@ -13,7 +13,7 @@ THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios,
** Running with ~cpu~
#+begin_src shell
cargo run --example codegeex4-9b --release --cpu -- --prompt "please write a insertion sort in rust" --sample-len 300
cargo run --example codegeex4-9b --release -- --cpu --prompt "please write a insertion sort in rust" --sample-len 300
#+end_src
** Output_Example

View File

@ -6,10 +6,8 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use std::ffi::OsString;
use std::path::PathBuf;
use clap::Parser;
use std::{ffi::OsString, path::PathBuf, sync::Arc};
use candle::DType::{F32, U8};
use candle::{DType, Device, Module, Result, Tensor};
@ -82,7 +80,7 @@ pub fn main() -> anyhow::Result<()> {
};
let config = DepthAnythingV2Config::vit_small();
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
let depth_anything = DepthAnythingV2::new(Arc::new(dinov2), config, vb)?;
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;

View File

@ -250,7 +250,11 @@ fn run(args: Args) -> Result<()> {
};
println!("img\n{img}");
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
let filename = match args.seed {
None => "out.jpg".to_string(),
Some(s) => format!("out-{s}.jpg"),
};
candle_examples::save_image(&img.i(0)?, filename)?;
Ok(())
}

View File

@ -7,48 +7,25 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
** Running with ~cuda~
#+begin_src shell
cargo run --example glm4 --release --features cuda
cargo run --example glm4 --release --features cuda -- --prompt "Hello world"
#+end_src
** Running with ~cpu~
#+begin_src shell
cargo run --example glm4 --release -- --cpu
cargo run --example glm4 --release -- --cpu--prompt "Hello world"
#+end_src
** Output Example
#+begin_src shell
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
Finished release [optimized] target(s) in 0.24s
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
cargo run --features cuda -r --example glm4 -- --prompt "Hello "
avx: true, neon: false, simd128: false, f16c: true
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
cache path .
retrieved the files in 6.88963ms
loaded the model in 6.113752297s
retrieved the files in 6.454375ms
loaded the model in 3.652383779s
starting the inference loop
[欢迎使用GLM-4,请输入prompt]
请你告诉我什么是FFT
266 tokens generated (34.50 token/s)
Result:
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换DFT的方法它广泛应用于信号处理、图像处理和数据分析等领域。
具体来说FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中我们通常需要知道信号的频率成分这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。
以下是使用 Python 中的 numpy 进行 FFT 的简单示例:
```python
import numpy as np
# 创建一个时域信号
t = np.linspace(0, 1, num=100)
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)
# 对该信号做FFT变换并计算其幅值谱
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))
```
在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
Hello 2018, hello new year! Im so excited to be back and sharing with you all my favorite things from the past month. This is a monthly series where I share whats been inspiring me lately in hopes that it will inspire you too!
...
#+end_src
This example will read prompt from stdin

View File

@ -12,136 +12,117 @@ struct TextGeneration {
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
args: Args,
dtype: DType,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
dtype: DType,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
fn new(model: Model, tokenizer: Tokenizer, args: Args, device: &Device, dtype: DType) -> Self {
let logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
args,
device: device.clone(),
dtype,
}
}
fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
use std::io::BufRead;
use std::io::BufReader;
fn run(&mut self) -> anyhow::Result<()> {
use std::io::Write;
let args = &self.args;
println!("starting the inference loop");
println!("[欢迎使用GLM-4,请输入prompt]");
let stdin = std::io::stdin();
let reader = BufReader::new(stdin);
for line in reader.lines() {
let line = line.expect("Failed to read line");
let tokens = self.tokenizer.encode(line, true).expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
let tokens = self
.tokenizer
.encode(args.prompt.to_string(), true)
.expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
if args.verbose {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
} else {
print!("{}", &args.prompt);
std::io::stdout().flush()?;
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();
for index in 0..args.sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();
let mut count = 0;
let mut result = vec![];
for index in 0..sample_len {
count += 1;
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("Token error");
if self.verbose_prompt {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
count, next_token, token
);
}
result.push(token);
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("token decode error");
if args.verbose {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
generated_tokens, next_token, token
);
} else {
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
println!("Result:");
for tokens in result {
print!("{tokens}");
}
self.model.reset_kv_cache(); // clean the cache
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(name = "cache", short, long, default_value = ".")]
cache_path: String,
#[arg(name = "cache", short)]
cache_path: Option<String>,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
prompt: String,
/// Display the tokens for the specified prompt and outputs.
#[arg(long)]
verbose: bool,
/// The temperature used to generate samples.
#[arg(long)]
@ -197,28 +178,32 @@ fn main() -> anyhow::Result<()> {
);
let start = std::time::Instant::now();
println!("cache path {}", args.cache_path);
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
.build()
.map_err(anyhow::Error::msg)?;
let api = match args.cache_path.as_ref() {
None => hf_hub::api::sync::Api::new()?,
Some(path) => {
hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(path.to_string().into()))
.build()
.map_err(anyhow::Error::msg)?
}
};
let model_id = match args.model_id {
let model_id = match args.model_id.as_ref() {
Some(model_id) => model_id.to_string(),
None => "THUDM/glm-4-9b".to_string(),
};
let revision = match args.revision {
let revision = match args.revision.as_ref() {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
let tokenizer_filename = match args.tokenizer.as_ref() {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.map_err(anyhow::Error::msg)?,
};
let filenames = match args.weight_file {
let filenames = match args.weight_file.as_ref() {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
@ -238,18 +223,7 @@ fn main() -> anyhow::Result<()> {
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
dtype,
);
pipeline.run(args.sample_len)?;
let mut pipeline = TextGeneration::new(model, tokenizer, args, &device, dtype);
pipeline.run()?;
Ok(())
}

View File

@ -477,7 +477,7 @@ fn run(args: Args) -> Result<()> {
),
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let mut scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
if let Some(seed) = seed {
device.set_seed(seed)?;
@ -539,7 +539,7 @@ fn run(args: Args) -> Result<()> {
};
for idx in 0..num_samples {
let timesteps = scheduler.timesteps();
let timesteps = scheduler.timesteps().to_vec();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;

View File

@ -0,0 +1,30 @@
# candle-xlm-roberta
This example demonstrates how to use the XLM-RoBERTa model in Candle especially known for their use in reranking. It uses the `fill-mask` task to generate a word for a masked token. And a `reranker` task to rerank a list of documents for a given query.
## Usage
Fill Mask:
```bash
cargo run --example xlm-roberta --release -- --task fill-mask --model xlm-roberta-base
```
```markdown
Sentence: 0 : Hello I'm a fashion model.
Sentence: 1 : I'm a little boy.
Sentence: 2 : I'm living in berlin.
```
Reranker:
```bash
cargo run --example xlm-roberta --release -- --task reranker --model bge-reranker-base
```
```markdown
Ranking Results:
--------------------------------------------------------------------------------
> Rank #4 | Score: 0.0001 | South Korea is a country in East Asia.
> Rank #5 | Score: 0.0000 | There are forests in the mountains.
> Rank #2 | Score: 0.7314 | Pandas look like bears.
> Rank #3 | Score: 0.6948 | There are some animals with black and white fur.
> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
--------------------------------------------------------------------------------
```

View File

@ -0,0 +1,277 @@
use std::path::PathBuf;
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::xlm_roberta::{
Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Debug, Clone, ValueEnum)]
enum Model {
BgeRerankerBase,
BgeRerankerLarge,
BgeRerankerBaseV2,
XLMRobertaBase,
XLMRobertaLarge,
}
#[derive(Debug, Clone, ValueEnum)]
enum Task {
FillMask,
Reranker,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long, default_value = "bge-reranker-base")]
model: Model,
#[arg(long, default_value = "reranker")]
task: Task,
// Path to the tokenizer file.
#[arg(long)]
tokenizer_file: Option<String>,
// Path to the weight files.
#[arg(long)]
weight_files: Option<String>,
// Path to the config file.
#[arg(long)]
config_file: Option<String>,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.task {
Task::FillMask => match args.model {
Model::XLMRobertaBase => "FacebookAI/xlm-roberta-base".to_string(),
Model::XLMRobertaLarge => "FacebookAI/xlm-roberta-large".to_string(),
_ => anyhow::bail!("BGE models are not supported for fill-mask task"),
},
Task::Reranker => match args.model {
Model::BgeRerankerBase => "BAAI/bge-reranker-base".to_string(),
Model::BgeRerankerLarge => "BAAI/bge-reranker-large".to_string(),
Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
_ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
},
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let weights_filename = match args.weight_files {
Some(files) => PathBuf::from(files),
None => match repo.get("model.safetensors") {
Ok(safetensors) => safetensors,
Err(_) => match repo.get("pytorch_model.bin") {
Ok(pytorch_model) => pytorch_model,
Err(e) => {
return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e)));
}
},
},
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(args.cpu)?;
let vb = if weights_filename.ends_with("model.safetensors") {
unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F16, &device)
.unwrap()
}
} else {
println!("Loading weights from pytorch_model.bin");
VarBuilder::from_pth(&weights_filename, candle::DType::F16, &device).unwrap()
};
tokenizer
.with_padding(Some(PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
pad_id: config.pad_token_id,
..Default::default()
}))
.with_truncation(None)
.map_err(E::msg)?;
match args.task {
Task::FillMask => {
let prompt = vec![
"Hello I'm a <mask> model.".to_string(),
"I'm a <mask> boy.".to_string(),
"I'm <mask> in berlin.".to_string(),
];
let model = XLMRobertaForMaskedLM::new(&config, vb)?;
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
let attention_mask =
get_attention_mask(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
let output = model
.forward(
&input_ids,
&attention_mask,
&token_type_ids,
None,
None,
None,
)?
.to_dtype(candle::DType::F32)?;
let max_outs = output.argmax(2)?;
let max_out = max_outs.to_vec2::<u32>()?;
let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect();
let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap();
for (i, sentence) in decoded.iter().enumerate() {
println!("Sentence: {} : {}", i + 1, sentence);
}
}
Task::Reranker => {
let query = "what is panda?".to_string();
let documents = ["South Korea is a country in East Asia.".to_string(),
"There are forests in the mountains.".to_string(),
"Pandas look like bears.".to_string(),
"There are some animals with black and white fur.".to_string(),
"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.".to_string()];
// create pairs of query and documents
let pairs = documents
.iter()
.map(|doc| (query.clone(), doc.clone()))
.collect::<Vec<_>>();
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
let attention_mask =
get_attention_mask(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
let output = model.forward(&input_ids, &attention_mask, &token_type_ids)?;
let output = candle_nn::ops::sigmoid(&output)?.t().unwrap();
let ranks = output
.arg_sort_last_dim(false)?
.to_vec2::<u32>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
println!("\nRanking Results:");
println!("{:-<80}", "");
documents.iter().enumerate().for_each(|(idx, doc)| {
let rank = ranks.iter().position(|&r| r == idx as u32).unwrap();
let score = output
.get_on_dim(1, idx)
.unwrap()
.to_dtype(candle::DType::F32)
.unwrap()
.to_vec1::<f32>()
.unwrap();
println!("Rank #{:<2} | Score: {:.4} | {}", rank + 1, score[0], doc);
});
println!("{:-<80}", "");
}
}
Ok(())
}
#[derive(Debug)]
pub enum TokenizeInput<'a> {
Single(&'a [String]),
Pairs(&'a [(String, String)]),
}
pub fn tokenize_batch(
tokenizer: &Tokenizer,
input: TokenizeInput,
device: &Device,
) -> anyhow::Result<Tensor> {
let tokens = match input {
TokenizeInput::Single(text_batch) => tokenizer
.encode_batch(text_batch.to_vec(), true)
.map_err(E::msg)?,
TokenizeInput::Pairs(pairs) => tokenizer
.encode_batch(pairs.to_vec(), true)
.map_err(E::msg)?,
};
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
Ok(Tensor::stack(&token_ids, 0)?)
}
pub fn get_attention_mask(
tokenizer: &Tokenizer,
input: TokenizeInput,
device: &Device,
) -> anyhow::Result<Tensor> {
let tokens = match input {
TokenizeInput::Single(text_batch) => tokenizer
.encode_batch(text_batch.to_vec(), true)
.map_err(E::msg)?,
TokenizeInput::Pairs(pairs) => tokenizer
.encode_batch(pairs.to_vec(), true)
.map_err(E::msg)?,
};
let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
Ok(Tensor::stack(&attention_mask, 0)?)
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.8.1"
version = "0.8.2"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.1" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.2" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -54,6 +54,7 @@ fn main() -> Result<()> {
println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
println!("cargo:rerun-if-changed=kernels/block_info.h");
println!("cargo:rerun-if-changed=kernels/static_switch.h");
println!("cargo:rerun-if-changed=kernels/hardware_info.h");
let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
Err(_) =>

View File

@ -18,8 +18,9 @@ struct BlockInfo {
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
, leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
, seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}
@ -30,13 +31,14 @@ struct BlockInfo {
template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int leftpad_k;
const int seqlen_k_cache;
const int actual_seqlen_k;
};

View File

@ -7,13 +7,7 @@
#include <cuda.h>
#include <vector>
// #ifdef OLD_GENERATOR_PATH
// #include <ATen/CUDAGeneratorImpl.h>
// #else
// #include <ATen/cuda/CUDAGeneratorImpl.h>
// #endif
//
// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params {
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int * __restrict__ leftpad_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);

View File

@ -53,9 +53,12 @@ extern "C" void run_mha(
int is_bf16,
int is_causal,
int unpadded_lse,
int window_size_left,
int window_size_right
int window_size_right,
float softcap
) {
Flash_fwd_params params;
// Reset the parameters
@ -99,8 +102,16 @@ extern "C" void run_mha(
params.d_rounded = d_rounded;
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
if (softcap > 0.0) {
params.softcap = softmax_scale / softcap;
params.scale_softmax = softcap;
params.scale_softmax_log2 = softcap * M_LOG2E;
} else{
// Remove potential NaN
params.softcap = 0.0;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
}
params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
@ -118,6 +129,7 @@ extern "C" void run_mha(
params.is_seqlens_k_cumulative = true;
params.num_splits = 1;
params.unpadded_lse = unpadded_lse;
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

View File

@ -4,6 +4,8 @@
#pragma once
// #include "philox_unpack.cuh" // For at::cuda::philox::unpack
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
@ -22,14 +24,6 @@ namespace flash {
using namespace cute;
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
@ -328,7 +322,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}
mask.template apply_mask<Is_causal, Is_even_MN>(
@ -394,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();
@ -691,7 +685,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
make_stride(params.rotary_dim / 2, _1{}));
@ -712,9 +706,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
// const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
const index_t row_offset_knew = bidb * params.knew_batch_stride
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
// const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
const index_t row_offset_vnew = bidb * params.vnew_batch_stride
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
@ -792,7 +788,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM);
} else {
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
@ -886,7 +882,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}
@ -961,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();
@ -1226,7 +1222,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
constexpr int kBlockN = kNThreads / kBlockM;
using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;

View File

@ -3,11 +3,11 @@
******************************************************************************/
#pragma once
// #include <ATen/cuda/CUDAContext.h>
// #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include "error.h"
#include "static_switch.h"
#include "hardware_info.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
@ -205,7 +205,8 @@ inline bool cuda_is_sm8x() {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
bool is_sm8x = cuda_is_sm8x();
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x = cc_major == 8 && cc_minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if (is_sm8x) {
@ -228,7 +229,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
bool is_sm8x = cuda_is_sm8x();
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x = cc_major == 8 && cc_minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if constexpr(!Is_dropout) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
@ -262,7 +264,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160;
bool is_sm8x = cuda_is_sm8x();
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x = cc_major == 8 && cc_minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),

View File

@ -0,0 +1,42 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <tuple>
#include <cstdio>
#if !defined(__CUDACC_RTC__)
#include "cuda_runtime.h"
#endif
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \
cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
inline int get_current_device() {
int device;
CHECK_CUDA(cudaGetDevice(&device));
return device;
}
inline std::tuple<int, int> get_compute_capability(int device) {
int capability_major, capability_minor;
CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device));
CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device));
return {capability_major, capability_minor};
}
inline int get_num_sm(int device) {
int multiprocessor_count;
CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
return multiprocessor_count;
}

View File

@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base {
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
AutoVectorizingCopyWithAssumedAlignment<128>
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
Stride< _16, _1>>
>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base {
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinCont = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
};
@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base {
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
using SmemLayoutQdOtransposed = decltype(
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutdKV = decltype(tile_to_shape(
SmemLayoutAtomdKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
using SmemLayoutAtomdQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutdQ = decltype(tile_to_shape(
SmemLayoutAtomdQ{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
// Double buffer for sQ
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base {
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
AutoVectorizingCopyWithAssumedAlignment<128>
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopydO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydKV = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomdQaccum = std::conditional_t<
@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base {
Stride< _16, _1>>
>;
using GmemTiledCopydQaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomdQaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemTiledCopydQaccumAtomicAdd = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
Stride<_32, _1>>{},
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store

View File

@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(src_tensor); ++i) {
dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash

View File

@ -42,9 +42,12 @@ extern "C" {
is_bf16: c_int,
is_causal: c_int,
unpadded_lse: c_int,
window_size_left: c_int,
window_size_right: c_int,
softcap: f32,
);
}

View File

@ -11,6 +11,7 @@ pub struct FlashAttn {
pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
pub softcap: Option<f32>,
}
fn round_multiple(x: usize, m: usize) -> usize {
@ -199,8 +200,10 @@ impl FlashAttn {
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* upadded_lse */ 0,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
/* softcap */ self.softcap.unwrap_or(0f32),
)
}
@ -271,6 +274,7 @@ pub fn flash_attn(
alibi_slopes: None,
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
@ -308,6 +312,7 @@ pub fn flash_attn_windowed(
alibi_slopes: None,
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
@ -342,6 +347,7 @@ pub fn flash_attn_alibi(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
@ -381,6 +387,52 @@ pub fn flash_attn_alibi_windowed(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors `k` and `v` with fewer heads
/// than `q`. The number of heads in `k` and `v` must be divisible by the number of heads in `q`.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Optional alibi slopes tensor with shape `(num_heads_q)`.
/// * `softmax_scale` - Scaling factor for the softmax operation.
/// * `window_size_left` - Optional limit on left attention to value tokens.
/// * `window_size_right` - Optional limit on right attention to value tokens.
/// * `softcap` - Gemma style softcap the attention logits before the softmax.
///
/// # Causal Mask
///
/// Setting `window_size_left=None` and `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`.
///
/// # Returns
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_alibi_windowed_softcap(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: Option<&Tensor>,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
softcap: f32,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
alibi_slopes: alibi_slopes.cloned(),
window_size_left,
window_size_right,
softcap: Some(softcap),
};
q.apply_op3(k, v, op)
}
@ -394,6 +446,7 @@ struct FlashAttnVarLen {
pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
pub softcap: Option<f32>,
}
impl FlashAttnVarLen {
@ -466,7 +519,7 @@ impl FlashAttnVarLen {
candle::bail!("the last dim of v must be contiguous {v_stride:?}")
}
let (_total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
let (total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
let (total_k, num_heads_k, _head_size_og) = k_l.shape().dims3()?;
let expected_kv = (total_k, num_heads_k, head_size_og);
if expected_kv != k_l.shape().dims3()? {
@ -549,9 +602,7 @@ impl FlashAttnVarLen {
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
let softmax_lse = dev
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };
@ -611,8 +662,10 @@ impl FlashAttnVarLen {
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* upadded_lse */ 1,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
/* softcap */ self.softcap.unwrap_or(0.0),
)
}
@ -699,6 +752,7 @@ pub fn flash_attn_varlen(
alibi_slopes: None,
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
@ -752,6 +806,7 @@ pub fn flash_attn_varlen_windowed(
alibi_slopes: None,
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
@ -802,6 +857,7 @@ pub fn flash_attn_varlen_alibi(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
@ -857,6 +913,65 @@ pub fn flash_attn_varlen_alibi_windowed(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Option, alibi slopes tensor with shape `(num_heads_q)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
/// * `window_size_left` - Option, limit left attention to value tokens.
/// * `window_size_right` - Option, limit right attention to value tokens.
/// * `softcap` - Gemma style softcap the attention logits before the softmax.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
pub fn flash_attn_varlen_alibi_windowed_softcap(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: Option<&Tensor>,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
softcap: f32,
) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: alibi_slopes.cloned(),
window_size_left,
window_size_right,
softcap: Some(softcap),
};
q.apply_op3(k, v, op)
}

View File

@ -27,6 +27,20 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<
Ok(output)
}
fn fa_acausal_softcap(q: &Tensor, k: &Tensor, v: &Tensor, softcap: f32) -> Result<Tensor> {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
// let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
let att = q.matmul(&k.t()?)?;
let att = (softcap as f64 * ((att / softcap as f64)?.tanh())?)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
Ok(output)
}
#[test]
fn flash_attn_acausal() -> Result<()> {
let device = Device::new_cuda(0)?;
@ -89,6 +103,44 @@ fn flash_attn_acausal() -> Result<()> {
Ok(())
}
#[test]
fn flash_attn_acausal_softcap() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::arange(0u32, 3 * 5 * 8, &device)?
.to_dtype(DType::F16)?
.reshape((1, 3, 5, 8))?;
let k = (&q / 40.)?;
let v = (&q / 50.)?;
let q = (&q / 30.)?;
let softcap = 5.0f32;
let ys1 = fa_acausal_softcap(&q, &k, &v, softcap.clone())?;
let ys1 = ys1.i(0)?.to_dtype(DType::F32)?;
let ys2 = {
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
candle_flash_attn::flash_attn_alibi_windowed_softcap(
&q,
&k,
&v,
None, // alibi_slopes //
1.0, // softmax //
None, // window_size_left //
None, // window_size_right //
softcap.clone(), // softcap //
)?
.transpose(1, 2)?
};
let ys2 = ys2.i(0)?.to_dtype(DType::F32)?;
let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?;
assert_eq!(ys1.dims(), &[3, 5, 8]);
assert_eq!(ys2.dims(), &[3, 5, 8]);
assert!(diff.to_vec0::<f32>()?.abs() < 1e-3);
Ok(())
}
#[test]
fn flash_attn_varlen() -> Result<()> {
let device = Device::new_cuda(0)?;

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.8.1"
version = "0.8.2"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.8.1"
version = "0.8.2"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.8.1"
version = "0.8.2"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.8.1" }
candle-nn = { path = "../candle-nn", version = "0.8.1" }
candle = { path = "../candle-core", package = "candle-core", version = "0.8.2" }
candle-nn = { path = "../candle-nn", version = "0.8.2" }
prost = "0.12.1"
[build-dependencies]

View File

@ -3,7 +3,7 @@
//! Functionality for modeling sampling strategies and logits processing in text generation
//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
//! and combinations thereof.
use candle::{DType, Error, Result, Tensor};
use candle::{Context, DType, Error, Result, Tensor};
use rand::{distributions::Distribution, SeedableRng};
#[derive(Clone, PartialEq, Debug)]
@ -45,7 +45,7 @@ impl LogitsProcessor {
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap();
.context("empty logits")?;
Ok(next_token)
}

View File

@ -6,7 +6,7 @@
//! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)
//! - 💻 [GH](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py_
use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D};
use candle::{Context, DType, IndexOp, Module, Result, Shape, Tensor, D};
use candle_nn as nn;
use super::{Activation, EncoderConfig};
@ -363,7 +363,7 @@ impl ChineseClipVisionTransformer {
.apply(&self.pre_layer_norm)?;
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
let encoder_outputs = result.last().unwrap();
let encoder_outputs = result.last().context("no last")?;
let pooled_output = encoder_outputs.i((.., 0, ..))?;
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
Ok(result)

View File

@ -6,7 +6,7 @@
//! https://github.com/openai/CLIP
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
use candle::{IndexOp, Result, Shape, Tensor, D};
use candle::{Context, IndexOp, Result, Shape, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
use nn::Conv2dConfig;
@ -149,7 +149,7 @@ impl ClipVisionTransformer {
.apply(&self.embeddings)?
.apply(&self.pre_layer_norm)?;
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
let encoder_outputs = result.last().unwrap();
let encoder_outputs = result.last().context("no last")?;
let pooled_output = encoder_outputs.i((.., 0, ..))?;
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
Ok(result)

View File

@ -4,6 +4,8 @@
//! - ["Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data"](https://github.com/LiheYoung/Depth-Anything)
//!
use std::sync::Arc;
use candle::D::Minus1;
use candle::{Module, Result, Tensor};
use candle_nn::ops::Identity;
@ -365,16 +367,18 @@ impl Scratch {
const NUM_CHANNELS: usize = 4;
pub struct DPTHead<'a> {
conf: &'a DepthAnythingV2Config,
pub struct DPTHead {
projections: Vec<Conv2d>,
resize_layers: Vec<Box<dyn Module>>,
readout_projections: Vec<Sequential>,
scratch: Scratch,
use_class_token: bool,
input_image_size: usize,
target_patch_size: usize,
}
impl<'a> DPTHead<'a> {
pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
impl DPTHead {
pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
projections.push(conv2d(
@ -445,20 +449,22 @@ impl<'a> DPTHead<'a> {
let scratch = Scratch::new(conf, vb.pp("scratch"))?;
Ok(Self {
conf,
projections,
resize_layers,
readout_projections,
scratch,
use_class_token: conf.use_class_token,
input_image_size: conf.input_image_size,
target_patch_size: conf.target_patch_size,
})
}
}
impl Module for DPTHead<'_> {
impl Module for DPTHead {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
for i in 0..NUM_CHANNELS {
let x = if self.conf.use_class_token {
let x = if self.use_class_token {
let x = xs.get(i)?.get(0)?;
let class_token = xs.get(i)?.get(1)?;
let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
@ -473,8 +479,8 @@ impl Module for DPTHead<'_> {
let x = x.permute((0, 2, 1))?.reshape((
x_dims[0],
x_dims[x_dims.len() - 1],
self.conf.target_patch_size,
self.conf.target_patch_size,
self.target_patch_size,
self.target_patch_size,
))?;
let x = self.projections[i].forward(&x)?;
@ -515,25 +521,25 @@ impl Module for DPTHead<'_> {
let out = self.scratch.output_conv1.forward(&path1)?;
let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?;
let out = out.interpolate2d(self.input_image_size, self.input_image_size)?;
self.scratch.output_conv2.forward(&out)
}
}
pub struct DepthAnythingV2<'a> {
pretrained: &'a DinoVisionTransformer,
depth_head: DPTHead<'a>,
conf: &'a DepthAnythingV2Config,
pub struct DepthAnythingV2 {
pretrained: Arc<DinoVisionTransformer>,
depth_head: DPTHead,
conf: DepthAnythingV2Config,
}
impl<'a> DepthAnythingV2<'a> {
impl DepthAnythingV2 {
pub fn new(
pretrained: &'a DinoVisionTransformer,
conf: &'a DepthAnythingV2Config,
pretrained: Arc<DinoVisionTransformer>,
conf: DepthAnythingV2Config,
vb: VarBuilder,
) -> Result<Self> {
let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?;
let depth_head = DPTHead::new(&conf, vb.pp("depth_head"))?;
Ok(Self {
pretrained,
@ -543,7 +549,7 @@ impl<'a> DepthAnythingV2<'a> {
}
}
impl Module for DepthAnythingV2<'_> {
impl Module for DepthAnythingV2 {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let features = self.pretrained.get_intermediate_layers(
xs,

View File

@ -3,7 +3,7 @@
//! See:
//! - ["EfficientBERT: Progressively Searching Multilayer Perceptron Architectures for BERT"](https://arxiv.org/abs/2201.00462)
//!
use candle::{Result, Tensor, D};
use candle::{Context, Result, Tensor, D};
use candle_nn as nn;
use nn::{Module, VarBuilder};
@ -289,7 +289,7 @@ impl EfficientNet {
pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
let f_p = p.pp("features");
let first_in_c = configs[0].input_channels;
let last_out_c = configs.last().unwrap().out_channels;
let last_out_c = configs.last().context("no last")?.out_channels;
let final_out_c = 4 * last_out_c;
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
let nconfigs = configs.len();

View File

@ -5,7 +5,7 @@
//!
//! Implementation based on [timm model](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py)
use candle::{DType, Result, Tensor, D};
use candle::{Context, DType, Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax,
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
@ -178,7 +178,7 @@ fn squeeze_and_excitation(
// based on the _fuse_bn_tensor method in timm
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
let (gamma, beta) = bn.weight_and_bias().unwrap();
let (gamma, beta) = bn.weight_and_bias().context("no weight-bias")?;
let mu = bn.running_mean();
let sigma = (bn.running_var() + bn.eps())?.sqrt();
let gps = (gamma / sigma)?;

View File

@ -14,7 +14,7 @@ use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer}
use crate::models::llama::{Cache, Llama};
use crate::models::with_tracing::linear;
use candle::{bail, Device, IndexOp, Result, Tensor};
use candle::{bail, Context, Device, IndexOp, Result, Tensor};
use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
use fancy_regex::Regex;
use utils::get_anyres_image_grid_shape;
@ -145,7 +145,7 @@ impl ClipVisionTower {
let config = if config.is_none() {
ClipVisionConfig::clip_vit_large_patch14_336()
} else {
config.clone().unwrap()
config.clone().context("no config")?
};
let select_layer = match select_layer {
-1 | -2 => select_layer,
@ -262,14 +262,14 @@ impl LLaVA {
let image_features = if mm_patch_merge_type == "flat" {
image_features
.iter()
.map(|x| x.flatten(0, 1).unwrap())
.collect::<Vec<Tensor>>()
.map(|x| x.flatten(0, 1))
.collect::<Result<Vec<Tensor>>>()?
} else if mm_patch_merge_type.starts_with("spatial") {
let mut new_image_features = Vec::new();
for (image_idx, image_feature) in image_features.iter().enumerate() {
let new_image_feature = if image_feature.dims()[0] > 1 {
let base_image_feature = image_feature.get(0).unwrap();
let patch_image_feature = image_feature.i(1..).unwrap();
let base_image_feature = image_feature.get(0)?;
let patch_image_feature = image_feature.i(1..)?;
let height = self.clip_vision_tower.num_patches_per_side();
let width = height;
assert_eq!(height * width, base_image_feature.dims()[0]);
@ -313,16 +313,12 @@ impl LLaVA {
};
Tensor::cat(&[base_image_feature, new_image_feature], 0)?
} else {
let new_image_feature = image_feature.get(0).unwrap();
let new_image_feature = image_feature.get(0)?;
if mm_patch_merge_type.contains("unpad") {
Tensor::cat(
&[
new_image_feature,
self.image_newline.clone().unsqueeze(0).unwrap(),
],
&[new_image_feature, self.image_newline.clone().unsqueeze(0)?],
0,
)
.unwrap()
)?
} else {
new_image_feature
}

View File

@ -262,7 +262,8 @@ impl Attention {
.contiguous()?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let (query_states, key_states) =
self.rotary_emb

View File

@ -109,4 +109,5 @@ pub mod vit;
pub mod whisper;
pub mod with_tracing;
pub mod wuerstchen;
pub mod xlm_roberta;
pub mod yi;

View File

@ -1,8 +1,8 @@
use candle::{DType, Module, Result, Tensor, D};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};
fn default_act() -> candle_nn::Activation {
candle_nn::Activation::Gelu
candle_nn::Activation::Silu
}
fn default_hidden_size() -> usize {
@ -58,7 +58,7 @@ impl Config {
num_attention_heads: 16,
head_dim: None,
// Default
hidden_act: candle_nn::Activation::Gelu,
hidden_act: candle_nn::Activation::Silu,
}
}
@ -104,6 +104,7 @@ impl Attention {
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
subsampled_positions: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b, patches, _) = xs.dims3()?;
@ -116,7 +117,8 @@ impl Attention {
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?;
let (query_states, key_states) =
emb.apply_rotary_emb_qkv(&query_states, &key_states, subsampled_positions)?;
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
let attn_weights = match attention_mask {
@ -189,12 +191,16 @@ impl AttentionLayer {
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
subsampled_positions: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs;
let xs = self
.attention
.forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?;
let xs = self.attention.forward(
&xs.apply(&self.attention_norm)?,
emb,
subsampled_positions,
attention_mask,
)?;
let xs = (residual + xs)?;
let residual = &xs;
let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
@ -222,11 +228,12 @@ impl Transformer {
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
subsampled_positions: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, emb, attention_mask)?
xs = layer.forward(&xs, emb, subsampled_positions, attention_mask)?
}
Ok(xs)
}
@ -270,10 +277,20 @@ impl RotaryEmbedding {
Ok(Self { cos, sin })
}
fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
subsampled_positions: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
let cos = &self.cos;
let sin = &self.sin;
let (cos, sin) = match subsampled_positions {
None => (&self.cos, &self.sin),
Some(pos) => (
&self.cos.index_select(pos, 0)?,
&self.sin.index_select(pos, 0)?,
),
};
let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
Ok((q_embed, k_embed))
@ -286,6 +303,7 @@ pub struct Model {
ln_pre: RmsNorm,
transformer: Transformer,
patch_positional_embedding: RotaryEmbedding,
max_image_width: u32,
}
impl Model {
@ -305,20 +323,44 @@ impl Model {
let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
let patch_positional_embedding =
RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
let max_image_width = (cfg.image_size / cfg.patch_size) as u32;
Ok(Self {
patch_conv,
ln_pre,
transformer,
patch_positional_embedding,
max_image_width,
})
}
pub fn position_ids_in_meshgrid(
&self,
num_patches_h: usize,
num_patches_w: usize,
device: &Device,
) -> Result<Tensor> {
let idx = Tensor::arange(0, num_patches_h as u32, device)?;
let idy = Tensor::arange(0, num_patches_w as u32, device)?;
let mesh = Tensor::meshgrid(&[idx, idy], false)?;
let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?.flatten_all()?;
Ok(ids)
}
}
impl Module for Model {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let patch_embeds = xs.apply(&self.patch_conv)?;
let subsampled_positions = Some(self.position_ids_in_meshgrid(
patch_embeds.dim(2)?,
patch_embeds.dim(3)?,
patch_embeds.device(),
)?);
let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
self.transformer
.forward(&patch_embeds, &self.patch_positional_embedding, None)
self.transformer.forward(
&patch_embeds,
&self.patch_positional_embedding,
subsampled_positions.as_ref(),
None,
)
}
}

View File

@ -15,7 +15,7 @@
//!
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
use candle::{Module, ModuleT, Result, Tensor, D};
use candle::{Context, Module, ModuleT, Result, Tensor, D};
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
@ -633,7 +633,7 @@ impl ImageClassificationModel {
impl Module for ImageClassificationModel {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let all_hidden_states = self.segformer.forward(x)?;
let hidden_states = all_hidden_states.last().unwrap();
let hidden_states = all_hidden_states.last().context("no last")?;
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
let mean = hidden_states.mean(1)?;
self.classifier.forward(&mean)

View File

@ -127,7 +127,7 @@ impl DDIMScheduler {
impl Scheduler for DDIMScheduler {
/// Performs a backward step during inference.
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {

View File

@ -171,7 +171,7 @@ impl Scheduler for EulerAncestralDiscreteScheduler {
}
/// Performs a backward step during inference.
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let step_index = self
.timesteps
.iter()

View File

@ -47,6 +47,7 @@ pub mod resnet;
pub mod schedulers;
pub mod unet_2d;
pub mod unet_2d_blocks;
pub mod uni_pc;
pub mod utils;
pub mod vae;

View File

@ -19,7 +19,7 @@ pub trait Scheduler {
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
}
/// This represents how beta ranges from its minimum value to the maximum

File diff suppressed because it is too large Load Diff

View File

@ -204,6 +204,7 @@ pub fn log_mel_spectrogram_<T: Float>(
// ensure that the number of threads is even and less than 12
let n_threads = std::cmp::min(get_num_threads() - get_num_threads() % 2, 12);
let n_threads = std::cmp::max(n_threads, 2);
let hann = Arc::new(hann);
let samples = Arc::new(samples);

View File

@ -0,0 +1,545 @@
use crate::models::with_tracing::{linear, Linear};
use candle::{DType, Module, Result, Tensor};
use candle_nn::{
embedding, layer_norm, ops::softmax_last_dim, Activation, Embedding, LayerNorm, VarBuilder,
};
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub hidden_size: usize,
pub layer_norm_eps: f64,
pub attention_probs_dropout_prob: f32,
pub hidden_dropout_prob: f32,
pub num_attention_heads: usize,
pub position_embedding_type: String,
pub intermediate_size: usize,
pub hidden_act: Activation,
pub num_hidden_layers: usize,
pub vocab_size: usize,
pub max_position_embeddings: usize,
pub type_vocab_size: usize,
pub pad_token_id: u32,
}
struct XLMRobertaEmbeddings {
word_embeddings: Embedding,
position_embeddings: Option<Embedding>,
token_type_embeddings: Embedding,
layer_norm: LayerNorm,
padding_idx: u32,
span: tracing::Span,
}
impl XLMRobertaEmbeddings {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let word_embeddings = embedding(
config.vocab_size,
config.hidden_size,
vb.pp("word_embeddings"),
)?;
let position_embeddings = embedding(
config.max_position_embeddings,
config.hidden_size,
vb.pp("position_embeddings"),
)?;
let token_type_embeddings = embedding(
config.type_vocab_size,
config.hidden_size,
vb.pp("token_type_embeddings"),
)?;
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
vb.pp("LayerNorm"),
)?;
Ok(Self {
word_embeddings,
position_embeddings: Some(position_embeddings),
token_type_embeddings,
layer_norm,
padding_idx: config.pad_token_id,
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
})
}
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (_bsize, _) = input_ids.dims2()?;
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
if let Some(position_embeddings) = &self.position_embeddings {
let mask = input_ids
.ne(self.padding_idx)?
.to_dtype(input_embeddings.dtype())?;
let cumsum = mask.cumsum(1)?;
let position_ids = (cumsum * mask)?
.broadcast_add(
&Tensor::try_from(self.padding_idx)?
.to_dtype(input_embeddings.dtype())?
.to_device(input_embeddings.device())?,
)?
.to_dtype(candle::DType::U32)?;
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?;
}
let embeddings = self.layer_norm.forward(&embeddings)?;
Ok(embeddings)
}
}
struct XLMRobertaSelfAttention {
num_attention_heads: usize,
attention_head_size: usize,
all_head_size: usize,
query: Linear,
key: Linear,
value: Linear,
}
impl XLMRobertaSelfAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
let all_head_size = cfg.num_attention_heads * attention_head_size;
Ok(Self {
num_attention_heads: cfg.num_attention_heads,
attention_head_size,
all_head_size,
query: linear(cfg.hidden_size, all_head_size, vb.pp("query"))?,
key: linear(cfg.hidden_size, all_head_size, vb.pp("key"))?,
value: linear(cfg.hidden_size, all_head_size, vb.pp("value"))?,
})
}
fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor> {
let mut new_x_shape = x.dims().to_vec();
new_x_shape[2] = self.num_attention_heads;
new_x_shape.push(self.attention_head_size);
let x = x.reshape(new_x_shape)?;
x.permute((0, 2, 1, 3))?.contiguous()
}
fn forward(
&self,
hidden_states: &Tensor,
encoder_hidden_states: Option<&Tensor>,
attention_mask: &Tensor,
past_key_value: Option<(&Tensor, &Tensor)>,
encoder_attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let mixed_query_layer = self.query.forward(hidden_states)?;
let is_cross_attention = encoder_hidden_states.is_some();
let (key_layer, value_layer, attention_mask) = if is_cross_attention
&& past_key_value.is_some()
{
let key_layer = past_key_value.unwrap().0.clone();
let value_layer = past_key_value.unwrap().1.clone();
let attention_mask = encoder_attention_mask.unwrap().clone();
(key_layer, value_layer, Some(attention_mask))
} else if is_cross_attention {
let key_layer =
self.transpose_for_scores(&self.key.forward(encoder_hidden_states.unwrap())?)?;
let value_layer =
self.transpose_for_scores(&self.value.forward(encoder_hidden_states.unwrap())?)?;
let attention_mask = encoder_attention_mask.unwrap();
(key_layer, value_layer, Some(attention_mask.clone()))
} else if past_key_value.is_some() {
let mut key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
let mut value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
key_layer = Tensor::cat(
&[
past_key_value.clone().as_ref().unwrap().0.clone(),
key_layer,
],
2,
)?;
value_layer = Tensor::cat(
&[past_key_value.as_ref().unwrap().1.clone(), value_layer],
2,
)?;
(key_layer, value_layer, Some(attention_mask.clone()))
} else {
let key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
let value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
(key_layer, value_layer, Some(attention_mask.clone()))
};
let query_layer = self.transpose_for_scores(&mixed_query_layer)?;
let mut attention_scores = query_layer.matmul(&key_layer.transpose(2, 3)?)?;
let scale = 1f64 / f64::sqrt(self.attention_head_size as f64);
attention_scores = (attention_scores * scale)?;
attention_scores = match attention_mask {
None => attention_scores,
Some(mask) => {
attention_scores.broadcast_add(&mask.to_dtype(attention_scores.dtype())?)?
}
};
let attention_probs = softmax_last_dim(&attention_scores)?;
let context_layer = attention_probs
.matmul(&value_layer)?
.permute((0, 2, 1, 3))?
.contiguous()?;
let mut new_context_layer_shape =
context_layer.dims()[..context_layer.dims().len() - 2].to_vec();
new_context_layer_shape.push(self.all_head_size);
let context_layer = context_layer.reshape(new_context_layer_shape)?;
Ok(context_layer)
}
}
struct XLMRobertaSelfOutput {
dense: Linear,
layernorm: LayerNorm,
}
impl XLMRobertaSelfOutput {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
let layernorm =
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
Ok(Self { dense, layernorm })
}
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
Ok(hidden_states)
}
}
struct XLMRobertaAttention {
output: XLMRobertaSelfOutput,
self_attention: XLMRobertaSelfAttention,
}
impl XLMRobertaAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let output = XLMRobertaSelfOutput::new(cfg, vb.pp("output"))?;
let self_attention = XLMRobertaSelfAttention::new(cfg, vb.pp("self"))?;
Ok(Self {
output,
self_attention,
})
}
fn forward(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
past_key_value: Option<(&Tensor, &Tensor)>,
) -> Result<(Tensor, Tensor)> {
let self_outputs = self.self_attention.forward(
hidden_states,
encoder_hidden_states,
attention_mask,
past_key_value,
encoder_attention_mask,
)?;
let attention_output = self.output.forward(&self_outputs, hidden_states)?;
Ok((attention_output, self_outputs))
}
}
struct XLMRobertaOutput {
dense: Linear,
layernorm: LayerNorm,
}
impl XLMRobertaOutput {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
let layernorm =
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
Ok(Self { dense, layernorm })
}
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
Ok(hidden_states)
}
}
struct XLMRobertaIntermediate {
dense: Linear,
intermediate_act_fn: Activation,
}
impl XLMRobertaIntermediate {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
let intermediate_act_fn = cfg.hidden_act;
Ok(Self {
dense,
intermediate_act_fn,
})
}
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.intermediate_act_fn.forward(&hidden_states)?;
Ok(hidden_states)
}
}
struct XLMRobertaLayer {
attention: XLMRobertaAttention,
intermediate: XLMRobertaIntermediate,
output: XLMRobertaOutput,
}
impl XLMRobertaLayer {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let attention = XLMRobertaAttention::new(cfg, vb.pp("attention"))?;
let intermediate = XLMRobertaIntermediate::new(cfg, vb.pp("intermediate"))?;
let output = XLMRobertaOutput::new(cfg, vb.pp("output"))?;
Ok(Self {
attention,
intermediate,
output,
})
}
fn forward(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
past_key_value: Option<(&Tensor, &Tensor)>,
) -> Result<(Tensor, Tensor)> {
let self_attention_outputs = self.attention.forward(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
)?;
let attention_output = self_attention_outputs.0;
let outputs = self_attention_outputs.1;
let intermediate_output = self.intermediate.forward(&attention_output)?;
let layer_output = self
.output
.forward(&intermediate_output, &attention_output)?;
Ok((layer_output, outputs))
}
}
struct XLMRobertaEncoder {
layers: Vec<XLMRobertaLayer>,
}
impl XLMRobertaEncoder {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let layers = (0..cfg.num_hidden_layers)
.map(|i| XLMRobertaLayer::new(cfg, vb.pp(format!("layer.{}", i))))
.collect::<Result<Vec<_>>>()?;
Ok(Self { layers })
}
fn forward(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
past_key_value: Option<(&Tensor, &Tensor)>,
) -> Result<Tensor> {
let mut hidden_states = hidden_states.clone();
for layer_module in self.layers.iter() {
let layer_outputs = layer_module.forward(
&hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
)?;
hidden_states = layer_outputs.0;
}
Ok(hidden_states)
}
}
pub struct XLMRobertaModel {
encoder: XLMRobertaEncoder,
embeddings: XLMRobertaEmbeddings,
}
impl XLMRobertaModel {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let encoder = XLMRobertaEncoder::new(cfg, vb.pp("encoder"))?;
let embeddings = XLMRobertaEmbeddings::load(vb.pp("embeddings"), cfg)?;
Ok(Self {
encoder,
embeddings,
})
}
pub fn forward(
&self,
input_ids: &Tensor,
attention_mask: &Tensor,
token_type_ids: &Tensor,
past_key_value: Option<(&Tensor, &Tensor)>,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let hidden_states = self.embeddings.forward(input_ids, token_type_ids)?;
let attention_mask = prepare_4d_attention_mask(attention_mask, DType::F32, None)?
.to_device(hidden_states.device())?;
let hidden_states = self.encoder.forward(
&hidden_states,
&attention_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
)?;
Ok(hidden_states)
}
}
struct XLMRobertaLMHead {
dense: Linear,
layer_norm: LayerNorm,
}
impl XLMRobertaLMHead {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
let layer_norm =
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layer_norm"))?;
Ok(Self { dense, layer_norm })
}
fn forward(&self, hidden_states: &Tensor, shared_embeddings: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = candle_nn::Activation::Gelu.forward(&hidden_states)?;
let hidden_states = self.layer_norm.forward(&hidden_states)?;
let hidden_states = hidden_states.broadcast_matmul(shared_embeddings)?;
Ok(hidden_states)
}
}
pub struct XLMRobertaForMaskedLM {
roberta: XLMRobertaModel,
lm_head: XLMRobertaLMHead,
}
impl XLMRobertaForMaskedLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
let lm_head = XLMRobertaLMHead::new(cfg, vb.pp("lm_head"))?;
Ok(Self { roberta, lm_head })
}
pub fn forward(
&self,
input_ids: &Tensor,
attention_mask: &Tensor,
token_type_ids: &Tensor,
past_key_value: Option<(&Tensor, &Tensor)>,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let hidden_states = self.roberta.forward(
input_ids,
attention_mask,
token_type_ids,
past_key_value,
encoder_hidden_states,
encoder_attention_mask,
)?;
let lm_logits = self.lm_head.forward(
&hidden_states,
&self
.roberta
.embeddings
.word_embeddings
.embeddings()
.t()?
.unsqueeze(0)?,
)?;
Ok(lm_logits)
}
}
struct XLMRobertaClassificationHead {
dense: Linear,
out_proj: Linear,
}
impl XLMRobertaClassificationHead {
fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
let out_proj = linear(cfg.hidden_size, num_labels, vb.pp("out_proj"))?;
Ok(Self { dense, out_proj })
}
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
let hidden_states = self.dense.forward(&cls_states)?;
let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
let hidden_states = self.out_proj.forward(&hidden_states)?;
Ok(hidden_states)
}
}
pub struct XLMRobertaForSequenceClassification {
roberta: XLMRobertaModel,
classifier: XLMRobertaClassificationHead,
}
impl XLMRobertaForSequenceClassification {
pub fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
let classifier = XLMRobertaClassificationHead::new(num_labels, cfg, vb.pp("classifier"))?;
Ok(Self {
roberta,
classifier,
})
}
pub fn forward(
&self,
input_ids: &Tensor,
attention_mask: &Tensor,
token_type_ids: &Tensor,
) -> Result<Tensor> {
let hidden_states =
self.roberta
.forward(input_ids, attention_mask, token_type_ids, None, None, None)?;
self.classifier.forward(&hidden_states)
}
}
fn prepare_4d_attention_mask(
mask: &Tensor,
dtype: DType,
tgt_len: Option<usize>,
) -> Result<Tensor> {
let bsz = mask.dim(0)?;
let src_len = mask.dim(1)?;
let tgt_len = tgt_len.unwrap_or(src_len);
let expanded_mask = mask
.unsqueeze(1)?
.unsqueeze(2)?
.expand((bsz, 1, tgt_len, src_len))?
.to_dtype(dtype)?;
let inverted_mask = (1.0 - expanded_mask)?;
(inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype)
}
fn get_dtype_min_val(dtype: DType) -> f64 {
match dtype {
DType::F32 => f32::MIN as f64,
DType::F64 => f64::MIN,
_ => panic!("Unsupported data type"),
}
}