Add flash attention (#241)

* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab.

* More flash attn.

* Set up the flash attn parameters.

* Get things to compile locally.

* Move the flash attention files in a different directory.

* Build the static C library with nvcc.

* Add more flash attention.

* Update the build part.

* Better caching.

* Exclude flash attention from the default workspace.

* Put flash-attn behind a feature gate.

* Get the flash attn kernel to run.

* Move the flags to a more appropriate place.

* Enable flash attention in llama.

* Use flash attention in llama.
This commit is contained in:
Laurent Mazare
2023-07-26 07:48:10 +01:00
committed by GitHub
parent c97d51243c
commit d9f9c859af
22 changed files with 2699 additions and 9 deletions

View File

@ -0,0 +1,59 @@
mod ffi;
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Error, Layout, Result, Shape};
use half::f16;
pub struct FlashHdim32Sm80;
impl candle::CustomOp3 for FlashHdim32Sm80 {
fn name(&self) -> &'static str {
"flash-hdim32-sm80"
}
fn cpu_fwd(
&self,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
Err(Error::Wrapped("no cpu support for flash-attn".into()))
}
fn cuda_fwd(
&self,
q: &candle::CudaStorage,
_q_l: &Layout,
k: &candle::CudaStorage,
_k_l: &Layout,
v: &candle::CudaStorage,
_v_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
let dev = q.device();
let out_shape = Shape::from(&[1]);
let q = q.as_cuda_slice::<f16>()?;
let k = k.as_cuda_slice::<f16>()?;
let v = v.as_cuda_slice::<f16>()?;
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
ffi::run_mha(
q_ptr, k_ptr, v_ptr, dst_ptr, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1.0, 1, 1,
1, 1, 1,
)
}
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev.clone());
Ok((dst, out_shape))
}
}