Add flash-attn support for stable-lm. (#1052)

This commit is contained in:
Laurent Mazare
2023-10-07 21:12:54 +01:00
committed by GitHub
parent d833527fda
commit 823fe23f9b
2 changed files with 30 additions and 3 deletions

View File

@ -220,7 +220,7 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
let (model, device) = {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {