diff --git a/src/error.rs b/src/error.rs index 7a3b9960..31bcf4c3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -90,6 +90,9 @@ pub enum Error { /// I/O error. #[error(transparent)] Io(#[from] std::io::Error), + + #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] + BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, } pub type Result = std::result::Result; diff --git a/src/tensor.rs b/src/tensor.rs index 5d883329..fb8c0960 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -605,8 +605,8 @@ impl Tensor { storage: Arc::new(self.storage.try_clone()?), shape: self.shape.clone(), stride: self.stride.clone(), - op: self.op.clone(), - is_variable: self.is_variable, + op: None, // TODO + is_variable: false, }; Ok(Tensor(Arc::new(tensor_))) } @@ -654,7 +654,7 @@ impl Tensor { shape: self.shape.clone(), stride: self.stride.clone(), op, - is_variable: self.is_variable, + is_variable: false, }; Ok(Tensor(Arc::new(tensor_))) } @@ -662,28 +662,60 @@ impl Tensor { /// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted /// on the left. - pub fn broadcast>(&self, left_shape: S) -> Result { + pub fn broadcast_left>(&self, left_shape: S) -> Result { let left_shape = left_shape.into(); + let mut dims = left_shape.into_dims(); + dims.extend(self.shape.dims()); + self.broadcast_as(dims) + } + + pub fn broadcast_as>(&self, shape: S) -> Result { let op = if self.track_op() { Some(Op::Broadcast(self.clone())) } else { None }; - let mut stride = vec![0; left_shape.rank()]; - stride.extend_from_slice(&self.stride); - let mut dims = left_shape.into_dims(); - dims.extend(self.shape.dims()); + let shape = shape.into(); + if shape.rank() < self.rank() { + return Err(Error::BroadcastIncompatibleShapes { + src_shape: self.shape().clone(), + dst_shape: shape, + }); + } + let added_dims = shape.rank() - self.rank(); + let mut stride = vec![0; added_dims]; + for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..] + .iter() + .zip(self.dims().iter().zip(self.stride())) + { + let s = if dst_dim == src_dim { + src_stride + } else if src_dim != 1 { + return Err(Error::BroadcastIncompatibleShapes { + src_shape: self.shape().clone(), + dst_shape: shape, + }); + } else { + 0 + }; + stride.push(s) + } let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape: Shape::from(dims), + shape, stride, op, - is_variable: self.is_variable, + is_variable: false, }; Ok(Tensor(Arc::new(tensor_))) } + /// An alias for broadcast_as. + pub fn expand>(&self, shape: S) -> Result { + self.broadcast_as(shape) + } + pub fn to_dtype(&self, dtype: DType) -> Result { let shape = self.shape(); let storage = self.storage.to_dtype(shape, self.stride(), dtype)?; @@ -706,8 +738,8 @@ impl Tensor { Ok(from_storage( storage, shape.clone(), - self.op.clone(), - self.is_variable, + None, // TODO + false, )) } } diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index c5a332e7..1dde769a 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -144,7 +144,7 @@ fn broadcast() -> Result<()> { let data = &[3f32, 1., 4.]; let tensor = Tensor::new(data, &Device::Cpu)?; assert_eq!( - tensor.broadcast((3, 1))?.to_vec3::()?, + tensor.broadcast_left((3, 1))?.to_vec3::()?, &[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]] ); Ok(())