More general broadcast setup.

This commit is contained in:
laurent
2023-06-25 08:55:09 +01:00
parent 213445c0e5
commit 7ccf27dda2
3 changed files with 48 additions and 13 deletions

View File

@ -90,6 +90,9 @@ pub enum Error {
/// I/O error.
#[error(transparent)]
Io(#[from] std::io::Error),
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@ -605,8 +605,8 @@ impl Tensor {
storage: Arc::new(self.storage.try_clone()?),
shape: self.shape.clone(),
stride: self.stride.clone(),
op: self.op.clone(),
is_variable: self.is_variable,
op: None, // TODO
is_variable: false,
};
Ok(Tensor(Arc::new(tensor_)))
}
@ -654,7 +654,7 @@ impl Tensor {
shape: self.shape.clone(),
stride: self.stride.clone(),
op,
is_variable: self.is_variable,
is_variable: false,
};
Ok(Tensor(Arc::new(tensor_)))
}
@ -662,28 +662,60 @@ impl Tensor {
/// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted
/// on the left.
pub fn broadcast<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
let left_shape = left_shape.into();
let mut dims = left_shape.into_dims();
dims.extend(self.shape.dims());
self.broadcast_as(dims)
}
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
let op = if self.track_op() {
Some(Op::Broadcast(self.clone()))
} else {
None
};
let mut stride = vec![0; left_shape.rank()];
stride.extend_from_slice(&self.stride);
let mut dims = left_shape.into_dims();
dims.extend(self.shape.dims());
let shape = shape.into();
if shape.rank() < self.rank() {
return Err(Error::BroadcastIncompatibleShapes {
src_shape: self.shape().clone(),
dst_shape: shape,
});
}
let added_dims = shape.rank() - self.rank();
let mut stride = vec![0; added_dims];
for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
.iter()
.zip(self.dims().iter().zip(self.stride()))
{
let s = if dst_dim == src_dim {
src_stride
} else if src_dim != 1 {
return Err(Error::BroadcastIncompatibleShapes {
src_shape: self.shape().clone(),
dst_shape: shape,
});
} else {
0
};
stride.push(s)
}
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
shape: Shape::from(dims),
shape,
stride,
op,
is_variable: self.is_variable,
is_variable: false,
};
Ok(Tensor(Arc::new(tensor_)))
}
/// An alias for broadcast_as.
pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
self.broadcast_as(shape)
}
pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
let shape = self.shape();
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
@ -706,8 +738,8 @@ impl Tensor {
Ok(from_storage(
storage,
shape.clone(),
self.op.clone(),
self.is_variable,
None, // TODO
false,
))
}
}

View File

@ -144,7 +144,7 @@ fn broadcast() -> Result<()> {
let data = &[3f32, 1., 4.];
let tensor = Tensor::new(data, &Device::Cpu)?;
assert_eq!(
tensor.broadcast((3, 1))?.to_vec3::<f32>()?,
tensor.broadcast_left((3, 1))?.to_vec3::<f32>()?,
&[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]]
);
Ok(())