diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 2365a34d..9a2602f4 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -23,7 +23,7 @@ pub use device::{Device, DeviceLocation}; pub use dtype::{DType, WithDType}; pub use error::{Error, Result}; pub use layout::Layout; -pub use shape::Shape; +pub use shape::{Shape, D}; pub use storage::Storage; use strided_index::StridedIndex; pub use tensor::{Tensor, TensorId}; diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index cc068004..1152dc3e 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -183,6 +183,45 @@ impl Shape { } } +pub trait Dim { + fn to_index(&self, shape: &Shape, op: &'static str) -> Result; +} + +impl Dim for usize { + fn to_index(&self, shape: &Shape, op: &'static str) -> Result { + let dim = *self; + if dim >= shape.dims().len() { + Err(Error::DimOutOfRange { + shape: shape.clone(), + dim, + op, + })? + } else { + Ok(dim) + } + } +} + +pub enum D { + Minus1, + Minus2, +} + +impl Dim for D { + fn to_index(&self, shape: &Shape, op: &'static str) -> Result { + let rank = shape.rank(); + match self { + Self::Minus1 if rank >= 1 => Ok(rank - 1), + Self::Minus2 if rank >= 2 => Ok(rank - 2), + _ => Err(Error::DimOutOfRange { + shape: shape.clone(), + dim: 42, // TODO: Have an adequate error + op, + }), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 95f663f0..1eb92e6a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,3 +1,4 @@ +use crate::shape::Dim; use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::Arc; @@ -362,9 +363,9 @@ impl Tensor { /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` /// ranges from `start` to `start + len`. - pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result { + pub fn narrow(&self, dim: D, start: usize, len: usize) -> Result { let dims = self.dims(); - self.check_dim(dim, "narrow")?; + let dim = dim.to_index(self.shape(), "narrow")?; if start + len > dims[dim] { Err(Error::NarrowInvalidArgs { shape: self.shape().clone(), @@ -392,8 +393,8 @@ impl Tensor { } } - pub fn softmax(&self, dim: usize) -> Result { - self.check_dim(dim, "softmax")?; + pub fn softmax(&self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "softmax")?; // TODO: unify the two branches. if self.device().is_cuda() { // We do not have a cuda kernel for divide_by_sum_over_dim so split @@ -692,14 +693,22 @@ impl Tensor { self.sum(&dims) } - pub fn flatten(&self, start_dim: Option, end_dim: Option) -> Result { + fn flatten_( + &self, + start_dim: Option, + end_dim: Option, + ) -> Result { if self.rank() == 0 { self.reshape(1) } else { - let start_dim = start_dim.unwrap_or(0); - let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1); - self.check_dim(start_dim, "flatten")?; - self.check_dim(end_dim, "flatten")?; + let start_dim = match start_dim { + None => 0, + Some(dim) => dim.to_index(self.shape(), "flatten")?, + }; + let end_dim = match end_dim { + None => self.rank() - 1, + Some(dim) => dim.to_index(self.shape(), "flatten")?, + }; if start_dim < end_dim { let dims = self.dims(); let mut dst_dims = dims[..start_dim].to_vec(); @@ -714,8 +723,20 @@ impl Tensor { } } + pub fn flatten(&self, start_dim: D1, end_dim: D2) -> Result { + self.flatten_(Some(start_dim), Some(end_dim)) + } + + pub fn flatten_to(&self, end_dim: D) -> Result { + self.flatten_(None::, Some(end_dim)) + } + + pub fn flatten_from(&self, start_dim: D) -> Result { + self.flatten_(Some(start_dim), None::) + } + pub fn flatten_all(&self) -> Result { - self.flatten(None, None) + self.flatten_(None::, None::) } pub fn get(&self, i: usize) -> Result { @@ -743,9 +764,9 @@ impl Tensor { /// Returns a tensor that is a transposed version of the input, the given dimensions are /// swapped. - pub fn transpose(&self, dim1: usize, dim2: usize) -> Result { - self.check_dim(dim1, "transpose")?; - self.check_dim(dim2, "transpose")?; + pub fn transpose(&self, dim1: D1, dim2: D2) -> Result { + let dim1 = dim1.to_index(self.shape(), "transpose")?; + let dim2 = dim2.to_index(self.shape(), "transpose")?; let op = if self.track_op() { Some(Op::Transpose(self.clone(), dim1, dim2)) } else { @@ -929,23 +950,23 @@ impl Tensor { } } - pub fn squeeze(&self, index: usize) -> Result { + pub fn squeeze(&self, dim: D) -> Result { // The PyTorch semantics are to return the same tensor if the target dimension // does not have a size of 1. let dims = self.dims(); - self.check_dim(index, "squeeze")?; - if dims[index] == 1 { + let dim = dim.to_index(self.shape(), "squeeze")?; + if dims[dim] == 1 { let mut dims = dims.to_vec(); - dims.remove(index); + dims.remove(dim); self.reshape(dims) } else { Ok(self.clone()) } } - pub fn unsqueeze(&self, index: usize) -> Result { + pub fn unsqueeze(&self, dim: usize) -> Result { let mut dims = self.dims().to_vec(); - dims.insert(index, 1); + dims.insert(dim, 1); self.reshape(dims) } diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 4396326d..11d01a6a 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -386,12 +386,12 @@ impl BertSelfAttention { let attention_scores = query_layer.matmul(&key_layer.t()?)?; let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; - let attention_probs = attention_scores.softmax(attention_scores.rank() - 1)?; + let attention_probs = attention_scores.softmax(candle::D::Minus1)?; let attention_probs = self.dropout.forward(&attention_probs)?; let context_layer = attention_probs.matmul(&value_layer)?; let context_layer = context_layer.transpose(1, 2)?.contiguous()?; - let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?; + let context_layer = context_layer.flatten_from(candle::D::Minus2)?; Ok(context_layer) } } diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 73db15e0..d254eeed 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -283,19 +283,18 @@ impl CausalSelfAttention { dims.push(v / 2); dims.push(2); let x = x.reshape(dims)?; - let rank = x.rank(); - let re_x = x.narrow(rank - 1, 0, 1)?; - let im_x = x.narrow(rank - 1, 1, 1)?; + let re_x = x.narrow(candle::D::Minus1, 0, 1)?; + let im_x = x.narrow(candle::D::Minus1, 1, 1)?; let re_f = freqs_cis - .narrow(rank - 1, 0, 1)? + .narrow(candle::D::Minus1, 0, 1)? .broadcast_as(re_x.shape())?; let im_f = freqs_cis - .narrow(rank - 1, 1, 1)? + .narrow(candle::D::Minus1, 1, 1)? .broadcast_as(im_x.shape())?; let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?; let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?; - let rope = Tensor::cat(&[&re, &im], rank - 1)?; - let rope = rope.flatten(Some(rope.rank() - 2), None)?; + let rope = Tensor::cat(&[&re, &im], re.rank() - 1)?; + let rope = rope.flatten_from(candle::D::Minus2)?; Ok(rope) } @@ -339,7 +338,7 @@ impl CausalSelfAttention { let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?; let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = att.softmax(att.rank() - 1)?; + let att = att.softmax(candle::D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(0, 1)?.reshape(&[t, c])?; @@ -537,7 +536,7 @@ async fn main() -> Result<()> { let next_token = if let Some(temperature) = args.temperature { println!("Sampling with temperature {temperature:?}"); - let prs = (&logits / temperature)?.softmax(logits.rank() - 1)?; + let prs = (&logits / temperature)?.softmax(candle::D::Minus1)?; let logits_v: Vec = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?; diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 6ea3e536..fad3e91c 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -109,7 +109,7 @@ impl Decode { }; tokens.push(next_token); let prob = logits - .softmax(logits.rank() - 1)? + .softmax(candle::D::Minus1)? .get(next_token as usize)? .to_scalar::()? as f64; if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx { diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index bf322c51..a1973f27 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -342,8 +342,8 @@ impl MultiHeadAttention { let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?; qk = qk.broadcast_add(&mask)? } - let w = qk.softmax(qk.rank() - 1)?; - let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?; + let w = qk.softmax(candle::D::Minus1)?; + let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?; Ok(wv) } }