BF16 support for flash-attn. (#737)

This commit is contained in:
Laurent Mazare
2023-09-04 17:35:43 +02:00
committed by GitHub
parent 0d00c06a83
commit f80fd44201

View File

@ -4,7 +4,7 @@ use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use half::f16;
use half::{bf16, f16};
pub struct FlashAttn {
pub softmax_scale: f32,
@ -15,24 +15,10 @@ fn round_multiple(x: usize, m: usize) -> usize {
(x + m - 1) / m * m
}
impl candle::CustomOp3 for FlashAttn {
fn name(&self) -> &'static str {
"flash-attn"
}
fn cpu_fwd(
&self,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for flash-attn")
}
fn cuda_fwd(
impl FlashAttn {
fn cuda_fwd_t<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
&self,
q: &candle::CudaStorage,
q_l: &Layout,
@ -46,9 +32,9 @@ impl candle::CustomOp3 for FlashAttn {
let out_shape = q_l.shape().clone();
let out_l = Layout::contiguous(&out_shape);
let q = q.as_cuda_slice::<f16>()?;
let k = k.as_cuda_slice::<f16>()?;
let v = v.as_cuda_slice::<f16>()?;
let q = q.as_cuda_slice::<T>()?;
let k = k.as_cuda_slice::<T>()?;
let v = v.as_cuda_slice::<T>()?;
let q = q.slice(q_l.start_offset()..);
let k = k.slice(k_l.start_offset()..);
let v = v.slice(v_l.start_offset()..);
@ -104,7 +90,7 @@ impl candle::CustomOp3 for FlashAttn {
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let causal = if self.causal { 1 } else { 0 };
@ -155,6 +141,40 @@ impl candle::CustomOp3 for FlashAttn {
}
}
impl candle::CustomOp3 for FlashAttn {
fn name(&self) -> &'static str {
"flash-attn"
}
fn cpu_fwd(
&self,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for flash-attn")
}
fn cuda_fwd(
&self,
q: &candle::CudaStorage,
q_l: &Layout,
k: &candle::CudaStorage,
k_l: &Layout,
v: &candle::CudaStorage,
v_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match q.dtype() {
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
}
}
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
@ -191,24 +211,10 @@ struct FlashAttnVarLen {
seqlens_k: Tensor,
}
impl candle::CustomOp3 for FlashAttnVarLen {
fn name(&self) -> &'static str {
"flash-attn-varlen"
}
fn cpu_fwd(
&self,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for flash-attn")
}
fn cuda_fwd(
impl FlashAttnVarLen {
fn cuda_fwd_t<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
&self,
q: &candle::CudaStorage,
q_l: &Layout,
@ -364,6 +370,40 @@ impl candle::CustomOp3 for FlashAttnVarLen {
}
}
impl candle::CustomOp3 for FlashAttnVarLen {
fn name(&self) -> &'static str {
"flash-attn-varlen"
}
fn cpu_fwd(
&self,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for flash-attn")
}
fn cuda_fwd(
&self,
q: &candle::CudaStorage,
q_l: &Layout,
k: &candle::CudaStorage,
k_l: &Layout,
v: &candle::CudaStorage,
v_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match q.dtype() {
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
}
}
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///