diff --git a/src/op.rs b/src/op.rs index 523154fb..fa1373f6 100644 --- a/src/op.rs +++ b/src/op.rs @@ -22,6 +22,7 @@ pub(crate) enum Op { add: f64, }, ToDType(Tensor), + Broadcast(Tensor), Exp(Tensor), Log(Tensor), Sin(Tensor), diff --git a/src/shape.rs b/src/shape.rs index aa66e706..88430c9d 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -94,6 +94,10 @@ impl Shape { self.0.len() } + pub fn into_dims(self) -> Vec { + self.0 + } + pub fn dims(&self) -> &[usize] { &self.0 } diff --git a/src/tensor.rs b/src/tensor.rs index b25a23c2..17a9d0ae 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -653,6 +653,30 @@ impl Tensor { } } + /// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted + /// on the left. + pub fn broadcast>(&self, left_shape: S) -> Result { + let left_shape = left_shape.into(); + let op = if self.track_op() { + Some(Op::Broadcast(self.clone())) + } else { + None + }; + let mut stride = vec![0; left_shape.rank()]; + stride.extend_from_slice(&self.stride); + let mut dims = left_shape.into_dims(); + dims.extend(self.shape.dims()); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + shape: Shape::from(dims), + stride, + op, + is_variable: self.is_variable, + }; + Ok(Tensor(Arc::new(tensor_))) + } + pub fn to_dtype(&self, dtype: DType) -> Result { let shape = self.shape(); let storage = self.storage.to_dtype(shape, self.stride(), dtype)?; @@ -849,6 +873,7 @@ impl Tensor { } } Op::Reshape(node) + | Op::Broadcast(node) | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) @@ -978,6 +1003,9 @@ impl Tensor { start_idx += len; } } + Op::Broadcast(_arg) => { + return Err(Error::BackwardNotSupported { op: "broadcast" }) + } Op::ToDType(arg) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)? diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index a12f2b4d..c5a332e7 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -138,3 +138,14 @@ fn narrow() -> Result<()> { ); Ok(()) } + +#[test] +fn broadcast() -> Result<()> { + let data = &[3f32, 1., 4.]; + let tensor = Tensor::new(data, &Device::Cpu)?; + assert_eq!( + tensor.broadcast((3, 1))?.to_vec3::()?, + &[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]] + ); + Ok(()) +}