mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a simpler way to specify the dim index for some ops.
This commit is contained in:
@ -283,19 +283,18 @@ impl CausalSelfAttention {
|
||||
dims.push(v / 2);
|
||||
dims.push(2);
|
||||
let x = x.reshape(dims)?;
|
||||
let rank = x.rank();
|
||||
let re_x = x.narrow(rank - 1, 0, 1)?;
|
||||
let im_x = x.narrow(rank - 1, 1, 1)?;
|
||||
let re_x = x.narrow(candle::D::Minus1, 0, 1)?;
|
||||
let im_x = x.narrow(candle::D::Minus1, 1, 1)?;
|
||||
let re_f = freqs_cis
|
||||
.narrow(rank - 1, 0, 1)?
|
||||
.narrow(candle::D::Minus1, 0, 1)?
|
||||
.broadcast_as(re_x.shape())?;
|
||||
let im_f = freqs_cis
|
||||
.narrow(rank - 1, 1, 1)?
|
||||
.narrow(candle::D::Minus1, 1, 1)?
|
||||
.broadcast_as(im_x.shape())?;
|
||||
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
|
||||
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
|
||||
let rope = Tensor::cat(&[&re, &im], rank - 1)?;
|
||||
let rope = rope.flatten(Some(rope.rank() - 2), None)?;
|
||||
let rope = Tensor::cat(&[&re, &im], re.rank() - 1)?;
|
||||
let rope = rope.flatten_from(candle::D::Minus2)?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
@ -339,7 +338,7 @@ impl CausalSelfAttention {
|
||||
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
|
||||
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(att.rank() - 1)?;
|
||||
let att = att.softmax(candle::D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
|
||||
@ -537,7 +536,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
let next_token = if let Some(temperature) = args.temperature {
|
||||
println!("Sampling with temperature {temperature:?}");
|
||||
let prs = (&logits / temperature)?.softmax(logits.rank() - 1)?;
|
||||
let prs = (&logits / temperature)?.softmax(candle::D::Minus1)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
|
||||
|
Reference in New Issue
Block a user