From 76dcc7a381872de056c40fb274986cd94d134d96 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 19 Jul 2023 07:18:36 +0200 Subject: [PATCH] Test the broadcasting binary ops. (#196) --- candle-core/src/cpu_backend.rs | 1 + candle-core/tests/tensor_tests.rs | 104 +++++++++++++++++++++++++++++- 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 449c1e8b..015a162d 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -305,6 +305,7 @@ fn binary_map T>( } } +// Similar to binary_map but with vectorized variants. fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])>( lhs_l: &Layout, rhs_l: &Layout, diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 7b73cd7a..7fc8a195 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1,5 +1,5 @@ mod test_utils; -use candle::{DType, Device, Result, Tensor}; +use candle::{DType, Device, IndexOp, Result, Tensor}; use test_utils::to_vec3_round; fn zeros(device: &Device) -> Result<()> { @@ -350,6 +350,107 @@ fn matmul(device: &Device) -> Result<()> { Ok(()) } +fn broadcasting(device: &Device) -> Result<()> { + let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?; + let t2 = Tensor::new(&[100f32, 200f32], device)?; + let s = t1.broadcast_add(&t2.reshape((2, 1))?)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[100.0, 101.0, 102.0], [203.0, 204.0, 205.0]], + [[106.0, 107.0, 108.0], [209.0, 210.0, 211.0]], + [[112.0, 113.0, 114.0], [215.0, 216.0, 217.0]], + [[118.0, 119.0, 120.0], [221.0, 222.0, 223.0]] + ] + ); + let s = t1.t()?.broadcast_add(&t2)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[100.0, 203.0], [101.0, 204.0], [102.0, 205.0]], + [[106.0, 209.0], [107.0, 210.0], [108.0, 211.0]], + [[112.0, 215.0], [113.0, 216.0], [114.0, 217.0]], + [[118.0, 221.0], [119.0, 222.0], [120.0, 223.0]] + ] + ); + let s = t1.broadcast_sub(&t2.reshape((2, 1))?)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[-100.0, -99.0, -98.0], [-197.0, -196.0, -195.0]], + [[-94.0, -93.0, -92.0], [-191.0, -190.0, -189.0]], + [[-88.0, -87.0, -86.0], [-185.0, -184.0, -183.0]], + [[-82.0, -81.0, -80.0], [-179.0, -178.0, -177.0]] + ] + ); + let s = t1.t()?.broadcast_sub(&t2)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[-100.0, -197.0], [-99.0, -196.0], [-98.0, -195.0]], + [[-94.0, -191.0], [-93.0, -190.0], [-92.0, -189.0]], + [[-88.0, -185.0], [-87.0, -184.0], [-86.0, -183.0]], + [[-82.0, -179.0], [-81.0, -178.0], [-80.0, -177.0]] + ] + ); + // Test a narrowed version as this uses a layout start_offset. + let t1 = t1.i(2..)?; + let s = t1.broadcast_add(&t2.reshape((2, 1))?)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[112.0, 113.0, 114.0], [215.0, 216.0, 217.0]], + [[118.0, 119.0, 120.0], [221.0, 222.0, 223.0]] + ] + ); + let s = t1.t()?.broadcast_add(&t2)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[112.0, 215.0], [113.0, 216.0], [114.0, 217.0]], + [[118.0, 221.0], [119.0, 222.0], [120.0, 223.0]] + ] + ); + let s = t1.broadcast_sub(&t2.reshape((2, 1))?)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[-88.0, -87.0, -86.0], [-185.0, -184.0, -183.0]], + [[-82.0, -81.0, -80.0], [-179.0, -178.0, -177.0]] + ] + ); + let s = t1.t()?.broadcast_sub(&t2)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[-88.0, -185.0], [-87.0, -184.0], [-86.0, -183.0]], + [[-82.0, -179.0], [-81.0, -178.0], [-80.0, -177.0]] + ] + ); + let t3 = Tensor::new(1f32, device)?.broadcast_div(&t2)?; + let s = t1.broadcast_mul(&t2.reshape((2, 1))?)?; + let s_div = t1.broadcast_div(&t3.reshape((2, 1))?)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[1200.0, 1300.0, 1400.0], [3000.0, 3200.0, 3400.0]], + [[1800.0, 1900.0, 2000.0], [4200.0, 4400.0, 4600.0]] + ] + ); + assert_eq!(s.to_vec3::()?, s_div.to_vec3::()?,); + let s = t1.t()?.broadcast_mul(&t2)?; + let s_div = t1.t()?.broadcast_div(&t3)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[1200.0, 3000.0], [1300.0, 3200.0], [1400.0, 3400.0]], + [[1800.0, 4200.0], [1900.0, 4400.0], [2000.0, 4600.0]] + ] + ); + assert_eq!(s.to_vec3::()?, s_div.to_vec3::()?,); + Ok(()) +} + test_device!(zeros, zeros_cpu, zeros_gpu); test_device!(add_mul, add_mul_cpu, add_mul_gpu); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu); @@ -362,3 +463,4 @@ test_device!(binary_op, binary_op_cpu, binary_op_gpu); test_device!(softmax, softmax_cpu, softmax_gpu); test_device!(embeddings, embeddings_cpu, embeddings_gpu); test_device!(matmul, matmul_cpu, matmul_gpu); +test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);