mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
BF16 support for flash-attn. (#737)
This commit is contained in:
@ -4,7 +4,7 @@ use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle::cuda_backend::WrapErr;
|
||||
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
||||
use half::f16;
|
||||
use half::{bf16, f16};
|
||||
|
||||
pub struct FlashAttn {
|
||||
pub softmax_scale: f32,
|
||||
@ -15,24 +15,10 @@ fn round_multiple(x: usize, m: usize) -> usize {
|
||||
(x + m - 1) / m * m
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for FlashAttn {
|
||||
fn name(&self) -> &'static str {
|
||||
"flash-attn"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
candle::bail!("no cpu support for flash-attn")
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
impl FlashAttn {
|
||||
fn cuda_fwd_t<
|
||||
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||
>(
|
||||
&self,
|
||||
q: &candle::CudaStorage,
|
||||
q_l: &Layout,
|
||||
@ -46,9 +32,9 @@ impl candle::CustomOp3 for FlashAttn {
|
||||
let out_shape = q_l.shape().clone();
|
||||
let out_l = Layout::contiguous(&out_shape);
|
||||
|
||||
let q = q.as_cuda_slice::<f16>()?;
|
||||
let k = k.as_cuda_slice::<f16>()?;
|
||||
let v = v.as_cuda_slice::<f16>()?;
|
||||
let q = q.as_cuda_slice::<T>()?;
|
||||
let k = k.as_cuda_slice::<T>()?;
|
||||
let v = v.as_cuda_slice::<T>()?;
|
||||
let q = q.slice(q_l.start_offset()..);
|
||||
let k = k.slice(k_l.start_offset()..);
|
||||
let v = v.slice(v_l.start_offset()..);
|
||||
@ -104,7 +90,7 @@ impl candle::CustomOp3 for FlashAttn {
|
||||
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
let elem_count = out_shape.elem_count();
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
||||
|
||||
let causal = if self.causal { 1 } else { 0 };
|
||||
@ -155,6 +141,40 @@ impl candle::CustomOp3 for FlashAttn {
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for FlashAttn {
|
||||
fn name(&self) -> &'static str {
|
||||
"flash-attn"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
candle::bail!("no cpu support for flash-attn")
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
q: &candle::CudaStorage,
|
||||
q_l: &Layout,
|
||||
k: &candle::CudaStorage,
|
||||
k_l: &Layout,
|
||||
v: &candle::CudaStorage,
|
||||
v_l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
match q.dtype() {
|
||||
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
|
||||
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
|
||||
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Flash-attention v2 layer.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
@ -191,24 +211,10 @@ struct FlashAttnVarLen {
|
||||
seqlens_k: Tensor,
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for FlashAttnVarLen {
|
||||
fn name(&self) -> &'static str {
|
||||
"flash-attn-varlen"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
candle::bail!("no cpu support for flash-attn")
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
impl FlashAttnVarLen {
|
||||
fn cuda_fwd_t<
|
||||
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||
>(
|
||||
&self,
|
||||
q: &candle::CudaStorage,
|
||||
q_l: &Layout,
|
||||
@ -364,6 +370,40 @@ impl candle::CustomOp3 for FlashAttnVarLen {
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for FlashAttnVarLen {
|
||||
fn name(&self) -> &'static str {
|
||||
"flash-attn-varlen"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
_: &CpuStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
candle::bail!("no cpu support for flash-attn")
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
q: &candle::CudaStorage,
|
||||
q_l: &Layout,
|
||||
k: &candle::CudaStorage,
|
||||
k_l: &Layout,
|
||||
v: &candle::CudaStorage,
|
||||
v_l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
match q.dtype() {
|
||||
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
|
||||
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
|
||||
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Flash-attention v2 layer with variable-length batching.
|
||||
///
|
||||
|
Reference in New Issue
Block a user