mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Compare commits
19 Commits
Author | SHA1 | Date | |
---|---|---|---|
236c35e578 | |||
6f8351dfda | |||
57f41da13b | |||
cbaa0ad46f | |||
b12c7c2888 | |||
94ffc2ec6f | |||
7354afc673 | |||
2a705e6f37 | |||
a594ef669c | |||
71cd6d5533 | |||
d60eba1408 | |||
e38e2a85dd | |||
460616fc84 | |||
91f1f019b1 | |||
cd639131f0 | |||
11aa30be10 | |||
1be6b090c7 | |||
62ced44ea9 | |||
5c2f893e5a |
20
Cargo.toml
20
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.8.1"
|
||||
version = "0.8.2"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,20 +33,20 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.8.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.8.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.8.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.8.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.8.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.8.1" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.8.2" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.8.2" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.2" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.8.2" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.2" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.8.2" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.8.2" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.8.2" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = { version = "0.3.3", package = "candle-hf-hub" }
|
||||
hf-hub = "0.4.1"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
hound = "3.5.1"
|
||||
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
||||
|
@ -189,6 +189,7 @@ And then head over to
|
||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
|
||||
- [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible.
|
||||
- [`llms-from-scratch-rs`](https://github.com/nerdai/llms-from-scratch-rs): A comprehensive Rust translation of the code from Sebastian Raschka's Build an LLM from Scratch book.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
|
||||
|
@ -11,8 +11,8 @@ Then let's start by downloading the [model file](https://huggingface.co/bert-bas
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# extern crate candle_hf_hub;
|
||||
use candle_hf_hub::api::sync::Api;
|
||||
# extern crate hf_hub;
|
||||
use hf_hub::api::sync::Api;
|
||||
use candle_core::Device;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
@ -50,8 +50,8 @@ Now that we have our weights, we can use them in our bert architecture:
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# extern crate candle_nn;
|
||||
# extern crate candle_hf_hub;
|
||||
# use candle_hf_hub::api::sync::Api;
|
||||
# extern crate hf_hub;
|
||||
# use hf_hub::api::sync::Api;
|
||||
#
|
||||
# let api = Api::new().unwrap();
|
||||
# let repo = api.model("bert-base-uncased".to_string());
|
||||
|
@ -9,8 +9,14 @@ pub struct MatMulUnexpectedStriding {
|
||||
pub msg: &'static str,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{self}")
|
||||
}
|
||||
}
|
||||
|
||||
/// Main library error type.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[derive(thiserror::Error)]
|
||||
pub enum Error {
|
||||
// === DType Errors ===
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
@ -199,8 +205,14 @@ pub enum Error {
|
||||
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
||||
|
||||
/// Arbitrary errors wrapping.
|
||||
#[error(transparent)]
|
||||
Wrapped(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("{0}")]
|
||||
Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
|
||||
|
||||
#[error("{context}\n{inner}")]
|
||||
Context {
|
||||
inner: Box<Self>,
|
||||
context: Box<dyn std::fmt::Display + Send + Sync>,
|
||||
},
|
||||
|
||||
/// Adding path information to an error.
|
||||
#[error("path: {path:?} {inner}")]
|
||||
@ -218,16 +230,19 @@ pub enum Error {
|
||||
/// User generated error message, typically created via `bail!`.
|
||||
#[error("{0}")]
|
||||
Msg(String),
|
||||
|
||||
#[error("unwrap none")]
|
||||
UnwrapNone,
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
impl Error {
|
||||
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
|
||||
Self::Wrapped(Box::new(err)).bt()
|
||||
}
|
||||
|
||||
pub fn msg(err: impl std::error::Error) -> Self {
|
||||
pub fn msg(err: impl std::fmt::Display) -> Self {
|
||||
Self::Msg(err.to_string()).bt()
|
||||
}
|
||||
|
||||
@ -253,6 +268,13 @@ impl Error {
|
||||
path: p.as_ref().to_path_buf(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
|
||||
Self::Context {
|
||||
inner: Box::new(self),
|
||||
context: Box::new(c),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
@ -275,3 +297,41 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
|
||||
(_, Err(e)) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
// Taken from anyhow.
|
||||
pub trait Context<T> {
|
||||
/// Wrap the error value with additional context.
|
||||
fn context<C>(self, context: C) -> Result<T>
|
||||
where
|
||||
C: std::fmt::Display + Send + Sync + 'static;
|
||||
|
||||
/// Wrap the error value with additional context that is evaluated lazily
|
||||
/// only once an error does occur.
|
||||
fn with_context<C, F>(self, f: F) -> Result<T>
|
||||
where
|
||||
C: std::fmt::Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C;
|
||||
}
|
||||
|
||||
impl<T> Context<T> for Option<T> {
|
||||
fn context<C>(self, context: C) -> Result<T>
|
||||
where
|
||||
C: std::fmt::Display + Send + Sync + 'static,
|
||||
{
|
||||
match self {
|
||||
Some(v) => Ok(v),
|
||||
None => Err(Error::UnwrapNone.context(context).bt()),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_context<C, F>(self, f: F) -> Result<T>
|
||||
where
|
||||
C: std::fmt::Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C,
|
||||
{
|
||||
match self {
|
||||
Some(v) => Ok(v),
|
||||
None => Err(Error::UnwrapNone.context(f()).bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef};
|
||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
|
||||
pub use device::{Device, DeviceLocation, NdArray};
|
||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use error::{Context, Error, Result};
|
||||
pub use indexer::{IndexOp, TensorIndexer};
|
||||
pub use layout::Layout;
|
||||
pub use shape::{Shape, D};
|
||||
|
@ -1,7 +1,7 @@
|
||||
//! Just enough pickle support to be able to read PyTorch checkpoints.
|
||||
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
|
||||
// composable/tensor agnostic at some point.
|
||||
use crate::{DType, Error as E, Layout, Result, Tensor};
|
||||
use crate::{Context, DType, Error as E, Layout, Result, Tensor};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
use std::io::BufRead;
|
||||
@ -537,7 +537,7 @@ impl Stack {
|
||||
crate::bail!("setitems: not an even number of objects")
|
||||
}
|
||||
while let Some(value) = objs.pop() {
|
||||
let key = objs.pop().unwrap();
|
||||
let key = objs.pop().context("empty objs")?;
|
||||
d.push((key, value))
|
||||
}
|
||||
} else {
|
||||
@ -557,7 +557,7 @@ impl Stack {
|
||||
crate::bail!("setitems: not an even number of objects")
|
||||
}
|
||||
while let Some(value) = objs.pop() {
|
||||
let key = objs.pop().unwrap();
|
||||
let key = objs.pop().context("empty objs")?;
|
||||
pydict.push((key, value))
|
||||
}
|
||||
self.push(Object::Dict(pydict))
|
||||
@ -661,7 +661,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
if !file_name.ends_with("data.pkl") {
|
||||
continue;
|
||||
}
|
||||
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
|
||||
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
|
||||
let reader = zip.by_name(file_name)?;
|
||||
let mut reader = std::io::BufReader::new(reader);
|
||||
let mut stack = Stack::empty();
|
||||
|
@ -2,7 +2,7 @@
|
||||
//!
|
||||
|
||||
use super::{GgmlDType, QTensor};
|
||||
use crate::{Device, Result};
|
||||
use crate::{Context, Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -338,7 +338,7 @@ impl Value {
|
||||
if value_type.len() != 1 {
|
||||
crate::bail!("multiple value-types in the same array {value_type:?}")
|
||||
}
|
||||
value_type.into_iter().next().unwrap()
|
||||
value_type.into_iter().next().context("empty value_type")?
|
||||
};
|
||||
w.write_u32::<LittleEndian>(value_type.to_u32())?;
|
||||
w.write_u64::<LittleEndian>(v.len() as u64)?;
|
||||
|
@ -1,5 +1,5 @@
|
||||
//! Code for GGML and GGUF files
|
||||
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
|
||||
use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
|
||||
use k_quants::*;
|
||||
use std::borrow::Cow;
|
||||
|
||||
@ -481,7 +481,7 @@ impl crate::CustomOp1 for QTensor {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
let last_k = dst_shape.pop().context("empty dst_shape")?;
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
}
|
||||
|
@ -52,6 +52,49 @@ impl ArgSort {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
mod cuda {
|
||||
use super::*;
|
||||
use crate::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
||||
use crate::{CudaDevice, WithDType};
|
||||
|
||||
impl crate::cuda_backend::Map1Any for ArgSort {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &crate::Layout,
|
||||
_wrap: W,
|
||||
) -> Result<S> {
|
||||
let slice = match layout.contiguous_offsets() {
|
||||
None => crate::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = if self.asc {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
||||
} else {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
||||
};
|
||||
let ncols = self.last_dim;
|
||||
let nrows = elem_count / ncols;
|
||||
let ncols_pad = next_power_of_2(ncols);
|
||||
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (1, nrows as u32, 1),
|
||||
block_dim: (ncols_pad as u32, 1, 1),
|
||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(S::U32(dst))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for ArgSort {
|
||||
fn name(&self) -> &'static str {
|
||||
"argsort"
|
||||
@ -81,46 +124,8 @@ impl crate::CustomOp1 for ArgSort {
|
||||
storage: &crate::CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
||||
use crate::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
|
||||
use crate::{CudaDevice, WithDType};
|
||||
|
||||
impl Map1Any for ArgSort {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &crate::Layout,
|
||||
_wrap: W,
|
||||
) -> Result<S> {
|
||||
let slice = match layout.contiguous_offsets() {
|
||||
None => crate::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = if self.asc {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
||||
} else {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
||||
};
|
||||
let ncols = self.last_dim;
|
||||
let nrows = elem_count / ncols;
|
||||
let ncols_pad = next_power_of_2(ncols);
|
||||
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (1, nrows as u32, 1),
|
||||
block_dim: (ncols_pad as u32, 1, 1),
|
||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(S::U32(dst))
|
||||
}
|
||||
}
|
||||
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::cuda_backend::Map1Any;
|
||||
let dev = storage.device();
|
||||
let slice = self.map(&storage.slice, dev, layout)?;
|
||||
let dst = crate::cuda_backend::CudaStorage {
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::{shape::Dim, Error, Result, Shape, Tensor};
|
||||
use crate::{shape::Dim, Context, Error, Result, Shape, Tensor};
|
||||
|
||||
impl Tensor {
|
||||
/// Concatenates two or more tensors along a particular dimension.
|
||||
@ -134,7 +134,7 @@ impl Tensor {
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||
let next_offset = offsets.last().context("empty offsets")? + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
|
@ -78,7 +78,7 @@ impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> {
|
||||
match self.inner.inner.next() {
|
||||
Some(item) => items.push(item),
|
||||
None => {
|
||||
if self.return_last_incomplete_batch {
|
||||
if self.return_last_incomplete_batch && !items.is_empty() {
|
||||
break;
|
||||
}
|
||||
return None;
|
||||
@ -102,7 +102,7 @@ impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> {
|
||||
ys.push(y)
|
||||
}
|
||||
None => {
|
||||
if self.return_last_incomplete_batch {
|
||||
if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() {
|
||||
break;
|
||||
}
|
||||
return None;
|
||||
@ -127,7 +127,7 @@ impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> {
|
||||
match self.inner.inner.next() {
|
||||
Some(item) => items.push(item),
|
||||
None => {
|
||||
if self.return_last_incomplete_batch {
|
||||
if self.return_last_incomplete_batch && !items.is_empty() {
|
||||
break;
|
||||
}
|
||||
return None;
|
||||
@ -154,7 +154,7 @@ impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResu
|
||||
}
|
||||
Some(Err(err)) => errs.push(err),
|
||||
None => {
|
||||
if self.return_last_incomplete_batch {
|
||||
if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() {
|
||||
break;
|
||||
}
|
||||
return None;
|
||||
|
@ -13,7 +13,7 @@ THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios,
|
||||
|
||||
** Running with ~cpu~
|
||||
#+begin_src shell
|
||||
cargo run --example codegeex4-9b --release --cpu -- --prompt "please write a insertion sort in rust" --sample-len 300
|
||||
cargo run --example codegeex4-9b --release -- --cpu --prompt "please write a insertion sort in rust" --sample-len 300
|
||||
#+end_src
|
||||
|
||||
** Output_Example
|
||||
|
@ -6,10 +6,8 @@ extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use std::ffi::OsString;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::Parser;
|
||||
use std::{ffi::OsString, path::PathBuf, sync::Arc};
|
||||
|
||||
use candle::DType::{F32, U8};
|
||||
use candle::{DType, Device, Module, Result, Tensor};
|
||||
@ -82,7 +80,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
|
||||
let config = DepthAnythingV2Config::vit_small();
|
||||
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
|
||||
let depth_anything = DepthAnythingV2::new(Arc::new(dinov2), config, vb)?;
|
||||
|
||||
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;
|
||||
|
||||
|
@ -250,7 +250,11 @@ fn run(args: Args) -> Result<()> {
|
||||
};
|
||||
println!("img\n{img}");
|
||||
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
|
||||
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
|
||||
let filename = match args.seed {
|
||||
None => "out.jpg".to_string(),
|
||||
Some(s) => format!("out-{s}.jpg"),
|
||||
};
|
||||
candle_examples::save_image(&img.i(0)?, filename)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -7,48 +7,25 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
|
||||
** Running with ~cuda~
|
||||
|
||||
#+begin_src shell
|
||||
cargo run --example glm4 --release --features cuda
|
||||
cargo run --example glm4 --release --features cuda -- --prompt "Hello world"
|
||||
#+end_src
|
||||
|
||||
** Running with ~cpu~
|
||||
#+begin_src shell
|
||||
cargo run --example glm4 --release -- --cpu
|
||||
cargo run --example glm4 --release -- --cpu--prompt "Hello world"
|
||||
#+end_src
|
||||
|
||||
** Output Example
|
||||
#+begin_src shell
|
||||
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
|
||||
Finished release [optimized] target(s) in 0.24s
|
||||
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
|
||||
cargo run --features cuda -r --example glm4 -- --prompt "Hello "
|
||||
|
||||
avx: true, neon: false, simd128: false, f16c: true
|
||||
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
|
||||
cache path .
|
||||
retrieved the files in 6.88963ms
|
||||
loaded the model in 6.113752297s
|
||||
retrieved the files in 6.454375ms
|
||||
loaded the model in 3.652383779s
|
||||
starting the inference loop
|
||||
[欢迎使用GLM-4,请输入prompt]
|
||||
请你告诉我什么是FFT
|
||||
266 tokens generated (34.50 token/s)
|
||||
Result:
|
||||
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换(DFT)的方法,它广泛应用于信号处理、图像处理和数据分析等领域。
|
||||
|
||||
具体来说,FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中,我们通常需要知道信号的频率成分,这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高,而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。
|
||||
|
||||
以下是使用 Python 中的 numpy 进行 FFT 的简单示例:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
# 创建一个时域信号
|
||||
t = np.linspace(0, 1, num=100)
|
||||
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)
|
||||
|
||||
# 对该信号做FFT变换,并计算其幅值谱
|
||||
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))
|
||||
|
||||
```
|
||||
|
||||
在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
|
||||
Hello 2018, hello new year! I’m so excited to be back and sharing with you all my favorite things from the past month. This is a monthly series where I share what’s been inspiring me lately in hopes that it will inspire you too!
|
||||
...
|
||||
#+end_src
|
||||
|
||||
This example will read prompt from stdin
|
||||
|
@ -12,136 +12,117 @@ struct TextGeneration {
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
args: Args,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
dtype: DType,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
fn new(model: Model, tokenizer: Tokenizer, args: Args, device: &Device, dtype: DType) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
args,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
|
||||
use std::io::BufRead;
|
||||
use std::io::BufReader;
|
||||
fn run(&mut self) -> anyhow::Result<()> {
|
||||
use std::io::Write;
|
||||
let args = &self.args;
|
||||
println!("starting the inference loop");
|
||||
println!("[欢迎使用GLM-4,请输入prompt]");
|
||||
let stdin = std::io::stdin();
|
||||
let reader = BufReader::new(stdin);
|
||||
for line in reader.lines() {
|
||||
let line = line.expect("Failed to read line");
|
||||
|
||||
let tokens = self.tokenizer.encode(line, true).expect("tokens error");
|
||||
if tokens.is_empty() {
|
||||
panic!("Empty prompts are not supported in the chatglm model.")
|
||||
let tokens = self
|
||||
.tokenizer
|
||||
.encode(args.prompt.to_string(), true)
|
||||
.expect("tokens error");
|
||||
if tokens.is_empty() {
|
||||
panic!("Empty prompts are not supported in the chatglm model.")
|
||||
}
|
||||
if args.verbose {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => panic!("cannot find the endoftext token"),
|
||||
} else {
|
||||
print!("{}", &args.prompt);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => panic!("cannot find the endoftext token"),
|
||||
};
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
|
||||
std::io::stdout().flush().expect("output flush error");
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
||||
for index in 0..args.sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
|
||||
std::io::stdout().flush().expect("output flush error");
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
||||
let mut count = 0;
|
||||
let mut result = vec![];
|
||||
for index in 0..sample_len {
|
||||
count += 1;
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self
|
||||
.tokenizer
|
||||
.decode(&[next_token], true)
|
||||
.expect("Token error");
|
||||
if self.verbose_prompt {
|
||||
println!(
|
||||
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
|
||||
count, next_token, token
|
||||
);
|
||||
}
|
||||
result.push(token);
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self
|
||||
.tokenizer
|
||||
.decode(&[next_token], true)
|
||||
.expect("token decode error");
|
||||
if args.verbose {
|
||||
println!(
|
||||
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
|
||||
generated_tokens, next_token, token
|
||||
);
|
||||
} else {
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
println!("Result:");
|
||||
for tokens in result {
|
||||
print!("{tokens}");
|
||||
}
|
||||
self.model.reset_kv_cache(); // clean the cache
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(name = "cache", short, long, default_value = ".")]
|
||||
cache_path: String,
|
||||
#[arg(name = "cache", short)]
|
||||
cache_path: Option<String>,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
prompt: String,
|
||||
|
||||
/// Display the tokens for the specified prompt and outputs.
|
||||
#[arg(long)]
|
||||
verbose: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
@ -197,28 +178,32 @@ fn main() -> anyhow::Result<()> {
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
println!("cache path {}", args.cache_path);
|
||||
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
|
||||
.build()
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let api = match args.cache_path.as_ref() {
|
||||
None => hf_hub::api::sync::Api::new()?,
|
||||
Some(path) => {
|
||||
hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(path.to_string().into()))
|
||||
.build()
|
||||
.map_err(anyhow::Error::msg)?
|
||||
}
|
||||
};
|
||||
|
||||
let model_id = match args.model_id {
|
||||
let model_id = match args.model_id.as_ref() {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => "THUDM/glm-4-9b".to_string(),
|
||||
};
|
||||
let revision = match args.revision {
|
||||
let revision = match args.revision.as_ref() {
|
||||
Some(rev) => rev.to_string(),
|
||||
None => "main".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
let tokenizer_filename = match args.tokenizer.as_ref() {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("THUDM/codegeex4-all-9b".to_string())
|
||||
.get("tokenizer.json")
|
||||
.map_err(anyhow::Error::msg)?,
|
||||
};
|
||||
let filenames = match args.weight_file {
|
||||
let filenames = match args.weight_file.as_ref() {
|
||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
@ -238,18 +223,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
dtype,
|
||||
);
|
||||
pipeline.run(args.sample_len)?;
|
||||
let mut pipeline = TextGeneration::new(model, tokenizer, args, &device, dtype);
|
||||
pipeline.run()?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -477,7 +477,7 @@ fn run(args: Args) -> Result<()> {
|
||||
),
|
||||
};
|
||||
|
||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||
let mut scheduler = sd_config.build_scheduler(n_steps)?;
|
||||
let device = candle_examples::device(cpu)?;
|
||||
if let Some(seed) = seed {
|
||||
device.set_seed(seed)?;
|
||||
@ -539,7 +539,7 @@ fn run(args: Args) -> Result<()> {
|
||||
};
|
||||
|
||||
for idx in 0..num_samples {
|
||||
let timesteps = scheduler.timesteps();
|
||||
let timesteps = scheduler.timesteps().to_vec();
|
||||
let latents = match &init_latent_dist {
|
||||
Some(init_latent_dist) => {
|
||||
let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
|
||||
|
30
candle-examples/examples/xlm-roberta/Readme.md
Normal file
30
candle-examples/examples/xlm-roberta/Readme.md
Normal file
@ -0,0 +1,30 @@
|
||||
# candle-xlm-roberta
|
||||
|
||||
This example demonstrates how to use the XLM-RoBERTa model in Candle especially known for their use in reranking. It uses the `fill-mask` task to generate a word for a masked token. And a `reranker` task to rerank a list of documents for a given query.
|
||||
|
||||
## Usage
|
||||
|
||||
Fill Mask:
|
||||
```bash
|
||||
cargo run --example xlm-roberta --release -- --task fill-mask --model xlm-roberta-base
|
||||
```
|
||||
```markdown
|
||||
Sentence: 0 : Hello I'm a fashion model.
|
||||
Sentence: 1 : I'm a little boy.
|
||||
Sentence: 2 : I'm living in berlin.
|
||||
```
|
||||
|
||||
Reranker:
|
||||
```bash
|
||||
cargo run --example xlm-roberta --release -- --task reranker --model bge-reranker-base
|
||||
```
|
||||
```markdown
|
||||
Ranking Results:
|
||||
--------------------------------------------------------------------------------
|
||||
> Rank #4 | Score: 0.0001 | South Korea is a country in East Asia.
|
||||
> Rank #5 | Score: 0.0000 | There are forests in the mountains.
|
||||
> Rank #2 | Score: 0.7314 | Pandas look like bears.
|
||||
> Rank #3 | Score: 0.6948 | There are some animals with black and white fur.
|
||||
> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
|
||||
--------------------------------------------------------------------------------
|
||||
```
|
277
candle-examples/examples/xlm-roberta/main.rs
Normal file
277
candle-examples/examples/xlm-roberta/main.rs
Normal file
@ -0,0 +1,277 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::xlm_roberta::{
|
||||
Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
|
||||
};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
#[derive(Debug, Clone, ValueEnum)]
|
||||
enum Model {
|
||||
BgeRerankerBase,
|
||||
BgeRerankerLarge,
|
||||
BgeRerankerBaseV2,
|
||||
XLMRobertaBase,
|
||||
XLMRobertaLarge,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, ValueEnum)]
|
||||
enum Task {
|
||||
FillMask,
|
||||
Reranker,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long, default_value = "bge-reranker-base")]
|
||||
model: Model,
|
||||
|
||||
#[arg(long, default_value = "reranker")]
|
||||
task: Task,
|
||||
|
||||
// Path to the tokenizer file.
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
// Path to the weight files.
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
// Path to the config file.
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// When set, compute embeddings for this prompt.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let api = Api::new()?;
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => match args.task {
|
||||
Task::FillMask => match args.model {
|
||||
Model::XLMRobertaBase => "FacebookAI/xlm-roberta-base".to_string(),
|
||||
Model::XLMRobertaLarge => "FacebookAI/xlm-roberta-large".to_string(),
|
||||
_ => anyhow::bail!("BGE models are not supported for fill-mask task"),
|
||||
},
|
||||
Task::Reranker => match args.model {
|
||||
Model::BgeRerankerBase => "BAAI/bge-reranker-base".to_string(),
|
||||
Model::BgeRerankerLarge => "BAAI/bge-reranker-large".to_string(),
|
||||
Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
|
||||
_ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
|
||||
},
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
|
||||
let weights_filename = match args.weight_files {
|
||||
Some(files) => PathBuf::from(files),
|
||||
None => match repo.get("model.safetensors") {
|
||||
Ok(safetensors) => safetensors,
|
||||
Err(_) => match repo.get("pytorch_model.bin") {
|
||||
Ok(pytorch_model) => pytorch_model,
|
||||
Err(e) => {
|
||||
return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e)));
|
||||
}
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vb = if weights_filename.ends_with("model.safetensors") {
|
||||
unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F16, &device)
|
||||
.unwrap()
|
||||
}
|
||||
} else {
|
||||
println!("Loading weights from pytorch_model.bin");
|
||||
VarBuilder::from_pth(&weights_filename, candle::DType::F16, &device).unwrap()
|
||||
};
|
||||
tokenizer
|
||||
.with_padding(Some(PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
pad_id: config.pad_token_id,
|
||||
..Default::default()
|
||||
}))
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
|
||||
match args.task {
|
||||
Task::FillMask => {
|
||||
let prompt = vec![
|
||||
"Hello I'm a <mask> model.".to_string(),
|
||||
"I'm a <mask> boy.".to_string(),
|
||||
"I'm <mask> in berlin.".to_string(),
|
||||
];
|
||||
let model = XLMRobertaForMaskedLM::new(&config, vb)?;
|
||||
|
||||
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
|
||||
let attention_mask =
|
||||
get_attention_mask(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
|
||||
|
||||
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
|
||||
|
||||
let output = model
|
||||
.forward(
|
||||
&input_ids,
|
||||
&attention_mask,
|
||||
&token_type_ids,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)?
|
||||
.to_dtype(candle::DType::F32)?;
|
||||
|
||||
let max_outs = output.argmax(2)?;
|
||||
|
||||
let max_out = max_outs.to_vec2::<u32>()?;
|
||||
let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect();
|
||||
let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap();
|
||||
for (i, sentence) in decoded.iter().enumerate() {
|
||||
println!("Sentence: {} : {}", i + 1, sentence);
|
||||
}
|
||||
}
|
||||
Task::Reranker => {
|
||||
let query = "what is panda?".to_string();
|
||||
|
||||
let documents = ["South Korea is a country in East Asia.".to_string(),
|
||||
"There are forests in the mountains.".to_string(),
|
||||
"Pandas look like bears.".to_string(),
|
||||
"There are some animals with black and white fur.".to_string(),
|
||||
"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.".to_string()];
|
||||
|
||||
// create pairs of query and documents
|
||||
let pairs = documents
|
||||
.iter()
|
||||
.map(|doc| (query.clone(), doc.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
|
||||
let attention_mask =
|
||||
get_attention_mask(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
|
||||
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
|
||||
|
||||
let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
|
||||
|
||||
let output = model.forward(&input_ids, &attention_mask, &token_type_ids)?;
|
||||
let output = candle_nn::ops::sigmoid(&output)?.t().unwrap();
|
||||
let ranks = output
|
||||
.arg_sort_last_dim(false)?
|
||||
.to_vec2::<u32>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
println!("\nRanking Results:");
|
||||
println!("{:-<80}", "");
|
||||
documents.iter().enumerate().for_each(|(idx, doc)| {
|
||||
let rank = ranks.iter().position(|&r| r == idx as u32).unwrap();
|
||||
let score = output
|
||||
.get_on_dim(1, idx)
|
||||
.unwrap()
|
||||
.to_dtype(candle::DType::F32)
|
||||
.unwrap()
|
||||
.to_vec1::<f32>()
|
||||
.unwrap();
|
||||
println!("Rank #{:<2} | Score: {:.4} | {}", rank + 1, score[0], doc);
|
||||
});
|
||||
println!("{:-<80}", "");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TokenizeInput<'a> {
|
||||
Single(&'a [String]),
|
||||
Pairs(&'a [(String, String)]),
|
||||
}
|
||||
|
||||
pub fn tokenize_batch(
|
||||
tokenizer: &Tokenizer,
|
||||
input: TokenizeInput,
|
||||
device: &Device,
|
||||
) -> anyhow::Result<Tensor> {
|
||||
let tokens = match input {
|
||||
TokenizeInput::Single(text_batch) => tokenizer
|
||||
.encode_batch(text_batch.to_vec(), true)
|
||||
.map_err(E::msg)?,
|
||||
TokenizeInput::Pairs(pairs) => tokenizer
|
||||
.encode_batch(pairs.to_vec(), true)
|
||||
.map_err(E::msg)?,
|
||||
};
|
||||
|
||||
let token_ids = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
|
||||
Ok(Tensor::stack(&token_ids, 0)?)
|
||||
}
|
||||
|
||||
pub fn get_attention_mask(
|
||||
tokenizer: &Tokenizer,
|
||||
input: TokenizeInput,
|
||||
device: &Device,
|
||||
) -> anyhow::Result<Tensor> {
|
||||
let tokens = match input {
|
||||
TokenizeInput::Single(text_batch) => tokenizer
|
||||
.encode_batch(text_batch.to_vec(), true)
|
||||
.map_err(E::msg)?,
|
||||
TokenizeInput::Pairs(pairs) => tokenizer
|
||||
.encode_batch(pairs.to_vec(), true)
|
||||
.map_err(E::msg)?,
|
||||
};
|
||||
|
||||
let attention_mask = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_attention_mask().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
Ok(Tensor::stack(&attention_mask, 0)?)
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.8.1"
|
||||
version = "0.8.2"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.1" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.2" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -54,6 +54,7 @@ fn main() -> Result<()> {
|
||||
println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
|
||||
println!("cargo:rerun-if-changed=kernels/block_info.h");
|
||||
println!("cargo:rerun-if-changed=kernels/static_switch.h");
|
||||
println!("cargo:rerun-if-changed=kernels/hardware_info.h");
|
||||
let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
|
||||
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
|
||||
Err(_) =>
|
||||
|
Submodule candle-flash-attn/cutlass updated: 7d49e6c7e2...4c42f73fda
@ -18,8 +18,9 @@ struct BlockInfo {
|
||||
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
|
||||
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
||||
, leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
|
||||
, seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
|
||||
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
||||
{
|
||||
}
|
||||
|
||||
@ -30,13 +31,14 @@ struct BlockInfo {
|
||||
|
||||
template <typename index_t>
|
||||
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
||||
return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
|
||||
}
|
||||
|
||||
const int sum_s_q;
|
||||
const int sum_s_k;
|
||||
const int actual_seqlen_q;
|
||||
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
|
||||
const int leftpad_k;
|
||||
const int seqlen_k_cache;
|
||||
const int actual_seqlen_k;
|
||||
};
|
||||
|
@ -7,13 +7,7 @@
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
|
||||
// #ifdef OLD_GENERATOR_PATH
|
||||
// #include <ATen/CUDAGeneratorImpl.h>
|
||||
// #else
|
||||
// #include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
// #endif
|
||||
//
|
||||
// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
|
||||
// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
|
||||
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
// array of length b+1 holding starting offset of each sequence.
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
int * __restrict__ leftpad_k;
|
||||
|
||||
// If provided, the actual length of each k sequence.
|
||||
int * __restrict__ seqused_k;
|
||||
@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
||||
// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
||||
|
@ -53,9 +53,12 @@ extern "C" void run_mha(
|
||||
|
||||
int is_bf16,
|
||||
int is_causal,
|
||||
int unpadded_lse,
|
||||
|
||||
int window_size_left,
|
||||
int window_size_right
|
||||
int window_size_right,
|
||||
|
||||
float softcap
|
||||
) {
|
||||
Flash_fwd_params params;
|
||||
// Reset the parameters
|
||||
@ -99,8 +102,16 @@ extern "C" void run_mha(
|
||||
params.d_rounded = d_rounded;
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
if (softcap > 0.0) {
|
||||
params.softcap = softmax_scale / softcap;
|
||||
params.scale_softmax = softcap;
|
||||
params.scale_softmax_log2 = softcap * M_LOG2E;
|
||||
} else{
|
||||
// Remove potential NaN
|
||||
params.softcap = 0.0;
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
}
|
||||
|
||||
params.p_dropout = 1.; // probability to keep
|
||||
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
||||
@ -118,6 +129,7 @@ extern "C" void run_mha(
|
||||
|
||||
params.is_seqlens_k_cumulative = true;
|
||||
params.num_splits = 1;
|
||||
params.unpadded_lse = unpadded_lse;
|
||||
|
||||
cudaStream_t stream = 0; // Use the default stream.
|
||||
run_mha_fwd(params, stream);
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -4,6 +4,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// #include "philox_unpack.cuh" // For at::cuda::philox::unpack
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
@ -22,14 +24,6 @@ namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
|
||||
@ -328,7 +322,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
mask.template apply_mask<Is_causal, Is_even_MN>(
|
||||
@ -394,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
@ -691,7 +685,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
|
||||
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
|
||||
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
|
||||
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
|
||||
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);
|
||||
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
|
||||
make_stride(params.rotary_dim / 2, _1{}));
|
||||
@ -712,9 +706,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
// if (cute::thread(8, 0)) { print_tensor(gCos); }
|
||||
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
|
||||
|
||||
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
|
||||
// const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
|
||||
const index_t row_offset_knew = bidb * params.knew_batch_stride
|
||||
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
|
||||
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
|
||||
// const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
|
||||
const index_t row_offset_vnew = bidb * params.vnew_batch_stride
|
||||
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
|
||||
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
|
||||
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
|
||||
@ -792,7 +788,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
} else {
|
||||
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
|
||||
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
|
||||
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
|
||||
// We do this by setting the row stride of gCos / gSin to 0.
|
||||
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
|
||||
@ -886,7 +882,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
|
||||
@ -961,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
@ -1226,7 +1222,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
||||
constexpr int kBlockN = kNThreads / kBlockM;
|
||||
using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
|
||||
using GmemTiledCopyOaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
||||
|
@ -3,11 +3,11 @@
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
// #include <ATen/cuda/CUDAContext.h>
|
||||
// #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include "error.h"
|
||||
#include "static_switch.h"
|
||||
#include "hardware_info.h"
|
||||
#include "flash.h"
|
||||
#include "flash_fwd_kernel.h"
|
||||
|
||||
@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||
@ -205,7 +205,8 @@ inline bool cuda_is_sm8x() {
|
||||
template<typename T, bool Is_causal>
|
||||
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 96;
|
||||
bool is_sm8x = cuda_is_sm8x();
|
||||
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
||||
bool is_sm8x = cc_major == 8 && cc_minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
if (is_sm8x) {
|
||||
@ -228,7 +229,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<typename T, bool Is_causal>
|
||||
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
bool is_sm8x = cuda_is_sm8x();
|
||||
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
||||
bool is_sm8x = cc_major == 8 && cc_minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
@ -262,7 +264,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<typename T, bool Is_causal>
|
||||
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 160;
|
||||
bool is_sm8x = cuda_is_sm8x();
|
||||
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
||||
bool is_sm8x = cc_major == 8 && cc_minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// For A100, H100, 128 x 32 is the fastest.
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
|
42
candle-flash-attn/kernels/hardware_info.h
Normal file
42
candle-flash-attn/kernels/hardware_info.h
Normal file
@ -0,0 +1,42 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <cstdio>
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cuda_runtime.h"
|
||||
#endif
|
||||
|
||||
#define CHECK_CUDA(call) \
|
||||
do { \
|
||||
cudaError_t status_ = call; \
|
||||
if (status_ != cudaSuccess) { \
|
||||
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(status_)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
inline int get_current_device() {
|
||||
int device;
|
||||
CHECK_CUDA(cudaGetDevice(&device));
|
||||
return device;
|
||||
}
|
||||
|
||||
inline std::tuple<int, int> get_compute_capability(int device) {
|
||||
int capability_major, capability_minor;
|
||||
CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device));
|
||||
CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device));
|
||||
return {capability_major, capability_minor};
|
||||
}
|
||||
|
||||
inline int get_num_sm(int device) {
|
||||
int multiprocessor_count;
|
||||
CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
|
||||
return multiprocessor_count;
|
||||
}
|
@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
|
||||
using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
||||
|
||||
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
|
||||
@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopyOaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
|
||||
@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
|
||||
using GmemTiledCopyRotcossinCont = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
|
||||
};
|
||||
@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
|
||||
|
||||
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
|
||||
|
||||
using SmemLayoutQdOtransposed = decltype(
|
||||
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using SmemLayoutdKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
|
||||
|
||||
using SmemLayoutAtomdQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using SmemLayoutdQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdQ{},
|
||||
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
|
||||
|
||||
// Double buffer for sQ
|
||||
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
|
||||
@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopydO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydQ = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemLayoutAtomdQaccum = std::conditional_t<
|
||||
@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopydQaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomdQaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
|
||||
using GmemTiledCopydQaccumAtomicAdd = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
|
||||
Stride<_32, _1>>{},
|
||||
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
|
||||
|
@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(src_tensor); ++i) {
|
||||
dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
|
@ -42,9 +42,12 @@ extern "C" {
|
||||
|
||||
is_bf16: c_int,
|
||||
is_causal: c_int,
|
||||
unpadded_lse: c_int,
|
||||
|
||||
window_size_left: c_int,
|
||||
window_size_right: c_int,
|
||||
|
||||
softcap: f32,
|
||||
);
|
||||
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ pub struct FlashAttn {
|
||||
pub alibi_slopes: Option<Tensor>,
|
||||
pub window_size_left: Option<usize>,
|
||||
pub window_size_right: Option<usize>,
|
||||
pub softcap: Option<f32>,
|
||||
}
|
||||
|
||||
fn round_multiple(x: usize, m: usize) -> usize {
|
||||
@ -199,8 +200,10 @@ impl FlashAttn {
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_bf16 */ is_bf16,
|
||||
/* is_causal */ is_causal,
|
||||
/* upadded_lse */ 0,
|
||||
/* window_size_left */ window_size_left,
|
||||
/* window_size_right */ window_size_right,
|
||||
/* softcap */ self.softcap.unwrap_or(0f32),
|
||||
)
|
||||
}
|
||||
|
||||
@ -271,6 +274,7 @@ pub fn flash_attn(
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -308,6 +312,7 @@ pub fn flash_attn_windowed(
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -342,6 +347,7 @@ pub fn flash_attn_alibi(
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -381,6 +387,52 @@ pub fn flash_attn_alibi_windowed(
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
/// Flash-attention v2 layer.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors `k` and `v` with fewer heads
|
||||
/// than `q`. The number of heads in `k` and `v` must be divisible by the number of heads in `q`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `alibi_slopes` - Optional alibi slopes tensor with shape `(num_heads_q)`.
|
||||
/// * `softmax_scale` - Scaling factor for the softmax operation.
|
||||
/// * `window_size_left` - Optional limit on left attention to value tokens.
|
||||
/// * `window_size_right` - Optional limit on right attention to value tokens.
|
||||
/// * `softcap` - Gemma style softcap the attention logits before the softmax.
|
||||
///
|
||||
/// # Causal Mask
|
||||
///
|
||||
/// Setting `window_size_left=None` and `window_size_right=Some(0)` applies a causal mask to the result
|
||||
/// of `Q @ K^T`.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
pub fn flash_attn_alibi_windowed_softcap(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
alibi_slopes: Option<&Tensor>,
|
||||
softmax_scale: f32,
|
||||
window_size_left: Option<usize>,
|
||||
window_size_right: Option<usize>,
|
||||
softcap: f32,
|
||||
) -> Result<Tensor> {
|
||||
let op = FlashAttn {
|
||||
softmax_scale,
|
||||
alibi_slopes: alibi_slopes.cloned(),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: Some(softcap),
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -394,6 +446,7 @@ struct FlashAttnVarLen {
|
||||
pub alibi_slopes: Option<Tensor>,
|
||||
pub window_size_left: Option<usize>,
|
||||
pub window_size_right: Option<usize>,
|
||||
pub softcap: Option<f32>,
|
||||
}
|
||||
|
||||
impl FlashAttnVarLen {
|
||||
@ -466,7 +519,7 @@ impl FlashAttnVarLen {
|
||||
candle::bail!("the last dim of v must be contiguous {v_stride:?}")
|
||||
}
|
||||
|
||||
let (_total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
|
||||
let (total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
|
||||
let (total_k, num_heads_k, _head_size_og) = k_l.shape().dims3()?;
|
||||
let expected_kv = (total_k, num_heads_k, head_size_og);
|
||||
if expected_kv != k_l.shape().dims3()? {
|
||||
@ -549,9 +602,7 @@ impl FlashAttnVarLen {
|
||||
|
||||
let elem_count = out_shape.elem_count();
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
let softmax_lse = dev
|
||||
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
|
||||
.w()?;
|
||||
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?;
|
||||
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
@ -611,8 +662,10 @@ impl FlashAttnVarLen {
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_bf16 */ is_bf16,
|
||||
/* is_causal */ is_causal,
|
||||
/* upadded_lse */ 1,
|
||||
/* window_size_left */ window_size_left,
|
||||
/* window_size_right */ window_size_right,
|
||||
/* softcap */ self.softcap.unwrap_or(0.0),
|
||||
)
|
||||
}
|
||||
|
||||
@ -699,6 +752,7 @@ pub fn flash_attn_varlen(
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -752,6 +806,7 @@ pub fn flash_attn_varlen_windowed(
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -802,6 +857,7 @@ pub fn flash_attn_varlen_alibi(
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -857,6 +913,65 @@ pub fn flash_attn_varlen_alibi_windowed(
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Flash-attention v2 layer with variable-length batching.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `alibi_slopes` - Option, alibi slopes tensor with shape `(num_heads_q)`.
|
||||
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
|
||||
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
|
||||
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
|
||||
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
|
||||
/// * `window_size_left` - Option, limit left attention to value tokens.
|
||||
/// * `window_size_right` - Option, limit right attention to value tokens.
|
||||
/// * `softcap` - Gemma style softcap the attention logits before the softmax.
|
||||
///
|
||||
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
|
||||
/// `seqlen_1 + seqlen_2`, etc.
|
||||
///
|
||||
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
|
||||
///
|
||||
/// # Causal mask
|
||||
///
|
||||
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||
/// of `Q @ K^T`
|
||||
pub fn flash_attn_varlen_alibi_windowed_softcap(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
alibi_slopes: Option<&Tensor>,
|
||||
seqlens_q: &Tensor,
|
||||
seqlens_k: &Tensor,
|
||||
max_seqlen_q: usize,
|
||||
max_seqlen_k: usize,
|
||||
softmax_scale: f32,
|
||||
window_size_left: Option<usize>,
|
||||
window_size_right: Option<usize>,
|
||||
softcap: f32,
|
||||
) -> Result<Tensor> {
|
||||
let op = FlashAttnVarLen {
|
||||
softmax_scale,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqlens_q: seqlens_q.clone(),
|
||||
seqlens_k: seqlens_k.clone(),
|
||||
alibi_slopes: alibi_slopes.cloned(),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: Some(softcap),
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
@ -27,6 +27,20 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn fa_acausal_softcap(q: &Tensor, k: &Tensor, v: &Tensor, softcap: f32) -> Result<Tensor> {
|
||||
let in_dtype = q.dtype();
|
||||
let q = q.to_dtype(DType::F32)?;
|
||||
let k = k.to_dtype(DType::F32)?;
|
||||
let v = v.to_dtype(DType::F32)?;
|
||||
// let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
|
||||
let att = q.matmul(&k.t()?)?;
|
||||
let att = (softcap as f64 * ((att / softcap as f64)?.tanh())?)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_attn_acausal() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
@ -89,6 +103,44 @@ fn flash_attn_acausal() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_attn_acausal_softcap() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let q = Tensor::arange(0u32, 3 * 5 * 8, &device)?
|
||||
.to_dtype(DType::F16)?
|
||||
.reshape((1, 3, 5, 8))?;
|
||||
let k = (&q / 40.)?;
|
||||
let v = (&q / 50.)?;
|
||||
let q = (&q / 30.)?;
|
||||
let softcap = 5.0f32;
|
||||
|
||||
let ys1 = fa_acausal_softcap(&q, &k, &v, softcap.clone())?;
|
||||
let ys1 = ys1.i(0)?.to_dtype(DType::F32)?;
|
||||
let ys2 = {
|
||||
let q = q.transpose(1, 2)?;
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
candle_flash_attn::flash_attn_alibi_windowed_softcap(
|
||||
&q,
|
||||
&k,
|
||||
&v,
|
||||
None, // alibi_slopes //
|
||||
1.0, // softmax //
|
||||
None, // window_size_left //
|
||||
None, // window_size_right //
|
||||
softcap.clone(), // softcap //
|
||||
)?
|
||||
.transpose(1, 2)?
|
||||
};
|
||||
let ys2 = ys2.i(0)?.to_dtype(DType::F32)?;
|
||||
let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?;
|
||||
|
||||
assert_eq!(ys1.dims(), &[3, 5, 8]);
|
||||
assert_eq!(ys2.dims(), &[3, 5, 8]);
|
||||
assert!(diff.to_vec0::<f32>()?.abs() < 1e-3);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_attn_varlen() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.8.1"
|
||||
version = "0.8.2"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.8.1"
|
||||
version = "0.8.2"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.8.1"
|
||||
version = "0.8.2"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.8.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.8.1" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.8.2" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.8.2" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -3,7 +3,7 @@
|
||||
//! Functionality for modeling sampling strategies and logits processing in text generation
|
||||
//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
|
||||
//! and combinations thereof.
|
||||
use candle::{DType, Error, Result, Tensor};
|
||||
use candle::{Context, DType, Error, Result, Tensor};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
@ -45,7 +45,7 @@ impl LogitsProcessor {
|
||||
.enumerate()
|
||||
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap();
|
||||
.context("empty logits")?;
|
||||
Ok(next_token)
|
||||
}
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
//! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)
|
||||
//! - 💻 [GH](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py_
|
||||
|
||||
use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D};
|
||||
use candle::{Context, DType, IndexOp, Module, Result, Shape, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
use super::{Activation, EncoderConfig};
|
||||
@ -363,7 +363,7 @@ impl ChineseClipVisionTransformer {
|
||||
.apply(&self.pre_layer_norm)?;
|
||||
|
||||
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
|
||||
let encoder_outputs = result.last().unwrap();
|
||||
let encoder_outputs = result.last().context("no last")?;
|
||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
|
||||
Ok(result)
|
||||
|
@ -6,7 +6,7 @@
|
||||
//! https://github.com/openai/CLIP
|
||||
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||
|
||||
use candle::{IndexOp, Result, Shape, Tensor, D};
|
||||
use candle::{Context, IndexOp, Result, Shape, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
use nn::Conv2dConfig;
|
||||
@ -149,7 +149,7 @@ impl ClipVisionTransformer {
|
||||
.apply(&self.embeddings)?
|
||||
.apply(&self.pre_layer_norm)?;
|
||||
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
|
||||
let encoder_outputs = result.last().unwrap();
|
||||
let encoder_outputs = result.last().context("no last")?;
|
||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
|
||||
Ok(result)
|
||||
|
@ -4,6 +4,8 @@
|
||||
//! - ["Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data"](https://github.com/LiheYoung/Depth-Anything)
|
||||
//!
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use candle::D::Minus1;
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn::ops::Identity;
|
||||
@ -365,16 +367,18 @@ impl Scratch {
|
||||
|
||||
const NUM_CHANNELS: usize = 4;
|
||||
|
||||
pub struct DPTHead<'a> {
|
||||
conf: &'a DepthAnythingV2Config,
|
||||
pub struct DPTHead {
|
||||
projections: Vec<Conv2d>,
|
||||
resize_layers: Vec<Box<dyn Module>>,
|
||||
readout_projections: Vec<Sequential>,
|
||||
scratch: Scratch,
|
||||
use_class_token: bool,
|
||||
input_image_size: usize,
|
||||
target_patch_size: usize,
|
||||
}
|
||||
|
||||
impl<'a> DPTHead<'a> {
|
||||
pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
|
||||
impl DPTHead {
|
||||
pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
|
||||
let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
|
||||
for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
|
||||
projections.push(conv2d(
|
||||
@ -445,20 +449,22 @@ impl<'a> DPTHead<'a> {
|
||||
let scratch = Scratch::new(conf, vb.pp("scratch"))?;
|
||||
|
||||
Ok(Self {
|
||||
conf,
|
||||
projections,
|
||||
resize_layers,
|
||||
readout_projections,
|
||||
scratch,
|
||||
use_class_token: conf.use_class_token,
|
||||
input_image_size: conf.input_image_size,
|
||||
target_patch_size: conf.target_patch_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DPTHead<'_> {
|
||||
impl Module for DPTHead {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
|
||||
for i in 0..NUM_CHANNELS {
|
||||
let x = if self.conf.use_class_token {
|
||||
let x = if self.use_class_token {
|
||||
let x = xs.get(i)?.get(0)?;
|
||||
let class_token = xs.get(i)?.get(1)?;
|
||||
let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
|
||||
@ -473,8 +479,8 @@ impl Module for DPTHead<'_> {
|
||||
let x = x.permute((0, 2, 1))?.reshape((
|
||||
x_dims[0],
|
||||
x_dims[x_dims.len() - 1],
|
||||
self.conf.target_patch_size,
|
||||
self.conf.target_patch_size,
|
||||
self.target_patch_size,
|
||||
self.target_patch_size,
|
||||
))?;
|
||||
let x = self.projections[i].forward(&x)?;
|
||||
|
||||
@ -515,25 +521,25 @@ impl Module for DPTHead<'_> {
|
||||
|
||||
let out = self.scratch.output_conv1.forward(&path1)?;
|
||||
|
||||
let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?;
|
||||
let out = out.interpolate2d(self.input_image_size, self.input_image_size)?;
|
||||
|
||||
self.scratch.output_conv2.forward(&out)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DepthAnythingV2<'a> {
|
||||
pretrained: &'a DinoVisionTransformer,
|
||||
depth_head: DPTHead<'a>,
|
||||
conf: &'a DepthAnythingV2Config,
|
||||
pub struct DepthAnythingV2 {
|
||||
pretrained: Arc<DinoVisionTransformer>,
|
||||
depth_head: DPTHead,
|
||||
conf: DepthAnythingV2Config,
|
||||
}
|
||||
|
||||
impl<'a> DepthAnythingV2<'a> {
|
||||
impl DepthAnythingV2 {
|
||||
pub fn new(
|
||||
pretrained: &'a DinoVisionTransformer,
|
||||
conf: &'a DepthAnythingV2Config,
|
||||
pretrained: Arc<DinoVisionTransformer>,
|
||||
conf: DepthAnythingV2Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?;
|
||||
let depth_head = DPTHead::new(&conf, vb.pp("depth_head"))?;
|
||||
|
||||
Ok(Self {
|
||||
pretrained,
|
||||
@ -543,7 +549,7 @@ impl<'a> DepthAnythingV2<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DepthAnythingV2<'_> {
|
||||
impl Module for DepthAnythingV2 {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let features = self.pretrained.get_intermediate_layers(
|
||||
xs,
|
||||
|
@ -3,7 +3,7 @@
|
||||
//! See:
|
||||
//! - ["EfficientBERT: Progressively Searching Multilayer Perceptron Architectures for BERT"](https://arxiv.org/abs/2201.00462)
|
||||
//!
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle::{Context, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use nn::{Module, VarBuilder};
|
||||
|
||||
@ -289,7 +289,7 @@ impl EfficientNet {
|
||||
pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
|
||||
let f_p = p.pp("features");
|
||||
let first_in_c = configs[0].input_channels;
|
||||
let last_out_c = configs.last().unwrap().out_channels;
|
||||
let last_out_c = configs.last().context("no last")?.out_channels;
|
||||
let final_out_c = 4 * last_out_c;
|
||||
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
|
||||
let nconfigs = configs.len();
|
||||
|
@ -5,7 +5,7 @@
|
||||
//!
|
||||
//! Implementation based on [timm model](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py)
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle::{Context, DType, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax,
|
||||
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
|
||||
@ -178,7 +178,7 @@ fn squeeze_and_excitation(
|
||||
// based on the _fuse_bn_tensor method in timm
|
||||
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
||||
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
||||
let (gamma, beta) = bn.weight_and_bias().unwrap();
|
||||
let (gamma, beta) = bn.weight_and_bias().context("no weight-bias")?;
|
||||
let mu = bn.running_mean();
|
||||
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
||||
let gps = (gamma / sigma)?;
|
||||
|
@ -14,7 +14,7 @@ use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer}
|
||||
use crate::models::llama::{Cache, Llama};
|
||||
use crate::models::with_tracing::linear;
|
||||
|
||||
use candle::{bail, Device, IndexOp, Result, Tensor};
|
||||
use candle::{bail, Context, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
|
||||
use fancy_regex::Regex;
|
||||
use utils::get_anyres_image_grid_shape;
|
||||
@ -145,7 +145,7 @@ impl ClipVisionTower {
|
||||
let config = if config.is_none() {
|
||||
ClipVisionConfig::clip_vit_large_patch14_336()
|
||||
} else {
|
||||
config.clone().unwrap()
|
||||
config.clone().context("no config")?
|
||||
};
|
||||
let select_layer = match select_layer {
|
||||
-1 | -2 => select_layer,
|
||||
@ -262,14 +262,14 @@ impl LLaVA {
|
||||
let image_features = if mm_patch_merge_type == "flat" {
|
||||
image_features
|
||||
.iter()
|
||||
.map(|x| x.flatten(0, 1).unwrap())
|
||||
.collect::<Vec<Tensor>>()
|
||||
.map(|x| x.flatten(0, 1))
|
||||
.collect::<Result<Vec<Tensor>>>()?
|
||||
} else if mm_patch_merge_type.starts_with("spatial") {
|
||||
let mut new_image_features = Vec::new();
|
||||
for (image_idx, image_feature) in image_features.iter().enumerate() {
|
||||
let new_image_feature = if image_feature.dims()[0] > 1 {
|
||||
let base_image_feature = image_feature.get(0).unwrap();
|
||||
let patch_image_feature = image_feature.i(1..).unwrap();
|
||||
let base_image_feature = image_feature.get(0)?;
|
||||
let patch_image_feature = image_feature.i(1..)?;
|
||||
let height = self.clip_vision_tower.num_patches_per_side();
|
||||
let width = height;
|
||||
assert_eq!(height * width, base_image_feature.dims()[0]);
|
||||
@ -313,16 +313,12 @@ impl LLaVA {
|
||||
};
|
||||
Tensor::cat(&[base_image_feature, new_image_feature], 0)?
|
||||
} else {
|
||||
let new_image_feature = image_feature.get(0).unwrap();
|
||||
let new_image_feature = image_feature.get(0)?;
|
||||
if mm_patch_merge_type.contains("unpad") {
|
||||
Tensor::cat(
|
||||
&[
|
||||
new_image_feature,
|
||||
self.image_newline.clone().unsqueeze(0).unwrap(),
|
||||
],
|
||||
&[new_image_feature, self.image_newline.clone().unsqueeze(0)?],
|
||||
0,
|
||||
)
|
||||
.unwrap()
|
||||
)?
|
||||
} else {
|
||||
new_image_feature
|
||||
}
|
||||
|
@ -262,7 +262,8 @@ impl Attention {
|
||||
.contiguous()?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
|
@ -109,4 +109,5 @@ pub mod vit;
|
||||
pub mod whisper;
|
||||
pub mod with_tracing;
|
||||
pub mod wuerstchen;
|
||||
pub mod xlm_roberta;
|
||||
pub mod yi;
|
||||
|
@ -1,8 +1,8 @@
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};
|
||||
|
||||
fn default_act() -> candle_nn::Activation {
|
||||
candle_nn::Activation::Gelu
|
||||
candle_nn::Activation::Silu
|
||||
}
|
||||
|
||||
fn default_hidden_size() -> usize {
|
||||
@ -58,7 +58,7 @@ impl Config {
|
||||
num_attention_heads: 16,
|
||||
head_dim: None,
|
||||
// Default
|
||||
hidden_act: candle_nn::Activation::Gelu,
|
||||
hidden_act: candle_nn::Activation::Silu,
|
||||
}
|
||||
}
|
||||
|
||||
@ -104,6 +104,7 @@ impl Attention {
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
emb: &RotaryEmbedding,
|
||||
subsampled_positions: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (b, patches, _) = xs.dims3()?;
|
||||
@ -116,7 +117,8 @@ impl Attention {
|
||||
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?;
|
||||
let (query_states, key_states) =
|
||||
emb.apply_rotary_emb_qkv(&query_states, &key_states, subsampled_positions)?;
|
||||
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
@ -189,12 +191,16 @@ impl AttentionLayer {
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
emb: &RotaryEmbedding,
|
||||
subsampled_positions: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self
|
||||
.attention
|
||||
.forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?;
|
||||
let xs = self.attention.forward(
|
||||
&xs.apply(&self.attention_norm)?,
|
||||
emb,
|
||||
subsampled_positions,
|
||||
attention_mask,
|
||||
)?;
|
||||
let xs = (residual + xs)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
|
||||
@ -222,11 +228,12 @@ impl Transformer {
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
emb: &RotaryEmbedding,
|
||||
subsampled_positions: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, emb, attention_mask)?
|
||||
xs = layer.forward(&xs, emb, subsampled_positions, attention_mask)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
@ -270,10 +277,20 @@ impl RotaryEmbedding {
|
||||
Ok(Self { cos, sin })
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
subsampled_positions: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = &self.cos;
|
||||
let sin = &self.sin;
|
||||
let (cos, sin) = match subsampled_positions {
|
||||
None => (&self.cos, &self.sin),
|
||||
Some(pos) => (
|
||||
&self.cos.index_select(pos, 0)?,
|
||||
&self.sin.index_select(pos, 0)?,
|
||||
),
|
||||
};
|
||||
let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
@ -286,6 +303,7 @@ pub struct Model {
|
||||
ln_pre: RmsNorm,
|
||||
transformer: Transformer,
|
||||
patch_positional_embedding: RotaryEmbedding,
|
||||
max_image_width: u32,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -305,20 +323,44 @@ impl Model {
|
||||
let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
|
||||
let patch_positional_embedding =
|
||||
RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
|
||||
let max_image_width = (cfg.image_size / cfg.patch_size) as u32;
|
||||
Ok(Self {
|
||||
patch_conv,
|
||||
ln_pre,
|
||||
transformer,
|
||||
patch_positional_embedding,
|
||||
max_image_width,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn position_ids_in_meshgrid(
|
||||
&self,
|
||||
num_patches_h: usize,
|
||||
num_patches_w: usize,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let idx = Tensor::arange(0, num_patches_h as u32, device)?;
|
||||
let idy = Tensor::arange(0, num_patches_w as u32, device)?;
|
||||
let mesh = Tensor::meshgrid(&[idx, idy], false)?;
|
||||
let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?.flatten_all()?;
|
||||
Ok(ids)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Model {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let patch_embeds = xs.apply(&self.patch_conv)?;
|
||||
let subsampled_positions = Some(self.position_ids_in_meshgrid(
|
||||
patch_embeds.dim(2)?,
|
||||
patch_embeds.dim(3)?,
|
||||
patch_embeds.device(),
|
||||
)?);
|
||||
let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
|
||||
self.transformer
|
||||
.forward(&patch_embeds, &self.patch_positional_embedding, None)
|
||||
self.transformer.forward(
|
||||
&patch_embeds,
|
||||
&self.patch_positional_embedding,
|
||||
subsampled_positions.as_ref(),
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,7 @@
|
||||
//!
|
||||
|
||||
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
|
||||
use candle::{Module, ModuleT, Result, Tensor, D};
|
||||
use candle::{Context, Module, ModuleT, Result, Tensor, D};
|
||||
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
@ -633,7 +633,7 @@ impl ImageClassificationModel {
|
||||
impl Module for ImageClassificationModel {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let all_hidden_states = self.segformer.forward(x)?;
|
||||
let hidden_states = all_hidden_states.last().unwrap();
|
||||
let hidden_states = all_hidden_states.last().context("no last")?;
|
||||
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
let mean = hidden_states.mean(1)?;
|
||||
self.classifier.forward(&mean)
|
||||
|
@ -127,7 +127,7 @@ impl DDIMScheduler {
|
||||
|
||||
impl Scheduler for DDIMScheduler {
|
||||
/// Performs a backward step during inference.
|
||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
let timestep = if timestep >= self.alphas_cumprod.len() {
|
||||
timestep - 1
|
||||
} else {
|
||||
|
@ -171,7 +171,7 @@ impl Scheduler for EulerAncestralDiscreteScheduler {
|
||||
}
|
||||
|
||||
/// Performs a backward step during inference.
|
||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
let step_index = self
|
||||
.timesteps
|
||||
.iter()
|
||||
|
@ -47,6 +47,7 @@ pub mod resnet;
|
||||
pub mod schedulers;
|
||||
pub mod unet_2d;
|
||||
pub mod unet_2d_blocks;
|
||||
pub mod uni_pc;
|
||||
pub mod utils;
|
||||
pub mod vae;
|
||||
|
||||
|
@ -19,7 +19,7 @@ pub trait Scheduler {
|
||||
|
||||
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
|
||||
|
||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
|
||||
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
/// This represents how beta ranges from its minimum value to the maximum
|
||||
|
1005
candle-transformers/src/models/stable_diffusion/uni_pc.rs
Normal file
1005
candle-transformers/src/models/stable_diffusion/uni_pc.rs
Normal file
File diff suppressed because it is too large
Load Diff
@ -204,6 +204,7 @@ pub fn log_mel_spectrogram_<T: Float>(
|
||||
|
||||
// ensure that the number of threads is even and less than 12
|
||||
let n_threads = std::cmp::min(get_num_threads() - get_num_threads() % 2, 12);
|
||||
let n_threads = std::cmp::max(n_threads, 2);
|
||||
|
||||
let hann = Arc::new(hann);
|
||||
let samples = Arc::new(samples);
|
||||
|
545
candle-transformers/src/models/xlm_roberta.rs
Normal file
545
candle-transformers/src/models/xlm_roberta.rs
Normal file
@ -0,0 +1,545 @@
|
||||
use crate::models::with_tracing::{linear, Linear};
|
||||
use candle::{DType, Module, Result, Tensor};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, ops::softmax_last_dim, Activation, Embedding, LayerNorm, VarBuilder,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub layer_norm_eps: f64,
|
||||
pub attention_probs_dropout_prob: f32,
|
||||
pub hidden_dropout_prob: f32,
|
||||
pub num_attention_heads: usize,
|
||||
pub position_embedding_type: String,
|
||||
pub intermediate_size: usize,
|
||||
pub hidden_act: Activation,
|
||||
pub num_hidden_layers: usize,
|
||||
pub vocab_size: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub type_vocab_size: usize,
|
||||
pub pad_token_id: u32,
|
||||
}
|
||||
|
||||
struct XLMRobertaEmbeddings {
|
||||
word_embeddings: Embedding,
|
||||
position_embeddings: Option<Embedding>,
|
||||
token_type_embeddings: Embedding,
|
||||
layer_norm: LayerNorm,
|
||||
padding_idx: u32,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl XLMRobertaEmbeddings {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let word_embeddings = embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
vb.pp("word_embeddings"),
|
||||
)?;
|
||||
let position_embeddings = embedding(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
vb.pp("position_embeddings"),
|
||||
)?;
|
||||
let token_type_embeddings = embedding(
|
||||
config.type_vocab_size,
|
||||
config.hidden_size,
|
||||
vb.pp("token_type_embeddings"),
|
||||
)?;
|
||||
let layer_norm = layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("LayerNorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
position_embeddings: Some(position_embeddings),
|
||||
token_type_embeddings,
|
||||
layer_norm,
|
||||
padding_idx: config.pad_token_id,
|
||||
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_bsize, _) = input_ids.dims2()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||
if let Some(position_embeddings) = &self.position_embeddings {
|
||||
let mask = input_ids
|
||||
.ne(self.padding_idx)?
|
||||
.to_dtype(input_embeddings.dtype())?;
|
||||
let cumsum = mask.cumsum(1)?;
|
||||
let position_ids = (cumsum * mask)?
|
||||
.broadcast_add(
|
||||
&Tensor::try_from(self.padding_idx)?
|
||||
.to_dtype(input_embeddings.dtype())?
|
||||
.to_device(input_embeddings.device())?,
|
||||
)?
|
||||
.to_dtype(candle::DType::U32)?;
|
||||
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?;
|
||||
}
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaSelfAttention {
|
||||
num_attention_heads: usize,
|
||||
attention_head_size: usize,
|
||||
all_head_size: usize,
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
}
|
||||
|
||||
impl XLMRobertaSelfAttention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let all_head_size = cfg.num_attention_heads * attention_head_size;
|
||||
Ok(Self {
|
||||
num_attention_heads: cfg.num_attention_heads,
|
||||
attention_head_size,
|
||||
all_head_size,
|
||||
query: linear(cfg.hidden_size, all_head_size, vb.pp("query"))?,
|
||||
key: linear(cfg.hidden_size, all_head_size, vb.pp("key"))?,
|
||||
value: linear(cfg.hidden_size, all_head_size, vb.pp("value"))?,
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let mut new_x_shape = x.dims().to_vec();
|
||||
new_x_shape[2] = self.num_attention_heads;
|
||||
new_x_shape.push(self.attention_head_size);
|
||||
let x = x.reshape(new_x_shape)?;
|
||||
x.permute((0, 2, 1, 3))?.contiguous()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
attention_mask: &Tensor,
|
||||
past_key_value: Option<(&Tensor, &Tensor)>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let mixed_query_layer = self.query.forward(hidden_states)?;
|
||||
let is_cross_attention = encoder_hidden_states.is_some();
|
||||
let (key_layer, value_layer, attention_mask) = if is_cross_attention
|
||||
&& past_key_value.is_some()
|
||||
{
|
||||
let key_layer = past_key_value.unwrap().0.clone();
|
||||
let value_layer = past_key_value.unwrap().1.clone();
|
||||
let attention_mask = encoder_attention_mask.unwrap().clone();
|
||||
(key_layer, value_layer, Some(attention_mask))
|
||||
} else if is_cross_attention {
|
||||
let key_layer =
|
||||
self.transpose_for_scores(&self.key.forward(encoder_hidden_states.unwrap())?)?;
|
||||
let value_layer =
|
||||
self.transpose_for_scores(&self.value.forward(encoder_hidden_states.unwrap())?)?;
|
||||
let attention_mask = encoder_attention_mask.unwrap();
|
||||
(key_layer, value_layer, Some(attention_mask.clone()))
|
||||
} else if past_key_value.is_some() {
|
||||
let mut key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
|
||||
let mut value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
|
||||
key_layer = Tensor::cat(
|
||||
&[
|
||||
past_key_value.clone().as_ref().unwrap().0.clone(),
|
||||
key_layer,
|
||||
],
|
||||
2,
|
||||
)?;
|
||||
value_layer = Tensor::cat(
|
||||
&[past_key_value.as_ref().unwrap().1.clone(), value_layer],
|
||||
2,
|
||||
)?;
|
||||
(key_layer, value_layer, Some(attention_mask.clone()))
|
||||
} else {
|
||||
let key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
|
||||
let value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
|
||||
(key_layer, value_layer, Some(attention_mask.clone()))
|
||||
};
|
||||
|
||||
let query_layer = self.transpose_for_scores(&mixed_query_layer)?;
|
||||
let mut attention_scores = query_layer.matmul(&key_layer.transpose(2, 3)?)?;
|
||||
let scale = 1f64 / f64::sqrt(self.attention_head_size as f64);
|
||||
|
||||
attention_scores = (attention_scores * scale)?;
|
||||
attention_scores = match attention_mask {
|
||||
None => attention_scores,
|
||||
Some(mask) => {
|
||||
attention_scores.broadcast_add(&mask.to_dtype(attention_scores.dtype())?)?
|
||||
}
|
||||
};
|
||||
let attention_probs = softmax_last_dim(&attention_scores)?;
|
||||
|
||||
let context_layer = attention_probs
|
||||
.matmul(&value_layer)?
|
||||
.permute((0, 2, 1, 3))?
|
||||
.contiguous()?;
|
||||
let mut new_context_layer_shape =
|
||||
context_layer.dims()[..context_layer.dims().len() - 2].to_vec();
|
||||
new_context_layer_shape.push(self.all_head_size);
|
||||
let context_layer = context_layer.reshape(new_context_layer_shape)?;
|
||||
|
||||
Ok(context_layer)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaSelfOutput {
|
||||
dense: Linear,
|
||||
layernorm: LayerNorm,
|
||||
}
|
||||
|
||||
impl XLMRobertaSelfOutput {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let layernorm =
|
||||
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||
Ok(Self { dense, layernorm })
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaAttention {
|
||||
output: XLMRobertaSelfOutput,
|
||||
self_attention: XLMRobertaSelfAttention,
|
||||
}
|
||||
|
||||
impl XLMRobertaAttention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let output = XLMRobertaSelfOutput::new(cfg, vb.pp("output"))?;
|
||||
let self_attention = XLMRobertaSelfAttention::new(cfg, vb.pp("self"))?;
|
||||
Ok(Self {
|
||||
output,
|
||||
self_attention,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
past_key_value: Option<(&Tensor, &Tensor)>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let self_outputs = self.self_attention.forward(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
past_key_value,
|
||||
encoder_attention_mask,
|
||||
)?;
|
||||
let attention_output = self.output.forward(&self_outputs, hidden_states)?;
|
||||
Ok((attention_output, self_outputs))
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaOutput {
|
||||
dense: Linear,
|
||||
layernorm: LayerNorm,
|
||||
}
|
||||
|
||||
impl XLMRobertaOutput {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let layernorm =
|
||||
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||
Ok(Self { dense, layernorm })
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaIntermediate {
|
||||
dense: Linear,
|
||||
intermediate_act_fn: Activation,
|
||||
}
|
||||
|
||||
impl XLMRobertaIntermediate {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
|
||||
let intermediate_act_fn = cfg.hidden_act;
|
||||
Ok(Self {
|
||||
dense,
|
||||
intermediate_act_fn,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.intermediate_act_fn.forward(&hidden_states)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaLayer {
|
||||
attention: XLMRobertaAttention,
|
||||
intermediate: XLMRobertaIntermediate,
|
||||
output: XLMRobertaOutput,
|
||||
}
|
||||
|
||||
impl XLMRobertaLayer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let attention = XLMRobertaAttention::new(cfg, vb.pp("attention"))?;
|
||||
let intermediate = XLMRobertaIntermediate::new(cfg, vb.pp("intermediate"))?;
|
||||
let output = XLMRobertaOutput::new(cfg, vb.pp("output"))?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
intermediate,
|
||||
output,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
past_key_value: Option<(&Tensor, &Tensor)>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let self_attention_outputs = self.attention.forward(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
)?;
|
||||
let attention_output = self_attention_outputs.0;
|
||||
let outputs = self_attention_outputs.1;
|
||||
let intermediate_output = self.intermediate.forward(&attention_output)?;
|
||||
let layer_output = self
|
||||
.output
|
||||
.forward(&intermediate_output, &attention_output)?;
|
||||
Ok((layer_output, outputs))
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaEncoder {
|
||||
layers: Vec<XLMRobertaLayer>,
|
||||
}
|
||||
|
||||
impl XLMRobertaEncoder {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let layers = (0..cfg.num_hidden_layers)
|
||||
.map(|i| XLMRobertaLayer::new(cfg, vb.pp(format!("layer.{}", i))))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
past_key_value: Option<(&Tensor, &Tensor)>,
|
||||
) -> Result<Tensor> {
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
for layer_module in self.layers.iter() {
|
||||
let layer_outputs = layer_module.forward(
|
||||
&hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
)?;
|
||||
hidden_states = layer_outputs.0;
|
||||
}
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct XLMRobertaModel {
|
||||
encoder: XLMRobertaEncoder,
|
||||
embeddings: XLMRobertaEmbeddings,
|
||||
}
|
||||
|
||||
impl XLMRobertaModel {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let encoder = XLMRobertaEncoder::new(cfg, vb.pp("encoder"))?;
|
||||
let embeddings = XLMRobertaEmbeddings::load(vb.pp("embeddings"), cfg)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
embeddings,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
token_type_ids: &Tensor,
|
||||
past_key_value: Option<(&Tensor, &Tensor)>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let hidden_states = self.embeddings.forward(input_ids, token_type_ids)?;
|
||||
let attention_mask = prepare_4d_attention_mask(attention_mask, DType::F32, None)?
|
||||
.to_device(hidden_states.device())?;
|
||||
let hidden_states = self.encoder.forward(
|
||||
&hidden_states,
|
||||
&attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaLMHead {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl XLMRobertaLMHead {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let layer_norm =
|
||||
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layer_norm"))?;
|
||||
Ok(Self { dense, layer_norm })
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, shared_embeddings: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = candle_nn::Activation::Gelu.forward(&hidden_states)?;
|
||||
let hidden_states = self.layer_norm.forward(&hidden_states)?;
|
||||
let hidden_states = hidden_states.broadcast_matmul(shared_embeddings)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct XLMRobertaForMaskedLM {
|
||||
roberta: XLMRobertaModel,
|
||||
lm_head: XLMRobertaLMHead,
|
||||
}
|
||||
|
||||
impl XLMRobertaForMaskedLM {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
|
||||
let lm_head = XLMRobertaLMHead::new(cfg, vb.pp("lm_head"))?;
|
||||
Ok(Self { roberta, lm_head })
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
token_type_ids: &Tensor,
|
||||
past_key_value: Option<(&Tensor, &Tensor)>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let hidden_states = self.roberta.forward(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_value,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)?;
|
||||
let lm_logits = self.lm_head.forward(
|
||||
&hidden_states,
|
||||
&self
|
||||
.roberta
|
||||
.embeddings
|
||||
.word_embeddings
|
||||
.embeddings()
|
||||
.t()?
|
||||
.unsqueeze(0)?,
|
||||
)?;
|
||||
Ok(lm_logits)
|
||||
}
|
||||
}
|
||||
|
||||
struct XLMRobertaClassificationHead {
|
||||
dense: Linear,
|
||||
out_proj: Linear,
|
||||
}
|
||||
|
||||
impl XLMRobertaClassificationHead {
|
||||
fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let out_proj = linear(cfg.hidden_size, num_labels, vb.pp("out_proj"))?;
|
||||
Ok(Self { dense, out_proj })
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
|
||||
let hidden_states = self.dense.forward(&cls_states)?;
|
||||
let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
|
||||
let hidden_states = self.out_proj.forward(&hidden_states)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct XLMRobertaForSequenceClassification {
|
||||
roberta: XLMRobertaModel,
|
||||
classifier: XLMRobertaClassificationHead,
|
||||
}
|
||||
|
||||
impl XLMRobertaForSequenceClassification {
|
||||
pub fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
|
||||
let classifier = XLMRobertaClassificationHead::new(num_labels, cfg, vb.pp("classifier"))?;
|
||||
Ok(Self {
|
||||
roberta,
|
||||
classifier,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
token_type_ids: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let hidden_states =
|
||||
self.roberta
|
||||
.forward(input_ids, attention_mask, token_type_ids, None, None, None)?;
|
||||
self.classifier.forward(&hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_4d_attention_mask(
|
||||
mask: &Tensor,
|
||||
dtype: DType,
|
||||
tgt_len: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let bsz = mask.dim(0)?;
|
||||
let src_len = mask.dim(1)?;
|
||||
let tgt_len = tgt_len.unwrap_or(src_len);
|
||||
|
||||
let expanded_mask = mask
|
||||
.unsqueeze(1)?
|
||||
.unsqueeze(2)?
|
||||
.expand((bsz, 1, tgt_len, src_len))?
|
||||
.to_dtype(dtype)?;
|
||||
|
||||
let inverted_mask = (1.0 - expanded_mask)?;
|
||||
|
||||
(inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
fn get_dtype_min_val(dtype: DType) -> f64 {
|
||||
match dtype {
|
||||
DType::F32 => f32::MIN as f64,
|
||||
DType::F64 => f64::MIN,
|
||||
_ => panic!("Unsupported data type"),
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user