Fix some shape issues in falcon. (#95)

* Fix some shape issues.

* Use different dtypes.
This commit is contained in:
Laurent Mazare
2023-07-06 19:23:54 +01:00
committed by GitHub
parent 4afa461b34
commit 0f679fe42e
2 changed files with 21 additions and 7 deletions

View File

@ -10,7 +10,10 @@ use clap::Parser;
mod model;
use model::{Config, Falcon, VarBuilder};
const DTYPE: DType = DType::F16;
#[cfg(feature = "mkl")]
const DTYPE: DType = DType::F32;
#[cfg(not(feature = "mkl"))]
const DTYPE: DType = DType::BF16;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]