mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00
Support alternative dtypes for mamba (#2036)
* Allow different dtypes in mamba. * Add a dtype flag.
This commit is contained in:
@ -54,6 +54,7 @@ impl TextGeneration {
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let dtype = self.model.dtype();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
@ -66,7 +67,7 @@ impl TextGeneration {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let mut state = State::new(1, &self.config, &self.device)?;
|
||||
let mut state = State::new(1, &self.config, dtype, &self.device)?;
|
||||
let mut next_logits = None;
|
||||
for &t in tokens.iter() {
|
||||
let input = Tensor::new(&[t], &self.device)?;
|
||||
@ -84,7 +85,7 @@ impl TextGeneration {
|
||||
Some(logits) => logits,
|
||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(dtype)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
@ -210,6 +211,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "f32")]
|
||||
dtype: String,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
@ -220,6 +224,7 @@ struct Args {
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use std::str::FromStr;
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
@ -279,7 +284,8 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let dtype = DType::from_str(&args.dtype)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
BIN
candle-examples/examples/yolo-v8/assets/bike.pp.jpg
Normal file
BIN
candle-examples/examples/yolo-v8/assets/bike.pp.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 175 KiB |
@ -179,7 +179,9 @@ impl FalconRotaryEmbedding {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let on_true = Tensor::new(on_true, on_false.device())?
|
||||
.to_dtype(on_false.dtype())?
|
||||
.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
@ -1,4 +1,3 @@
|
||||
#![allow(unused)]
|
||||
/// A fast implementation of mamba for inference only.
|
||||
/// This is based on: https://github.com/LaurentMazare/mamba.rs
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
@ -38,12 +37,12 @@ pub struct State {
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result<Self> {
|
||||
pub fn new(batch_size: usize, cfg: &Config, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let mut hs = Vec::with_capacity(cfg.n_layer);
|
||||
let mut prev_xs = Vec::with_capacity(cfg.n_layer);
|
||||
for _i in 0..cfg.n_layer {
|
||||
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?;
|
||||
let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?;
|
||||
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), dtype, device)?;
|
||||
let x = Tensor::zeros((batch_size, cfg.d_inner()), dtype, device)?;
|
||||
hs.push(h);
|
||||
prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]);
|
||||
}
|
||||
@ -128,8 +127,8 @@ impl MambaBlock {
|
||||
let delta = delta.apply(&self.dt_proj)?;
|
||||
// softplus
|
||||
let delta = (delta.exp()? + 1.)?.log()?;
|
||||
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
|
||||
let d = self.d.to_dtype(candle::DType::F32)?;
|
||||
let a = self.a_log.to_dtype(delta.dtype())?.exp()?.neg()?;
|
||||
let d = self.d.to_dtype(delta.dtype())?;
|
||||
|
||||
// Selective scan part
|
||||
// Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t
|
||||
@ -178,6 +177,7 @@ pub struct Model {
|
||||
layers: Vec<ResidualBlock>,
|
||||
norm_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -196,6 +196,7 @@ impl Model {
|
||||
layers,
|
||||
norm_f,
|
||||
lm_head,
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -208,4 +209,8 @@ impl Model {
|
||||
state.pos += 1;
|
||||
xs.apply(&self.norm_f)?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use candle::{Result, Tensor};
|
||||
|
||||
pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
|
||||
let device = logits.device();
|
||||
let mut logits = logits.to_vec1::<f32>()?;
|
||||
let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::<f32>()?;
|
||||
let mut already_seen = std::collections::HashSet::new();
|
||||
for token_id in context {
|
||||
if already_seen.contains(token_id) {
|
||||
|
Reference in New Issue
Block a user