mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00
Flash-Attn upgrade / SoftCap Candle-FlashAttn [2/n] (#2689)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace * softcap is working, including test and api. * make softcap test case better --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -55,7 +55,9 @@ extern "C" void run_mha(
|
||||
int is_causal,
|
||||
|
||||
int window_size_left,
|
||||
int window_size_right
|
||||
int window_size_right,
|
||||
|
||||
float softcap
|
||||
) {
|
||||
Flash_fwd_params params;
|
||||
// Reset the parameters
|
||||
@ -99,8 +101,16 @@ extern "C" void run_mha(
|
||||
params.d_rounded = d_rounded;
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
if (softcap > 0.0) {
|
||||
params.softcap = softmax_scale / softcap;
|
||||
params.scale_softmax = softcap;
|
||||
params.scale_softmax_log2 = softcap * M_LOG2E;
|
||||
} else{
|
||||
// Remove potential NaN
|
||||
params.softcap = 0.0;
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
}
|
||||
|
||||
params.p_dropout = 1.; // probability to keep
|
||||
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
||||
|
@ -45,6 +45,8 @@ extern "C" {
|
||||
|
||||
window_size_left: c_int,
|
||||
window_size_right: c_int,
|
||||
|
||||
softcap: f32,
|
||||
);
|
||||
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ pub struct FlashAttn {
|
||||
pub alibi_slopes: Option<Tensor>,
|
||||
pub window_size_left: Option<usize>,
|
||||
pub window_size_right: Option<usize>,
|
||||
pub softcap: Option<f32>,
|
||||
}
|
||||
|
||||
fn round_multiple(x: usize, m: usize) -> usize {
|
||||
@ -201,6 +202,7 @@ impl FlashAttn {
|
||||
/* is_causal */ is_causal,
|
||||
/* window_size_left */ window_size_left,
|
||||
/* window_size_right */ window_size_right,
|
||||
/* softcap */ self.softcap.unwrap_or(0f32),
|
||||
)
|
||||
}
|
||||
|
||||
@ -271,6 +273,7 @@ pub fn flash_attn(
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -308,6 +311,7 @@ pub fn flash_attn_windowed(
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -342,6 +346,7 @@ pub fn flash_attn_alibi(
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -381,6 +386,52 @@ pub fn flash_attn_alibi_windowed(
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
/// Flash-attention v2 layer.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors `k` and `v` with fewer heads
|
||||
/// than `q`. The number of heads in `k` and `v` must be divisible by the number of heads in `q`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `alibi_slopes` - Optional alibi slopes tensor with shape `(num_heads_q)`.
|
||||
/// * `softmax_scale` - Scaling factor for the softmax operation.
|
||||
/// * `window_size_left` - Optional limit on left attention to value tokens.
|
||||
/// * `window_size_right` - Optional limit on right attention to value tokens.
|
||||
/// * `softcap` - Gemma style softcap the attention logits before the softmax.
|
||||
///
|
||||
/// # Causal Mask
|
||||
///
|
||||
/// Setting `window_size_left=None` and `window_size_right=Some(0)` applies a causal mask to the result
|
||||
/// of `Q @ K^T`.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
pub fn flash_attn_alibi_windowed_softcap(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
alibi_slopes: Option<&Tensor>,
|
||||
softmax_scale: f32,
|
||||
window_size_left: Option<usize>,
|
||||
window_size_right: Option<usize>,
|
||||
softcap: f32,
|
||||
) -> Result<Tensor> {
|
||||
let op = FlashAttn {
|
||||
softmax_scale,
|
||||
alibi_slopes: alibi_slopes.cloned(),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: Some(softcap),
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -394,6 +445,7 @@ struct FlashAttnVarLen {
|
||||
pub alibi_slopes: Option<Tensor>,
|
||||
pub window_size_left: Option<usize>,
|
||||
pub window_size_right: Option<usize>,
|
||||
pub softcap: Option<f32>,
|
||||
}
|
||||
|
||||
impl FlashAttnVarLen {
|
||||
@ -613,6 +665,7 @@ impl FlashAttnVarLen {
|
||||
/* is_causal */ is_causal,
|
||||
/* window_size_left */ window_size_left,
|
||||
/* window_size_right */ window_size_right,
|
||||
/* softcap */ self.softcap.unwrap_or(0.0),
|
||||
)
|
||||
}
|
||||
|
||||
@ -699,6 +752,7 @@ pub fn flash_attn_varlen(
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -752,6 +806,7 @@ pub fn flash_attn_varlen_windowed(
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -802,6 +857,7 @@ pub fn flash_attn_varlen_alibi(
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
@ -857,6 +913,65 @@ pub fn flash_attn_varlen_alibi_windowed(
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: None,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Flash-attention v2 layer with variable-length batching.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `alibi_slopes` - Option, alibi slopes tensor with shape `(num_heads_q)`.
|
||||
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
|
||||
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
|
||||
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
|
||||
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
|
||||
/// * `window_size_left` - Option, limit left attention to value tokens.
|
||||
/// * `window_size_right` - Option, limit right attention to value tokens.
|
||||
/// * `softcap` - Gemma style softcap the attention logits before the softmax.
|
||||
///
|
||||
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
|
||||
/// `seqlen_1 + seqlen_2`, etc.
|
||||
///
|
||||
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
|
||||
///
|
||||
/// # Causal mask
|
||||
///
|
||||
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||
/// of `Q @ K^T`
|
||||
pub fn flash_attn_varlen_alibi_windowed_softcap(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
alibi_slopes: Option<&Tensor>,
|
||||
seqlens_q: &Tensor,
|
||||
seqlens_k: &Tensor,
|
||||
max_seqlen_q: usize,
|
||||
max_seqlen_k: usize,
|
||||
softmax_scale: f32,
|
||||
window_size_left: Option<usize>,
|
||||
window_size_right: Option<usize>,
|
||||
softcap: f32,
|
||||
) -> Result<Tensor> {
|
||||
let op = FlashAttnVarLen {
|
||||
softmax_scale,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqlens_q: seqlens_q.clone(),
|
||||
seqlens_k: seqlens_k.clone(),
|
||||
alibi_slopes: alibi_slopes.cloned(),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
softcap: Some(softcap),
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
@ -27,6 +27,20 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn fa_acausal_softcap(q: &Tensor, k: &Tensor, v: &Tensor, softcap: f32) -> Result<Tensor> {
|
||||
let in_dtype = q.dtype();
|
||||
let q = q.to_dtype(DType::F32)?;
|
||||
let k = k.to_dtype(DType::F32)?;
|
||||
let v = v.to_dtype(DType::F32)?;
|
||||
// let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
|
||||
let att = q.matmul(&k.t()?)?;
|
||||
let att = (softcap as f64 * ((att / softcap as f64)?.tanh())?)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_attn_acausal() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
@ -89,6 +103,44 @@ fn flash_attn_acausal() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_attn_acausal_softcap() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let q = Tensor::arange(0u32, 3 * 5 * 8, &device)?
|
||||
.to_dtype(DType::F16)?
|
||||
.reshape((1, 3, 5, 8))?;
|
||||
let k = (&q / 40.)?;
|
||||
let v = (&q / 50.)?;
|
||||
let q = (&q / 30.)?;
|
||||
let softcap = 5.0f32;
|
||||
|
||||
let ys1 = fa_acausal_softcap(&q, &k, &v, softcap.clone())?;
|
||||
let ys1 = ys1.i(0)?.to_dtype(DType::F32)?;
|
||||
let ys2 = {
|
||||
let q = q.transpose(1, 2)?;
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
candle_flash_attn::flash_attn_alibi_windowed_softcap(
|
||||
&q,
|
||||
&k,
|
||||
&v,
|
||||
None, // alibi_slopes //
|
||||
1.0, // softmax //
|
||||
None, // window_size_left //
|
||||
None, // window_size_right //
|
||||
softcap.clone(), // softcap //
|
||||
)?
|
||||
.transpose(1, 2)?
|
||||
};
|
||||
let ys2 = ys2.i(0)?.to_dtype(DType::F32)?;
|
||||
let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?;
|
||||
|
||||
assert_eq!(ys1.dims(), &[3, 5, 8]);
|
||||
assert_eq!(ys2.dims(), &[3, 5, 8]);
|
||||
assert!(diff.to_vec0::<f32>()?.abs() < 1e-3);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_attn_varlen() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
|
Reference in New Issue
Block a user