mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the broadcast operator.
This commit is contained in:
@ -138,3 +138,14 @@ fn narrow() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn broadcast() -> Result<()> {
|
||||
let data = &[3f32, 1., 4.];
|
||||
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
tensor.broadcast((3, 1))?.to_vec3::<f32>()?,
|
||||
&[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user