mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Use u8 tensors for masks. (#273)
This commit is contained in:
@ -24,7 +24,7 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
|
||||
fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j <= i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||
Ok(mask)
|
||||
|
Reference in New Issue
Block a user