Support flash-attn in quantized phi3. (#2194)

This commit is contained in:
Laurent Mazare
2024-05-18 17:12:56 +02:00
committed by GitHub
parent 01545f7303
commit eefc1c77ef
2 changed files with 50 additions and 11 deletions

View File

@ -90,6 +90,9 @@ struct Args {
/// The model size to use.
#[arg(long, default_value = "phi-3b")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
}
impl Args {
@ -213,7 +216,13 @@ fn main() -> anyhow::Result<()> {
);
match args.which {
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
Which::Phi3 => Model::Phi3(Phi3::from_gguf(1, model, &mut file, &device)?),
Which::Phi3 => Model::Phi3(Phi3::from_gguf(
1,
args.use_flash_attn,
model,
&mut file,
&device,
)?),
Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
}
};

View File

@ -69,6 +69,7 @@ struct LayerWeights {
sin: Tensor,
neg_inf: Tensor,
kv_cache: KvCache,
use_flash_attn: bool,
span_attn: tracing::Span,
span_rot: tracing::Span,
}
@ -125,23 +126,50 @@ impl LayerWeights {
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
None => att,
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
masked_fill(&att, &mask, &self.neg_inf)?
}
let y = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
let q = q.to_dtype(DType::BF16)?.transpose(1, 2)?;
let k = k.to_dtype(DType::BF16)?.transpose(1, 2)?;
let v = v.to_dtype(DType::BF16)?.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
.to_dtype(DType::F32)?
.transpose(1, 2)?
} else {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
None => att,
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
masked_fill(&att, &mask, &self.neg_inf)?
}
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.attn_output.forward(&y)?;
Ok(y)
}
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
#[derive(Debug, Clone)]
pub struct ModelWeights {
tok_embeddings: Embedding,
@ -176,6 +204,7 @@ fn precomput_freqs_cis(
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
batch_size: usize,
use_flash_attn: bool,
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
@ -242,6 +271,7 @@ impl ModelWeights {
sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache,
use_flash_attn,
span_attn,
span_rot,
})