Use softmax-last-dim in the quantized example. (#848)

This commit is contained in:
Laurent Mazare
2023-09-14 18:29:24 +02:00
committed by GitHub
parent a0c6d5548c
commit 0a647875ec
2 changed files with 23 additions and 20 deletions

View File

@ -144,7 +144,7 @@ impl LayerWeights {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;