Test the broadcasting binary ops. (#196)

This commit is contained in:
Laurent Mazare
2023-07-19 07:18:36 +02:00
committed by GitHub
parent fd55fc9592
commit 76dcc7a381
2 changed files with 104 additions and 1 deletions

View File

@ -305,6 +305,7 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
}
}
// Similar to binary_map but with vectorized variants.
fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,

View File

@ -1,5 +1,5 @@
mod test_utils;
use candle::{DType, Device, Result, Tensor};
use candle::{DType, Device, IndexOp, Result, Tensor};
use test_utils::to_vec3_round;
fn zeros(device: &Device) -> Result<()> {
@ -350,6 +350,107 @@ fn matmul(device: &Device) -> Result<()> {
Ok(())
}
fn broadcasting(device: &Device) -> Result<()> {
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
let t2 = Tensor::new(&[100f32, 200f32], device)?;
let s = t1.broadcast_add(&t2.reshape((2, 1))?)?;
assert_eq!(
s.to_vec3::<f32>()?,
&[
[[100.0, 101.0, 102.0], [203.0, 204.0, 205.0]],
[[106.0, 107.0, 108.0], [209.0, 210.0, 211.0]],
[[112.0, 113.0, 114.0], [215.0, 216.0, 217.0]],
[[118.0, 119.0, 120.0], [221.0, 222.0, 223.0]]
]
);
let s = t1.t()?.broadcast_add(&t2)?;
assert_eq!(
s.to_vec3::<f32>()?,
&[
[[100.0, 203.0], [101.0, 204.0], [102.0, 205.0]],
[[106.0, 209.0], [107.0, 210.0], [108.0, 211.0]],
[[112.0, 215.0], [113.0, 216.0], [114.0, 217.0]],
[[118.0, 221.0], [119.0, 222.0], [120.0, 223.0]]
]
);
let s = t1.broadcast_sub(&t2.reshape((2, 1))?)?;
assert_eq!(
s.to_vec3::<f32>()?,
&[
[[-100.0, -99.0, -98.0], [-197.0, -196.0, -195.0]],
[[-94.0, -93.0, -92.0], [-191.0, -190.0, -189.0]],
[[-88.0, -87.0, -86.0], [-185.0, -184.0, -183.0]],
[[-82.0, -81.0, -80.0], [-179.0, -178.0, -177.0]]
]
);
let s = t1.t()?.broadcast_sub(&t2)?;
assert_eq!(
s.to_vec3::<f32>()?,
&[
[[-100.0, -197.0], [-99.0, -196.0], [-98.0, -195.0]],
[[-94.0, -191.0], [-93.0, -190.0], [-92.0, -189.0]],
[[-88.0, -185.0], [-87.0, -184.0], [-86.0, -183.0]],
[[-82.0, -179.0], [-81.0, -178.0], [-80.0, -177.0]]
]
);
// Test a narrowed version as this uses a layout start_offset.
let t1 = t1.i(2..)?;
let s = t1.broadcast_add(&t2.reshape((2, 1))?)?;
assert_eq!(
s.to_vec3::<f32>()?,
&[
[[112.0, 113.0, 114.0], [215.0, 216.0, 217.0]],
[[118.0, 119.0, 120.0], [221.0, 222.0, 223.0]]
]
);
let s = t1.t()?.broadcast_add(&t2)?;
assert_eq!(
s.to_vec3::<f32>()?,
&[
[[112.0, 215.0], [113.0, 216.0], [114.0, 217.0]],
[[118.0, 221.0], [119.0, 222.0], [120.0, 223.0]]
]
);
let s = t1.broadcast_sub(&t2.reshape((2, 1))?)?;
assert_eq!(
s.to_vec3::<f32>()?,
&[
[[-88.0, -87.0, -86.0], [-185.0, -184.0, -183.0]],
[[-82.0, -81.0, -80.0], [-179.0, -178.0, -177.0]]
]
);
let s = t1.t()?.broadcast_sub(&t2)?;
assert_eq!(
s.to_vec3::<f32>()?,
&[
[[-88.0, -185.0], [-87.0, -184.0], [-86.0, -183.0]],
[[-82.0, -179.0], [-81.0, -178.0], [-80.0, -177.0]]
]
);
let t3 = Tensor::new(1f32, device)?.broadcast_div(&t2)?;
let s = t1.broadcast_mul(&t2.reshape((2, 1))?)?;
let s_div = t1.broadcast_div(&t3.reshape((2, 1))?)?;
assert_eq!(
s.to_vec3::<f32>()?,
&[
[[1200.0, 1300.0, 1400.0], [3000.0, 3200.0, 3400.0]],
[[1800.0, 1900.0, 2000.0], [4200.0, 4400.0, 4600.0]]
]
);
assert_eq!(s.to_vec3::<f32>()?, s_div.to_vec3::<f32>()?,);
let s = t1.t()?.broadcast_mul(&t2)?;
let s_div = t1.t()?.broadcast_div(&t3)?;
assert_eq!(
s.to_vec3::<f32>()?,
&[
[[1200.0, 3000.0], [1300.0, 3200.0], [1400.0, 3400.0]],
[[1800.0, 4200.0], [1900.0, 4400.0], [2000.0, 4600.0]]
]
);
assert_eq!(s.to_vec3::<f32>()?, s_div.to_vec3::<f32>()?,);
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
@ -362,3 +463,4 @@ test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(softmax, softmax_cpu, softmax_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu);
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);