mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the sliding window. (#986)
This commit is contained in:
@ -299,7 +299,6 @@ pub struct Model {
|
|||||||
layers: Vec<DecoderLayer>,
|
layers: Vec<DecoderLayer>,
|
||||||
norm: RmsNorm,
|
norm: RmsNorm,
|
||||||
lm_head: Linear,
|
lm_head: Linear,
|
||||||
#[allow(unused)]
|
|
||||||
sliding_window: usize,
|
sliding_window: usize,
|
||||||
device: Device,
|
device: Device,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
@ -338,7 +337,15 @@ impl Model {
|
|||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
// Sliding window mask?
|
// Sliding window mask?
|
||||||
let mask: Vec<_> = (0..tgt_len)
|
let mask: Vec<_> = (0..tgt_len)
|
||||||
.flat_map(|i| (0..tgt_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. }))
|
.flat_map(|i| {
|
||||||
|
(0..tgt_len).map(move |j| {
|
||||||
|
if i < j || j + self.sliding_window < i {
|
||||||
|
f32::NEG_INFINITY
|
||||||
|
} else {
|
||||||
|
0.
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||||
let mask = if seqlen_offset > 0 {
|
let mask = if seqlen_offset > 0 {
|
||||||
|
Reference in New Issue
Block a user