mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Support alternative dtypes for mamba (#2036)
* Allow different dtypes in mamba. * Add a dtype flag.
This commit is contained in:
@ -179,7 +179,9 @@ impl FalconRotaryEmbedding {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let on_true = Tensor::new(on_true, on_false.device())?
|
||||
.to_dtype(on_false.dtype())?
|
||||
.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
@ -1,4 +1,3 @@
|
||||
#![allow(unused)]
|
||||
/// A fast implementation of mamba for inference only.
|
||||
/// This is based on: https://github.com/LaurentMazare/mamba.rs
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
@ -38,12 +37,12 @@ pub struct State {
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result<Self> {
|
||||
pub fn new(batch_size: usize, cfg: &Config, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let mut hs = Vec::with_capacity(cfg.n_layer);
|
||||
let mut prev_xs = Vec::with_capacity(cfg.n_layer);
|
||||
for _i in 0..cfg.n_layer {
|
||||
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?;
|
||||
let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?;
|
||||
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), dtype, device)?;
|
||||
let x = Tensor::zeros((batch_size, cfg.d_inner()), dtype, device)?;
|
||||
hs.push(h);
|
||||
prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]);
|
||||
}
|
||||
@ -128,8 +127,8 @@ impl MambaBlock {
|
||||
let delta = delta.apply(&self.dt_proj)?;
|
||||
// softplus
|
||||
let delta = (delta.exp()? + 1.)?.log()?;
|
||||
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
|
||||
let d = self.d.to_dtype(candle::DType::F32)?;
|
||||
let a = self.a_log.to_dtype(delta.dtype())?.exp()?.neg()?;
|
||||
let d = self.d.to_dtype(delta.dtype())?;
|
||||
|
||||
// Selective scan part
|
||||
// Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t
|
||||
@ -178,6 +177,7 @@ pub struct Model {
|
||||
layers: Vec<ResidualBlock>,
|
||||
norm_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -196,6 +196,7 @@ impl Model {
|
||||
layers,
|
||||
norm_f,
|
||||
lm_head,
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -208,4 +209,8 @@ impl Model {
|
||||
state.pos += 1;
|
||||
xs.apply(&self.norm_f)?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use candle::{Result, Tensor};
|
||||
|
||||
pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
|
||||
let device = logits.device();
|
||||
let mut logits = logits.to_vec1::<f32>()?;
|
||||
let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::<f32>()?;
|
||||
let mut already_seen = std::collections::HashSet::new();
|
||||
for token_id in context {
|
||||
if already_seen.contains(token_id) {
|
||||
|
Reference in New Issue
Block a user