Add a cli argument to easily switch the dtype. (#161)

This commit is contained in:
Laurent Mazare
2023-07-13 19:18:49 +01:00
committed by GitHub
parent ded93a1169
commit 3c02ea56b0
2 changed files with 17 additions and 12 deletions

View File

@ -14,11 +14,6 @@ use tokenizers::Tokenizer;
mod model;
use model::{Config, Falcon};
#[cfg(feature = "mkl")]
const DTYPE: DType = DType::F32;
#[cfg(not(feature = "mkl"))]
const DTYPE: DType = DType::BF16;
struct TextGeneration {
model: Falcon,
device: Device,
@ -99,6 +94,10 @@ struct Args {
#[arg(long)]
prompt: String,
/// Use f32 computations rather than bf16.
#[arg(long)]
use_f32: bool,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
@ -151,7 +150,12 @@ fn main() -> Result<()> {
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(weights, DTYPE, &device);
let dtype = if args.use_f32 {
DType::F32
} else {
DType::BF16
};
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
let config = Config::falcon7b();
config.validate()?;
let model = Falcon::load(vb, config)?;