mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use flash-attn in gemma. (#2195)
* Use flash-attn in gemma. * Fix flash-attn for head dim 256.
This commit is contained in:
@ -193,6 +193,9 @@ struct Args {
|
|||||||
/// The model to use.
|
/// The model to use.
|
||||||
#[arg(long, default_value = "2b")]
|
#[arg(long, default_value = "2b")]
|
||||||
which: Which,
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -270,7 +273,7 @@ fn main() -> Result<()> {
|
|||||||
DType::F32
|
DType::F32
|
||||||
};
|
};
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(args.use_flash_attn, &config, vb)?;
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|||||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||||
|
if (smem_size >= 48 * 1024) {
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||||
|
}
|
||||||
// int ctas_per_sm;
|
// int ctas_per_sm;
|
||||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||||
|
@ -139,7 +139,9 @@ impl FlashAttn {
|
|||||||
|
|
||||||
let elem_count = out_shape.elem_count();
|
let elem_count = out_shape.elem_count();
|
||||||
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
let softmax_lse = dev
|
||||||
|
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
|
||||||
|
.w()?;
|
||||||
|
|
||||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||||
|
|
||||||
|
@ -73,13 +73,6 @@ struct RotaryEmbedding {
|
|||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let last_dim = xs.dim(D::Minus1)?;
|
|
||||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
|
||||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
|
||||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
let dim = cfg.head_dim;
|
let dim = cfg.head_dim;
|
||||||
@ -94,7 +87,6 @@ impl RotaryEmbedding {
|
|||||||
.to_dtype(dtype)?
|
.to_dtype(dtype)?
|
||||||
.reshape((max_seq_len, 1))?;
|
.reshape((max_seq_len, 1))?;
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
sin: freqs.sin()?,
|
sin: freqs.sin()?,
|
||||||
cos: freqs.cos()?,
|
cos: freqs.cos()?,
|
||||||
@ -110,10 +102,8 @@ impl RotaryEmbedding {
|
|||||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
|
||||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
|
||||||
Ok((q_embed, k_embed))
|
Ok((q_embed, k_embed))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -163,10 +153,16 @@ struct Attention {
|
|||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
rotary_emb: Arc<RotaryEmbedding>,
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
use_flash_attn: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Attention {
|
impl Attention {
|
||||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(
|
||||||
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
use_flash_attn: bool,
|
||||||
|
cfg: &Config,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
let hidden_sz = cfg.hidden_size;
|
let hidden_sz = cfg.hidden_size;
|
||||||
let num_heads = cfg.num_attention_heads;
|
let num_heads = cfg.num_attention_heads;
|
||||||
let num_kv_heads = cfg.num_key_value_heads;
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
@ -188,6 +184,7 @@ impl Attention {
|
|||||||
head_dim,
|
head_dim,
|
||||||
rotary_emb,
|
rotary_emb,
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
|
use_flash_attn,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,7 +228,14 @@ impl Attention {
|
|||||||
let value_states =
|
let value_states =
|
||||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
let attn_output = {
|
let attn_output = if self.use_flash_attn {
|
||||||
|
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||||
|
let q = query_states.transpose(1, 2)?;
|
||||||
|
let k = key_states.transpose(1, 2)?;
|
||||||
|
let v = value_states.transpose(1, 2)?;
|
||||||
|
let scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||||
|
flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
|
||||||
|
} else {
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||||
|
|
||||||
@ -253,6 +257,22 @@ impl Attention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "flash-attn")]
|
||||||
|
fn flash_attn(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
softmax_scale: f32,
|
||||||
|
causal: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "flash-attn"))]
|
||||||
|
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||||
|
unimplemented!("compile with '--features flash-attn'")
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct DecoderLayer {
|
struct DecoderLayer {
|
||||||
self_attn: Attention,
|
self_attn: Attention,
|
||||||
@ -262,8 +282,13 @@ struct DecoderLayer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl DecoderLayer {
|
impl DecoderLayer {
|
||||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(
|
||||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
use_flash_attn: bool,
|
||||||
|
cfg: &Config,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?;
|
||||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
let input_layernorm =
|
let input_layernorm =
|
||||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||||
@ -312,7 +337,7 @@ pub struct Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let vb_m = vb.pp("model");
|
let vb_m = vb.pp("model");
|
||||||
let embed_tokens =
|
let embed_tokens =
|
||||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||||
@ -320,7 +345,8 @@ impl Model {
|
|||||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
let vb_l = vb_m.pp("layers");
|
let vb_l = vb_m.pp("layers");
|
||||||
for layer_idx in 0..cfg.num_hidden_layers {
|
for layer_idx in 0..cfg.num_hidden_layers {
|
||||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
let layer =
|
||||||
|
DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?;
|
||||||
layers.push(layer)
|
layers.push(layer)
|
||||||
}
|
}
|
||||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||||
|
Reference in New Issue
Block a user