mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Support alternative dtypes for mamba (#2036)
* Allow different dtypes in mamba. * Add a dtype flag.
This commit is contained in:
@ -179,7 +179,9 @@ impl FalconRotaryEmbedding {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let on_true = Tensor::new(on_true, on_false.device())?
|
||||
.to_dtype(on_false.dtype())?
|
||||
.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
Reference in New Issue
Block a user