diff --git a/src/tensor.rs b/src/tensor.rs index ff6cb3dc..816308e0 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -301,6 +301,17 @@ impl Tensor { self.id } + pub fn is_contiguous(&self) -> bool { + let mut acc = 1; + for (&stride, &dim) in self.stride.iter().zip(self.shape.dims().iter()).rev() { + if stride != acc { + return false; + } + acc *= dim; + } + true + } + /// Return all the nodes that lead to this value in a topologically sorted vec, the first /// elements having dependencies on the latter ones, e.g. the first element if any is the /// argument. @@ -432,3 +443,44 @@ impl Tensor { Ok(grads) } } + +macro_rules! bin_trait { + ($trait:ident, $fn1:ident) => { + impl> std::ops::$trait for Tensor { + type Output = Result; + + fn $fn1(self, rhs: B) -> Self::Output { + Tensor::$fn1(&self, rhs.borrow()) + } + } + + impl> std::ops::$trait for &Tensor { + type Output = Result; + + fn $fn1(self, rhs: B) -> Self::Output { + Tensor::$fn1(&self, rhs.borrow()) + } + } + + impl> std::ops::$trait> for Tensor { + type Output = Result; + + fn $fn1(self, rhs: Result) -> Self::Output { + Tensor::$fn1(&self, rhs?.borrow()) + } + } + + impl> std::ops::$trait> for &Tensor { + type Output = Result; + + fn $fn1(self, rhs: Result) -> Self::Output { + Tensor::$fn1(&self, rhs?.borrow()) + } + } + }; +} + +bin_trait!(Add, add); +bin_trait!(Sub, sub); +bin_trait!(Mul, mul); +bin_trait!(Div, div); diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index bea32336..01f6f66c 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -35,3 +35,21 @@ fn tensor_2d() -> Result<()> { assert_eq!(content, data); Ok(()) } + +#[test] +fn binary_op() -> Result<()> { + let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; + let tensor = Tensor::new(data, Device::Cpu)?; + let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]]; + let tensor2 = Tensor::new(data2, Device::Cpu)?; + let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?; + let dims = tensor.shape().r2()?; + assert_eq!(dims, (2, 5)); + let content: Vec> = tensor.to_vec2()?; + assert_eq!(content[0], [4.125, 1.1666666, 5.7777777, 1.1666666, 7.5]); + assert_eq!(content[1], [3.0, 1.5, 10.5, 12.0, 3.0]); + let tensor = (&tensor - &tensor)?; + let content: Vec> = tensor.to_vec2()?; + assert_eq!(content[0], [0., 0., 0., 0., 0.]); + Ok(()) +}