mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a cli argument to easily switch the dtype. (#161)
This commit is contained in:
@ -14,11 +14,6 @@ use tokenizers::Tokenizer;
|
||||
mod model;
|
||||
use model::{Config, Falcon};
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const DTYPE: DType = DType::F32;
|
||||
#[cfg(not(feature = "mkl"))]
|
||||
const DTYPE: DType = DType::BF16;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Falcon,
|
||||
device: Device,
|
||||
@ -99,6 +94,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// Use f32 computations rather than bf16.
|
||||
#[arg(long)]
|
||||
use_f32: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
@ -151,7 +150,12 @@ fn main() -> Result<()> {
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &device);
|
||||
let dtype = if args.use_f32 {
|
||||
DType::F32
|
||||
} else {
|
||||
DType::BF16
|
||||
};
|
||||
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
|
||||
let config = Config::falcon7b();
|
||||
config.validate()?;
|
||||
let model = Falcon::load(vb, config)?;
|
||||
|
@ -24,10 +24,6 @@ mod model;
|
||||
use model::{Config, Llama};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
#[cfg(feature = "mkl")]
|
||||
const DTYPE: DType = DType::F32;
|
||||
#[cfg(not(feature = "mkl"))]
|
||||
const DTYPE: DType = DType::F16;
|
||||
const DEFAULT_PROMPT: &str = r"
|
||||
EDWARD:
|
||||
I wonder how our princely father 'scaped,
|
||||
@ -127,6 +123,10 @@ struct Args {
|
||||
/// The initial prompt.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// Use f32 computations rather than f16.
|
||||
#[arg(long)]
|
||||
use_f32: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -140,9 +140,10 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let config = Config::config_7b();
|
||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
|
||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||
let (llama, tokenizer_filename) = match args.npy {
|
||||
Some(filename) => {
|
||||
let vb = VarBuilder::from_npz(filename, DTYPE, &device)?;
|
||||
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
|
||||
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer)
|
||||
}
|
||||
@ -170,7 +171,7 @@ fn main() -> Result<()> {
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = VarBuilder::from_safetensors(tensors, DTYPE, &device);
|
||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename)
|
||||
}
|
||||
};
|
||||
|
Reference in New Issue
Block a user