diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 4757d2b1..5cc7b065 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -14,11 +14,6 @@ use tokenizers::Tokenizer; mod model; use model::{Config, Falcon}; -#[cfg(feature = "mkl")] -const DTYPE: DType = DType::F32; -#[cfg(not(feature = "mkl"))] -const DTYPE: DType = DType::BF16; - struct TextGeneration { model: Falcon, device: Device, @@ -99,6 +94,10 @@ struct Args { #[arg(long)] prompt: String, + /// Use f32 computations rather than bf16. + #[arg(long)] + use_f32: bool, + /// The temperature used to generate samples. #[arg(long)] temperature: Option, @@ -151,7 +150,12 @@ fn main() -> Result<()> { .map(|f| Ok(f.deserialize()?)) .collect::>>()?; - let vb = VarBuilder::from_safetensors(weights, DTYPE, &device); + let dtype = if args.use_f32 { + DType::F32 + } else { + DType::BF16 + }; + let vb = VarBuilder::from_safetensors(weights, dtype, &device); let config = Config::falcon7b(); config.validate()?; let model = Falcon::load(vb, config)?; diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 301b870a..7ba87c70 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -24,10 +24,6 @@ mod model; use model::{Config, Llama}; const MAX_SEQ_LEN: usize = 4096; -#[cfg(feature = "mkl")] -const DTYPE: DType = DType::F32; -#[cfg(not(feature = "mkl"))] -const DTYPE: DType = DType::F16; const DEFAULT_PROMPT: &str = r" EDWARD: I wonder how our princely father 'scaped, @@ -127,6 +123,10 @@ struct Args { /// The initial prompt. #[arg(long)] prompt: Option, + + /// Use f32 computations rather than f16. + #[arg(long)] + use_f32: bool, } fn main() -> Result<()> { @@ -140,9 +140,10 @@ fn main() -> Result<()> { }; let config = Config::config_7b(); let cache = model::Cache::new(!args.no_kv_cache, &config, &device); + let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; let (llama, tokenizer_filename) = match args.npy { Some(filename) => { - let vb = VarBuilder::from_npz(filename, DTYPE, &device)?; + let vb = VarBuilder::from_npz(filename, dtype, &device)?; let tokenizer = std::path::PathBuf::from("llama-tokenizer.json"); (Llama::load(vb, &cache, &config)?, tokenizer) } @@ -170,7 +171,7 @@ fn main() -> Result<()> { .map(|h| Ok(h.deserialize()?)) .collect::>>()?; - let vb = VarBuilder::from_safetensors(tensors, DTYPE, &device); + let vb = VarBuilder::from_safetensors(tensors, dtype, &device); (Llama::load(vb, &cache, &config)?, tokenizer_filename) } };