Use flash-attn in gemma. (#2195)

* Use flash-attn in gemma.

* Fix flash-attn for head dim 256.
This commit is contained in:
Laurent Mazare
2024-05-18 19:18:59 +02:00
committed by GitHub
parent eefc1c77ef
commit 7ebc3548e1
4 changed files with 55 additions and 20 deletions

View File

@ -193,6 +193,9 @@ struct Args {
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
}
fn main() -> Result<()> {
@ -270,7 +273,7 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
let model = Model::new(args.use_flash_attn, &config, vb)?;
println!("loaded the model in {:?}", start.elapsed());