mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the sliding window. (#986)
This commit is contained in:
@ -299,7 +299,6 @@ pub struct Model {
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
#[allow(unused)]
|
||||
sliding_window: usize,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
@ -338,7 +337,15 @@ impl Model {
|
||||
) -> Result<Tensor> {
|
||||
// Sliding window mask?
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. }))
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + self.sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
|
Reference in New Issue
Block a user