mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add flash-attn support for stable-lm. (#1052)
This commit is contained in:
@ -220,7 +220,7 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = Config::stablelm_3b_4e1t();
|
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
||||||
let (model, device) = {
|
let (model, device) = {
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let dtype = if device.is_cuda() {
|
let dtype = if device.is_cuda() {
|
||||||
|
@ -19,10 +19,11 @@ pub struct Config {
|
|||||||
pub(crate) max_position_embeddings: usize,
|
pub(crate) max_position_embeddings: usize,
|
||||||
pub(crate) norm_eps: f64,
|
pub(crate) norm_eps: f64,
|
||||||
pub(crate) use_cache: bool,
|
pub(crate) use_cache: bool,
|
||||||
|
pub(crate) use_flash_attn: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn stablelm_3b_4e1t() -> Self {
|
pub fn stablelm_3b_4e1t(use_flash_attn: bool) -> Self {
|
||||||
Self {
|
Self {
|
||||||
vocab_size: 50304,
|
vocab_size: 50304,
|
||||||
intermediate_size: 6912,
|
intermediate_size: 6912,
|
||||||
@ -36,6 +37,7 @@ impl Config {
|
|||||||
max_position_embeddings: 4096,
|
max_position_embeddings: 4096,
|
||||||
norm_eps: 1e-5,
|
norm_eps: 1e-5,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
|
use_flash_attn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,6 +136,22 @@ impl Module for MLP {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "flash-attn")]
|
||||||
|
fn flash_attn(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
softmax_scale: f32,
|
||||||
|
causal: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "flash-attn"))]
|
||||||
|
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||||
|
unimplemented!("compile with '--features flash-attn'")
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Attention {
|
struct Attention {
|
||||||
q_proj: Linear,
|
q_proj: Linear,
|
||||||
@ -149,6 +167,7 @@ struct Attention {
|
|||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
use_cache: bool,
|
use_cache: bool,
|
||||||
rotary_ndims: usize,
|
rotary_ndims: usize,
|
||||||
|
use_flash_attn: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Attention {
|
impl Attention {
|
||||||
@ -175,6 +194,7 @@ impl Attention {
|
|||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
use_cache: cfg.use_cache,
|
use_cache: cfg.use_cache,
|
||||||
rotary_ndims: cfg.rotary_ndims(),
|
rotary_ndims: cfg.rotary_ndims(),
|
||||||
|
use_flash_attn: cfg.use_flash_attn,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,7 +258,14 @@ impl Attention {
|
|||||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||||
|
|
||||||
let attn_output = {
|
let attn_output = if self.use_flash_attn {
|
||||||
|
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||||
|
let q = query_states.transpose(1, 2)?;
|
||||||
|
let k = key_states.transpose(1, 2)?;
|
||||||
|
let v = value_states.transpose(1, 2)?;
|
||||||
|
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||||
|
flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)?
|
||||||
|
} else {
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user