Support alternative dtypes for mamba (#2036)

* Allow different dtypes in mamba.

* Add a dtype flag.
This commit is contained in:
Laurent Mazare
2024-04-10 18:10:01 +02:00
committed by GitHub
parent a4d5a414e3
commit b81ecf712d
5 changed files with 24 additions and 11 deletions

View File

@ -179,7 +179,9 @@ impl FalconRotaryEmbedding {
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let on_true = Tensor::new(on_true, on_false.device())?
.to_dtype(on_false.dtype())?
.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}