mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the causal mask in mixformer. (#937)
This commit is contained in:
@ -75,6 +75,20 @@ impl Module for Embedding {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device)
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
@ -198,6 +212,7 @@ struct MHA {
|
||||
rotary_emb: RotaryEmbedding,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
head_dim: usize,
|
||||
n_head: usize,
|
||||
softmax_scale: f64,
|
||||
span: tracing::Span,
|
||||
}
|
||||
@ -214,6 +229,7 @@ impl MHA {
|
||||
wqkv,
|
||||
out_proj,
|
||||
head_dim,
|
||||
n_head: cfg.n_head,
|
||||
kv_cache: None,
|
||||
rotary_emb,
|
||||
softmax_scale,
|
||||
@ -221,7 +237,7 @@ impl MHA {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||
let qkv = self
|
||||
@ -249,9 +265,16 @@ impl MHA {
|
||||
let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
|
||||
let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s
|
||||
|
||||
// TODO: Add the causal mask.
|
||||
// causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1)
|
||||
// scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||
let attn_weights = match mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => masked_fill(
|
||||
&attn_weights,
|
||||
&mask.broadcast_left(b_size * self.n_head)?,
|
||||
f32::NEG_INFINITY,
|
||||
)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
|
||||
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||
@ -287,11 +310,11 @@ impl ParallelBlock {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.ln)?;
|
||||
let attn_outputs = self.mixer.forward(&xs)?;
|
||||
let attn_outputs = self.mixer.forward(&xs, mask)?;
|
||||
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
|
||||
attn_outputs + feed_forward_hidden_states + residual
|
||||
}
|
||||
@ -327,8 +350,13 @@ impl MixFormerSequentialForCausalLM {
|
||||
let _enter = self.span.enter();
|
||||
let (_b_size, seq_len) = xs.dims2()?;
|
||||
let mut xs = xs.apply(&self.embedding)?;
|
||||
let mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(seq_len, xs.device())?)
|
||||
};
|
||||
for block in self.blocks.iter_mut() {
|
||||
xs = block.forward(&xs)?
|
||||
xs = block.forward(&xs, mask.as_ref())?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
||||
}
|
||||
|
Reference in New Issue
Block a user