Add the sliding window. (#986)

This commit is contained in:
Laurent Mazare
2023-09-28 18:26:33 +02:00
committed by GitHub
parent 716ab2ccdc
commit 23b3576c47

View File

@ -299,7 +299,6 @@ pub struct Model {
layers: Vec<DecoderLayer>, layers: Vec<DecoderLayer>,
norm: RmsNorm, norm: RmsNorm,
lm_head: Linear, lm_head: Linear,
#[allow(unused)]
sliding_window: usize, sliding_window: usize,
device: Device, device: Device,
dtype: DType, dtype: DType,
@ -338,7 +337,15 @@ impl Model {
) -> Result<Tensor> { ) -> Result<Tensor> {
// Sliding window mask? // Sliding window mask?
let mask: Vec<_> = (0..tgt_len) let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. })) .flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + self.sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect(); .collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 { let mask = if seqlen_offset > 0 {