mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a simpler way to specify the dim index for some ops.
This commit is contained in:
@ -109,7 +109,7 @@ impl Decode {
|
||||
};
|
||||
tokens.push(next_token);
|
||||
let prob = logits
|
||||
.softmax(logits.rank() - 1)?
|
||||
.softmax(candle::D::Minus1)?
|
||||
.get(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx {
|
||||
|
@ -342,8 +342,8 @@ impl MultiHeadAttention {
|
||||
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
|
||||
qk = qk.broadcast_add(&mask)?
|
||||
}
|
||||
let w = qk.softmax(qk.rank() - 1)?;
|
||||
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
|
||||
let w = qk.softmax(candle::D::Minus1)?;
|
||||
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
|
||||
Ok(wv)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user