diff --git a/candle-examples/examples/mamba/main.rs b/candle-examples/examples/mamba/main.rs index 4802f960..b8c8bb70 100644 --- a/candle-examples/examples/mamba/main.rs +++ b/candle-examples/examples/mamba/main.rs @@ -54,6 +54,7 @@ impl TextGeneration { fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { use std::io::Write; self.tokenizer.clear(); + let dtype = self.model.dtype(); let mut tokens = self .tokenizer .tokenizer() @@ -66,7 +67,7 @@ impl TextGeneration { Some(token) => token, None => anyhow::bail!("cannot find the token"), }; - let mut state = State::new(1, &self.config, &self.device)?; + let mut state = State::new(1, &self.config, dtype, &self.device)?; let mut next_logits = None; for &t in tokens.iter() { let input = Tensor::new(&[t], &self.device)?; @@ -84,7 +85,7 @@ impl TextGeneration { Some(logits) => logits, None => anyhow::bail!("cannot work on an empty prompt"), }; - let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits = logits.squeeze(0)?.to_dtype(dtype)?; let logits = if self.repeat_penalty == 1. { logits } else { @@ -210,6 +211,9 @@ struct Args { #[arg(long)] config_file: Option, + #[arg(long, default_value = "f32")] + dtype: String, + /// Penalty to be applied for repeating tokens, 1. means no penalty. #[arg(long, default_value_t = 1.1)] repeat_penalty: f32, @@ -220,6 +224,7 @@ struct Args { } fn main() -> Result<()> { + use std::str::FromStr; use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -279,7 +284,8 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let dtype = DType::from_str(&args.dtype)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let model = Model::new(&config, vb.pp("backbone"))?; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/yolo-v8/assets/bike.pp.jpg b/candle-examples/examples/yolo-v8/assets/bike.pp.jpg new file mode 100644 index 00000000..a46b8e84 Binary files /dev/null and b/candle-examples/examples/yolo-v8/assets/bike.pp.jpg differ diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 5fea27b9..e9d4af7e 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -179,7 +179,9 @@ impl FalconRotaryEmbedding { fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let on_true = Tensor::new(on_true, on_false.device())? + .to_dtype(on_false.dtype())? + .broadcast_as(shape.dims())?; let m = mask.where_cond(&on_true, on_false)?; Ok(m) } diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs index 836327ee..a75ee87a 100644 --- a/candle-transformers/src/models/mamba.rs +++ b/candle-transformers/src/models/mamba.rs @@ -1,4 +1,3 @@ -#![allow(unused)] /// A fast implementation of mamba for inference only. /// This is based on: https://github.com/LaurentMazare/mamba.rs use crate::models::with_tracing::{linear, linear_no_bias, Linear}; @@ -38,12 +37,12 @@ pub struct State { } impl State { - pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result { + pub fn new(batch_size: usize, cfg: &Config, dtype: DType, device: &Device) -> Result { let mut hs = Vec::with_capacity(cfg.n_layer); let mut prev_xs = Vec::with_capacity(cfg.n_layer); for _i in 0..cfg.n_layer { - let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?; - let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?; + let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), dtype, device)?; + let x = Tensor::zeros((batch_size, cfg.d_inner()), dtype, device)?; hs.push(h); prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]); } @@ -128,8 +127,8 @@ impl MambaBlock { let delta = delta.apply(&self.dt_proj)?; // softplus let delta = (delta.exp()? + 1.)?.log()?; - let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?; - let d = self.d.to_dtype(candle::DType::F32)?; + let a = self.a_log.to_dtype(delta.dtype())?.exp()?.neg()?; + let d = self.d.to_dtype(delta.dtype())?; // Selective scan part // Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t @@ -178,6 +177,7 @@ pub struct Model { layers: Vec, norm_f: RmsNorm, lm_head: Linear, + dtype: DType, } impl Model { @@ -196,6 +196,7 @@ impl Model { layers, norm_f, lm_head, + dtype: vb.dtype(), }) } @@ -208,4 +209,8 @@ impl Model { state.pos += 1; xs.apply(&self.norm_f)?.apply(&self.lm_head) } + + pub fn dtype(&self) -> DType { + self.dtype + } } diff --git a/candle-transformers/src/utils.rs b/candle-transformers/src/utils.rs index 3cbcac5c..d29995ed 100644 --- a/candle-transformers/src/utils.rs +++ b/candle-transformers/src/utils.rs @@ -2,7 +2,7 @@ use candle::{Result, Tensor}; pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result { let device = logits.device(); - let mut logits = logits.to_vec1::()?; + let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::()?; let mut already_seen = std::collections::HashSet::new(); for token_id in context { if already_seen.contains(token_id) {