mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Add a cli argument to easily switch the dtype. (#161)
This commit is contained in:
@ -14,11 +14,6 @@ use tokenizers::Tokenizer;
|
||||
mod model;
|
||||
use model::{Config, Falcon};
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const DTYPE: DType = DType::F32;
|
||||
#[cfg(not(feature = "mkl"))]
|
||||
const DTYPE: DType = DType::BF16;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Falcon,
|
||||
device: Device,
|
||||
@ -99,6 +94,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// Use f32 computations rather than bf16.
|
||||
#[arg(long)]
|
||||
use_f32: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
@ -151,7 +150,12 @@ fn main() -> Result<()> {
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &device);
|
||||
let dtype = if args.use_f32 {
|
||||
DType::F32
|
||||
} else {
|
||||
DType::BF16
|
||||
};
|
||||
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
|
||||
let config = Config::falcon7b();
|
||||
config.validate()?;
|
||||
let model = Falcon::load(vb, config)?;
|
||||
|
Reference in New Issue
Block a user