Add the broadcast operator.

This commit is contained in:
laurent
2023-06-24 19:16:03 +01:00
parent 96c098b6cd
commit 6b2cd9c51c
4 changed files with 44 additions and 0 deletions

View File

@ -22,6 +22,7 @@ pub(crate) enum Op {
add: f64,
},
ToDType(Tensor),
Broadcast(Tensor),
Exp(Tensor),
Log(Tensor),
Sin(Tensor),

View File

@ -94,6 +94,10 @@ impl Shape {
self.0.len()
}
pub fn into_dims(self) -> Vec<usize> {
self.0
}
pub fn dims(&self) -> &[usize] {
&self.0
}

View File

@ -653,6 +653,30 @@ impl Tensor {
}
}
/// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted
/// on the left.
pub fn broadcast<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
let left_shape = left_shape.into();
let op = if self.track_op() {
Some(Op::Broadcast(self.clone()))
} else {
None
};
let mut stride = vec![0; left_shape.rank()];
stride.extend_from_slice(&self.stride);
let mut dims = left_shape.into_dims();
dims.extend(self.shape.dims());
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
shape: Shape::from(dims),
stride,
op,
is_variable: self.is_variable,
};
Ok(Tensor(Arc::new(tensor_)))
}
pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
let shape = self.shape();
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
@ -849,6 +873,7 @@ impl Tensor {
}
}
Op::Reshape(node)
| Op::Broadcast(node)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@ -978,6 +1003,9 @@ impl Tensor {
start_idx += len;
}
}
Op::Broadcast(_arg) => {
return Err(Error::BackwardNotSupported { op: "broadcast" })
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?

View File

@ -138,3 +138,14 @@ fn narrow() -> Result<()> {
);
Ok(())
}
#[test]
fn broadcast() -> Result<()> {
let data = &[3f32, 1., 4.];
let tensor = Tensor::new(data, &Device::Cpu)?;
assert_eq!(
tensor.broadcast((3, 1))?.to_vec3::<f32>()?,
&[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]]
);
Ok(())
}