mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use flash-attn for mistral. (#1004)
This commit is contained in:
@ -113,6 +113,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
@ -207,7 +210,7 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::config_7b_v0_1();
|
||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
|
@ -17,10 +17,11 @@ pub struct Config {
|
||||
rms_norm_eps: f64,
|
||||
rope_theta: f64,
|
||||
sliding_window: usize,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn config_7b_v0_1() -> Self {
|
||||
pub fn config_7b_v0_1(use_flash_attn: bool) -> Self {
|
||||
Self {
|
||||
vocab_size: 32000,
|
||||
hidden_size: 4096,
|
||||
@ -33,6 +34,7 @@ impl Config {
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.,
|
||||
sliding_window: 4096,
|
||||
use_flash_attn,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -142,6 +144,22 @@ impl Module for MLP {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
@ -155,6 +173,7 @@ struct Attention {
|
||||
hidden_size: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
@ -180,6 +199,7 @@ impl Attention {
|
||||
hidden_size: hidden_sz,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
use_flash_attn: cfg.use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
@ -234,6 +254,14 @@ impl Attention {
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
|
||||
let attn_output = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
let q = query_states.transpose(1, 2)?;
|
||||
let k = key_states.transpose(1, 2)?;
|
||||
let v = value_states.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)?
|
||||
} else {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
@ -242,7 +270,8 @@ impl Attention {
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&value_states)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
|
Reference in New Issue
Block a user