diff --git a/src/shape.rs b/src/shape.rs index ebc497cf..aa66e706 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -1,7 +1,7 @@ use crate::{Error, Result}; #[derive(Clone, PartialEq, Eq)] -pub struct Shape(pub(crate) Vec); +pub struct Shape(Vec); impl std::fmt::Debug for Shape { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -63,6 +63,12 @@ impl From<(usize, usize, usize)> for Shape { } } +impl From> for Shape { + fn from(dims: Vec) -> Self { + Self(dims) + } +} + macro_rules! extract_dims { ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => { pub fn $fn_name(&self) -> Result<$out_type> { @@ -142,6 +148,11 @@ impl Shape { } true } + + pub fn extend(mut self, additional_dims: &[usize]) -> Self { + self.0.extend(additional_dims); + self + } } #[cfg(test)] diff --git a/src/tensor.rs b/src/tensor.rs index 161a4787..66807594 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -283,9 +283,8 @@ impl Tensor { }); } - let mut c_shape: Vec<_> = a_dims[..dim - 2].into(); - c_shape.extend(&[m, n]); - let c_shape = Shape(c_shape); + let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); + let c_stride = c_shape.stride_contiguous(); let batching: usize = a_dims[..dim - 2].iter().product(); let storage = self.storage.matmul_impl( @@ -297,8 +296,8 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage, - shape: c_shape.clone(), - stride: c_shape.stride_contiguous(), + shape: c_shape, + stride: c_stride, op: Some(Op::Matmul(self.clone(), rhs.clone())), is_variable: false, }; @@ -414,7 +413,6 @@ impl Tensor { pub fn t(&self) -> Result { let mut stride = self.stride().to_vec(); - let mut shape = self.shape().clone(); let n = stride.len(); if n < 2 { return Err(Error::UnexpectedNumberOfDims { @@ -423,12 +421,13 @@ impl Tensor { shape: self.shape().clone(), }); } - (shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]); + let mut dims = self.shape().dims().to_vec(); + (dims[n - 2], dims[n - 1]) = (dims[n - 1], dims[n - 2]); (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]); let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape, + shape: Shape::from(dims), stride, // TODO The op should have a backward op: None,