mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use u8 tensors for masks. (#273)
This commit is contained in:
@ -179,9 +179,8 @@ impl Cache {
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
// TODO: If we support bool or u8 tensors, this would be better.
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
masks.insert(t, mask.clone());
|
||||
|
Reference in New Issue
Block a user