From e2bfbcb79ce80e45078ff6b7fefad18a09ef6554 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 5 Jul 2023 20:39:08 +0100 Subject: [PATCH] Support dim indexes in cat. --- candle-core/src/tensor.rs | 6 ++++-- candle-examples/examples/llama/main.rs | 21 ++++++++++----------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 1eb92e6a..e5d80ff4 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -970,10 +970,11 @@ impl Tensor { self.reshape(dims) } - pub fn stack>(args: &[A], dim: usize) -> Result { + pub fn stack, D: Dim>(args: &[A], dim: D) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }); } + let dim = dim.to_index(args[0].as_ref().shape(), "stack")?; let args = args .iter() .map(|t| t.as_ref().unsqueeze(dim)) @@ -981,7 +982,7 @@ impl Tensor { Self::cat(&args, dim) } - pub fn cat>(args: &[A], dim: usize) -> Result { + pub fn cat, D: Dim>(args: &[A], dim: D) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); } @@ -989,6 +990,7 @@ impl Tensor { if args.len() == 1 { return Ok(arg0.clone()); } + let dim = dim.to_index(arg0.shape(), "cat")?; for arg in args { arg.as_ref().check_dim(dim, "cat")?; } diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index d254eeed..9f87b59a 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -15,7 +15,7 @@ use anyhow::{Error as E, Result}; use clap::Parser; use rand::{distributions::Distribution, SeedableRng}; -use candle::{DType, Device, Tensor}; +use candle::{DType, Device, Tensor, D}; use candle_hub::{api::Api, Repo, RepoType}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -283,18 +283,18 @@ impl CausalSelfAttention { dims.push(v / 2); dims.push(2); let x = x.reshape(dims)?; - let re_x = x.narrow(candle::D::Minus1, 0, 1)?; - let im_x = x.narrow(candle::D::Minus1, 1, 1)?; + let re_x = x.narrow(D::Minus1, 0, 1)?; + let im_x = x.narrow(D::Minus1, 1, 1)?; let re_f = freqs_cis - .narrow(candle::D::Minus1, 0, 1)? + .narrow(D::Minus1, 0, 1)? .broadcast_as(re_x.shape())?; let im_f = freqs_cis - .narrow(candle::D::Minus1, 1, 1)? + .narrow(D::Minus1, 1, 1)? .broadcast_as(im_x.shape())?; let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?; let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?; - let rope = Tensor::cat(&[&re, &im], re.rank() - 1)?; - let rope = rope.flatten_from(candle::D::Minus2)?; + let rope = Tensor::cat(&[&re, &im], D::Minus1)?; + let rope = rope.flatten_from(D::Minus2)?; Ok(rope) } @@ -338,7 +338,7 @@ impl CausalSelfAttention { let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?; let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = att.softmax(candle::D::Minus1)?; + let att = att.softmax(D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(0, 1)?.reshape(&[t, c])?; @@ -424,8 +424,7 @@ fn precompute_freqs_cis(config: &Config, device: &Device) -> Result { let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1]; let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?; let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?; - let last_dim = idx_theta_cos.rank() - 1; - Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim)?) + Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], D::Minus1)?) } #[derive(Parser, Debug)] @@ -536,7 +535,7 @@ async fn main() -> Result<()> { let next_token = if let Some(temperature) = args.temperature { println!("Sampling with temperature {temperature:?}"); - let prs = (&logits / temperature)?.softmax(candle::D::Minus1)?; + let prs = (&logits / temperature)?.softmax(D::Minus1)?; let logits_v: Vec = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?;