Add the broadcast operator.

This commit is contained in:
laurent
2023-06-24 19:16:03 +01:00
parent 96c098b6cd
commit 6b2cd9c51c
4 changed files with 44 additions and 0 deletions

View File

@ -138,3 +138,14 @@ fn narrow() -> Result<()> {
);
Ok(())
}
#[test]
fn broadcast() -> Result<()> {
let data = &[3f32, 1., 4.];
let tensor = Tensor::new(data, &Device::Cpu)?;
assert_eq!(
tensor.broadcast((3, 1))?.to_vec3::<f32>()?,
&[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]]
);
Ok(())
}