mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Support alternative dtypes for mamba (#2036)
* Allow different dtypes in mamba. * Add a dtype flag.
This commit is contained in:
@ -54,6 +54,7 @@ impl TextGeneration {
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let dtype = self.model.dtype();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
@ -66,7 +67,7 @@ impl TextGeneration {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let mut state = State::new(1, &self.config, &self.device)?;
|
||||
let mut state = State::new(1, &self.config, dtype, &self.device)?;
|
||||
let mut next_logits = None;
|
||||
for &t in tokens.iter() {
|
||||
let input = Tensor::new(&[t], &self.device)?;
|
||||
@ -84,7 +85,7 @@ impl TextGeneration {
|
||||
Some(logits) => logits,
|
||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(dtype)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
@ -210,6 +211,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "f32")]
|
||||
dtype: String,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
@ -220,6 +224,7 @@ struct Args {
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use std::str::FromStr;
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
@ -279,7 +284,8 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let dtype = DType::from_str(&args.dtype)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
Reference in New Issue
Block a user