From c44e5346f40f3825c60dc8cab113867753916400 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 27 Jun 2023 17:37:09 +0100 Subject: [PATCH] Add some helper functions. --- candle-core/examples/llama/main.rs | 7 +------ candle-core/src/tensor.rs | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index baf0cdb8..eb681f4b 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -306,12 +306,7 @@ impl CausalSelfAttention { let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?; let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?; let rope = Tensor::cat(&[&re, &im], rank - 1)?; - // TODO: Add the flatten op. - let mut dims = rope.dims().to_vec(); - let v1 = dims.pop().unwrap(); - let v2 = dims.pop().unwrap(); - dims.push(v1 * v2); - let rope = rope.reshape(dims)?; + let rope = rope.flatten(Some(rope.rank() - 2), None)?; Ok(rope) } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 43007ef5..b90bb2d6 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -652,6 +652,31 @@ impl Tensor { &self.op } + pub fn sum_all(&self) -> Result { + let dims: Vec<_> = (0..self.rank()).collect(); + self.sum(&dims) + } + + pub fn flatten(&self, start_dim: Option, end_dim: Option) -> Result { + let start_dim = start_dim.unwrap_or(0); + let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1); + if start_dim < end_dim { + let dims = self.dims(); + let mut dst_dims = dims[..start_dim].to_vec(); + dst_dims.push(dims[start_dim..end_dim + 1].iter().product::()); + if end_dim + 1 < dims.len() { + dst_dims.extend(&dims[end_dim + 1..]); + } + self.reshape(dst_dims) + } else { + Ok(self.clone()) + } + } + + pub fn flatten_all(&self) -> Result { + self.flatten(None, None) + } + /// Returns a tensor that is a transposed version of the input, the two last dimensions of the /// input are swapped. pub fn t(&self) -> Result {