mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a cli argument to easily switch the dtype. (#161)
This commit is contained in:
@ -14,11 +14,6 @@ use tokenizers::Tokenizer;
|
|||||||
mod model;
|
mod model;
|
||||||
use model::{Config, Falcon};
|
use model::{Config, Falcon};
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
|
||||||
const DTYPE: DType = DType::F32;
|
|
||||||
#[cfg(not(feature = "mkl"))]
|
|
||||||
const DTYPE: DType = DType::BF16;
|
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
model: Falcon,
|
model: Falcon,
|
||||||
device: Device,
|
device: Device,
|
||||||
@ -99,6 +94,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
|
/// Use f32 computations rather than bf16.
|
||||||
|
#[arg(long)]
|
||||||
|
use_f32: bool,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
@ -151,7 +150,12 @@ fn main() -> Result<()> {
|
|||||||
.map(|f| Ok(f.deserialize()?))
|
.map(|f| Ok(f.deserialize()?))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &device);
|
let dtype = if args.use_f32 {
|
||||||
|
DType::F32
|
||||||
|
} else {
|
||||||
|
DType::BF16
|
||||||
|
};
|
||||||
|
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
|
||||||
let config = Config::falcon7b();
|
let config = Config::falcon7b();
|
||||||
config.validate()?;
|
config.validate()?;
|
||||||
let model = Falcon::load(vb, config)?;
|
let model = Falcon::load(vb, config)?;
|
||||||
|
@ -24,10 +24,6 @@ mod model;
|
|||||||
use model::{Config, Llama};
|
use model::{Config, Llama};
|
||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 4096;
|
const MAX_SEQ_LEN: usize = 4096;
|
||||||
#[cfg(feature = "mkl")]
|
|
||||||
const DTYPE: DType = DType::F32;
|
|
||||||
#[cfg(not(feature = "mkl"))]
|
|
||||||
const DTYPE: DType = DType::F16;
|
|
||||||
const DEFAULT_PROMPT: &str = r"
|
const DEFAULT_PROMPT: &str = r"
|
||||||
EDWARD:
|
EDWARD:
|
||||||
I wonder how our princely father 'scaped,
|
I wonder how our princely father 'scaped,
|
||||||
@ -127,6 +123,10 @@ struct Args {
|
|||||||
/// The initial prompt.
|
/// The initial prompt.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
/// Use f32 computations rather than f16.
|
||||||
|
#[arg(long)]
|
||||||
|
use_f32: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -140,9 +140,10 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
|
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
|
||||||
|
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||||
let (llama, tokenizer_filename) = match args.npy {
|
let (llama, tokenizer_filename) = match args.npy {
|
||||||
Some(filename) => {
|
Some(filename) => {
|
||||||
let vb = VarBuilder::from_npz(filename, DTYPE, &device)?;
|
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
|
||||||
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
||||||
(Llama::load(vb, &cache, &config)?, tokenizer)
|
(Llama::load(vb, &cache, &config)?, tokenizer)
|
||||||
}
|
}
|
||||||
@ -170,7 +171,7 @@ fn main() -> Result<()> {
|
|||||||
.map(|h| Ok(h.deserialize()?))
|
.map(|h| Ok(h.deserialize()?))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let vb = VarBuilder::from_safetensors(tensors, DTYPE, &device);
|
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename)
|
(Llama::load(vb, &cache, &config)?, tokenizer_filename)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user