mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 20:38:06 +00:00
Support alternative dtypes for mamba (#2036)
* Allow different dtypes in mamba. * Add a dtype flag.
This commit is contained in:
@ -2,7 +2,7 @@ use candle::{Result, Tensor};
|
||||
|
||||
pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
|
||||
let device = logits.device();
|
||||
let mut logits = logits.to_vec1::<f32>()?;
|
||||
let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::<f32>()?;
|
||||
let mut already_seen = std::collections::HashSet::new();
|
||||
for token_id in context {
|
||||
if already_seen.contains(token_id) {
|
||||
|
Reference in New Issue
Block a user