Support alternative dtypes for mamba (#2036)

* Allow different dtypes in mamba.

* Add a dtype flag.
This commit is contained in:
Laurent Mazare
2024-04-10 18:10:01 +02:00
committed by GitHub
parent a4d5a414e3
commit b81ecf712d
5 changed files with 24 additions and 11 deletions

View File

@ -2,7 +2,7 @@ use candle::{Result, Tensor};
pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
let device = logits.device();
let mut logits = logits.to_vec1::<f32>()?;
let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::<f32>()?;
let mut already_seen = std::collections::HashSet::new();
for token_id in context {
if already_seen.contains(token_id) {