Use flash-attn for mistral. (#1004)

This commit is contained in:
Laurent Mazare
2023-09-30 13:15:10 +02:00
committed by GitHub
parent 87e3a4e175
commit 4021272875
2 changed files with 41 additions and 9 deletions

View File

@ -113,6 +113,9 @@ struct Args {
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
@ -207,7 +210,7 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::config_7b_v0_1();
let config = Config::config_7b_v0_1(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16