mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab. * More flash attn. * Set up the flash attn parameters. * Get things to compile locally. * Move the flash attention files in a different directory. * Build the static C library with nvcc. * Add more flash attention. * Update the build part. * Better caching. * Exclude flash attention from the default workspace. * Put flash-attn behind a feature gate. * Get the flash attn kernel to run. * Move the flags to a more appropriate place. * Enable flash attention in llama. * Use flash attention in llama.
This commit is contained in:
35
candle-flash-attn/src/ffi.rs
Normal file
35
candle-flash-attn/src/ffi.rs
Normal file
@ -0,0 +1,35 @@
|
||||
use core::ffi::{c_int, c_void};
|
||||
|
||||
extern "C" {
|
||||
pub(crate) fn run_mha(
|
||||
q_ptr: *const c_void,
|
||||
k_ptr: *const c_void,
|
||||
v_ptr: *const c_void,
|
||||
o_ptr: *const c_void,
|
||||
|
||||
q_batch_stride: u32,
|
||||
k_batch_stride: u32,
|
||||
v_batch_stride: u32,
|
||||
q_row_stride: u32,
|
||||
k_row_stride: u32,
|
||||
v_row_stride: u32,
|
||||
q_head_stride: u32,
|
||||
k_head_stride: u32,
|
||||
v_head_stride: u32,
|
||||
|
||||
b: u32,
|
||||
h: u32,
|
||||
h_k: u32,
|
||||
d: u32,
|
||||
d_rounded: u32,
|
||||
softmax_scale: f32,
|
||||
|
||||
seqlen_q: u32,
|
||||
seqlen_k: u32,
|
||||
seqlen_q_rounded: u32,
|
||||
seqlen_k_rounded: u32,
|
||||
|
||||
is_causal: c_int,
|
||||
);
|
||||
|
||||
}
|
59
candle-flash-attn/src/lib.rs
Normal file
59
candle-flash-attn/src/lib.rs
Normal file
@ -0,0 +1,59 @@
|
||||
mod ffi;
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle::cuda_backend::WrapErr;
|
||||
use candle::{CpuStorage, Error, Layout, Result, Shape};
|
||||
use half::f16;
|
||||
|
||||
pub struct FlashHdim32Sm80;
|
||||
|
||||
impl candle::CustomOp3 for FlashHdim32Sm80 {
|
||||
fn name(&self) -> &'static str {
|
||||
"flash-hdim32-sm80"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
Err(Error::Wrapped("no cpu support for flash-attn".into()))
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
q: &candle::CudaStorage,
|
||||
_q_l: &Layout,
|
||||
k: &candle::CudaStorage,
|
||||
_k_l: &Layout,
|
||||
v: &candle::CudaStorage,
|
||||
_v_l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
let dev = q.device();
|
||||
let out_shape = Shape::from(&[1]);
|
||||
let q = q.as_cuda_slice::<f16>()?;
|
||||
let k = k.as_cuda_slice::<f16>()?;
|
||||
let v = v.as_cuda_slice::<f16>()?;
|
||||
let elem_count = out_shape.elem_count();
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
||||
ffi::run_mha(
|
||||
q_ptr, k_ptr, v_ptr, dst_ptr, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1.0, 1, 1,
|
||||
1, 1, 1,
|
||||
)
|
||||
}
|
||||
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev.clone());
|
||||
Ok((dst, out_shape))
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user