mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use flash-attn in gemma. (#2195)
* Use flash-attn in gemma. * Fix flash-attn for head dim 256.
This commit is contained in:
@ -193,6 +193,9 @@ struct Args {
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "2b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -270,7 +273,7 @@ fn main() -> Result<()> {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
let model = Model::new(args.use_flash_attn, &config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
Reference in New Issue
Block a user