Add a cli argument to easily switch the dtype. (#161)

This commit is contained in:
Laurent Mazare
2023-07-13 19:18:49 +01:00
committed by GitHub
parent ded93a1169
commit 3c02ea56b0
2 changed files with 17 additions and 12 deletions

View File

@ -14,11 +14,6 @@ use tokenizers::Tokenizer;
mod model; mod model;
use model::{Config, Falcon}; use model::{Config, Falcon};
#[cfg(feature = "mkl")]
const DTYPE: DType = DType::F32;
#[cfg(not(feature = "mkl"))]
const DTYPE: DType = DType::BF16;
struct TextGeneration { struct TextGeneration {
model: Falcon, model: Falcon,
device: Device, device: Device,
@ -99,6 +94,10 @@ struct Args {
#[arg(long)] #[arg(long)]
prompt: String, prompt: String,
/// Use f32 computations rather than bf16.
#[arg(long)]
use_f32: bool,
/// The temperature used to generate samples. /// The temperature used to generate samples.
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
@ -151,7 +150,12 @@ fn main() -> Result<()> {
.map(|f| Ok(f.deserialize()?)) .map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(weights, DTYPE, &device); let dtype = if args.use_f32 {
DType::F32
} else {
DType::BF16
};
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
let config = Config::falcon7b(); let config = Config::falcon7b();
config.validate()?; config.validate()?;
let model = Falcon::load(vb, config)?; let model = Falcon::load(vb, config)?;

View File

@ -24,10 +24,6 @@ mod model;
use model::{Config, Llama}; use model::{Config, Llama};
const MAX_SEQ_LEN: usize = 4096; const MAX_SEQ_LEN: usize = 4096;
#[cfg(feature = "mkl")]
const DTYPE: DType = DType::F32;
#[cfg(not(feature = "mkl"))]
const DTYPE: DType = DType::F16;
const DEFAULT_PROMPT: &str = r" const DEFAULT_PROMPT: &str = r"
EDWARD: EDWARD:
I wonder how our princely father 'scaped, I wonder how our princely father 'scaped,
@ -127,6 +123,10 @@ struct Args {
/// The initial prompt. /// The initial prompt.
#[arg(long)] #[arg(long)]
prompt: Option<String>, prompt: Option<String>,
/// Use f32 computations rather than f16.
#[arg(long)]
use_f32: bool,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
@ -140,9 +140,10 @@ fn main() -> Result<()> {
}; };
let config = Config::config_7b(); let config = Config::config_7b();
let cache = model::Cache::new(!args.no_kv_cache, &config, &device); let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let (llama, tokenizer_filename) = match args.npy { let (llama, tokenizer_filename) = match args.npy {
Some(filename) => { Some(filename) => {
let vb = VarBuilder::from_npz(filename, DTYPE, &device)?; let vb = VarBuilder::from_npz(filename, dtype, &device)?;
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json"); let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
(Llama::load(vb, &cache, &config)?, tokenizer) (Llama::load(vb, &cache, &config)?, tokenizer)
} }
@ -170,7 +171,7 @@ fn main() -> Result<()> {
.map(|h| Ok(h.deserialize()?)) .map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(tensors, DTYPE, &device); let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
(Llama::load(vb, &cache, &config)?, tokenizer_filename) (Llama::load(vb, &cache, &config)?, tokenizer_filename)
} }
}; };