mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
BF16 support for flash-attn. (#737)
This commit is contained in:
@ -4,7 +4,7 @@ use candle::backend::BackendStorage;
|
|||||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
use candle::cuda_backend::WrapErr;
|
use candle::cuda_backend::WrapErr;
|
||||||
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
||||||
use half::f16;
|
use half::{bf16, f16};
|
||||||
|
|
||||||
pub struct FlashAttn {
|
pub struct FlashAttn {
|
||||||
pub softmax_scale: f32,
|
pub softmax_scale: f32,
|
||||||
@ -15,24 +15,10 @@ fn round_multiple(x: usize, m: usize) -> usize {
|
|||||||
(x + m - 1) / m * m
|
(x + m - 1) / m * m
|
||||||
}
|
}
|
||||||
|
|
||||||
impl candle::CustomOp3 for FlashAttn {
|
impl FlashAttn {
|
||||||
fn name(&self) -> &'static str {
|
fn cuda_fwd_t<
|
||||||
"flash-attn"
|
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||||
}
|
>(
|
||||||
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
_: &CpuStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CpuStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CpuStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(CpuStorage, Shape)> {
|
|
||||||
candle::bail!("no cpu support for flash-attn")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
&self,
|
||||||
q: &candle::CudaStorage,
|
q: &candle::CudaStorage,
|
||||||
q_l: &Layout,
|
q_l: &Layout,
|
||||||
@ -46,9 +32,9 @@ impl candle::CustomOp3 for FlashAttn {
|
|||||||
let out_shape = q_l.shape().clone();
|
let out_shape = q_l.shape().clone();
|
||||||
let out_l = Layout::contiguous(&out_shape);
|
let out_l = Layout::contiguous(&out_shape);
|
||||||
|
|
||||||
let q = q.as_cuda_slice::<f16>()?;
|
let q = q.as_cuda_slice::<T>()?;
|
||||||
let k = k.as_cuda_slice::<f16>()?;
|
let k = k.as_cuda_slice::<T>()?;
|
||||||
let v = v.as_cuda_slice::<f16>()?;
|
let v = v.as_cuda_slice::<T>()?;
|
||||||
let q = q.slice(q_l.start_offset()..);
|
let q = q.slice(q_l.start_offset()..);
|
||||||
let k = k.slice(k_l.start_offset()..);
|
let k = k.slice(k_l.start_offset()..);
|
||||||
let v = v.slice(v_l.start_offset()..);
|
let v = v.slice(v_l.start_offset()..);
|
||||||
@ -104,7 +90,7 @@ impl candle::CustomOp3 for FlashAttn {
|
|||||||
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||||
|
|
||||||
let elem_count = out_shape.elem_count();
|
let elem_count = out_shape.elem_count();
|
||||||
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
||||||
|
|
||||||
let causal = if self.causal { 1 } else { 0 };
|
let causal = if self.causal { 1 } else { 0 };
|
||||||
@ -155,6 +141,40 @@ impl candle::CustomOp3 for FlashAttn {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl candle::CustomOp3 for FlashAttn {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"flash-attn"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
_: &CpuStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CpuStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CpuStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)> {
|
||||||
|
candle::bail!("no cpu support for flash-attn")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
q: &candle::CudaStorage,
|
||||||
|
q_l: &Layout,
|
||||||
|
k: &candle::CudaStorage,
|
||||||
|
k_l: &Layout,
|
||||||
|
v: &candle::CudaStorage,
|
||||||
|
v_l: &Layout,
|
||||||
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
|
match q.dtype() {
|
||||||
|
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
|
||||||
|
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
|
||||||
|
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Flash-attention v2 layer.
|
/// Flash-attention v2 layer.
|
||||||
///
|
///
|
||||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||||
@ -191,24 +211,10 @@ struct FlashAttnVarLen {
|
|||||||
seqlens_k: Tensor,
|
seqlens_k: Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl candle::CustomOp3 for FlashAttnVarLen {
|
impl FlashAttnVarLen {
|
||||||
fn name(&self) -> &'static str {
|
fn cuda_fwd_t<
|
||||||
"flash-attn-varlen"
|
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||||
}
|
>(
|
||||||
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
_: &CpuStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CpuStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CpuStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(CpuStorage, Shape)> {
|
|
||||||
candle::bail!("no cpu support for flash-attn")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
&self,
|
||||||
q: &candle::CudaStorage,
|
q: &candle::CudaStorage,
|
||||||
q_l: &Layout,
|
q_l: &Layout,
|
||||||
@ -364,6 +370,40 @@ impl candle::CustomOp3 for FlashAttnVarLen {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl candle::CustomOp3 for FlashAttnVarLen {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"flash-attn-varlen"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
_: &CpuStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CpuStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CpuStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)> {
|
||||||
|
candle::bail!("no cpu support for flash-attn")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
q: &candle::CudaStorage,
|
||||||
|
q_l: &Layout,
|
||||||
|
k: &candle::CudaStorage,
|
||||||
|
k_l: &Layout,
|
||||||
|
v: &candle::CudaStorage,
|
||||||
|
v_l: &Layout,
|
||||||
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
|
match q.dtype() {
|
||||||
|
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
|
||||||
|
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
|
||||||
|
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
/// Flash-attention v2 layer with variable-length batching.
|
/// Flash-attention v2 layer with variable-length batching.
|
||||||
///
|
///
|
||||||
|
Reference in New Issue
Block a user