mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Test the broadcasting binary ops. (#196)
This commit is contained in:
@ -305,6 +305,7 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
}
|
||||
}
|
||||
|
||||
// Similar to binary_map but with vectorized variants.
|
||||
fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
|
@ -1,5 +1,5 @@
|
||||
mod test_utils;
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use test_utils::to_vec3_round;
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
@ -350,6 +350,107 @@ fn matmul(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcasting(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
||||
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
||||
let s = t1.broadcast_add(&t2.reshape((2, 1))?)?;
|
||||
assert_eq!(
|
||||
s.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[100.0, 101.0, 102.0], [203.0, 204.0, 205.0]],
|
||||
[[106.0, 107.0, 108.0], [209.0, 210.0, 211.0]],
|
||||
[[112.0, 113.0, 114.0], [215.0, 216.0, 217.0]],
|
||||
[[118.0, 119.0, 120.0], [221.0, 222.0, 223.0]]
|
||||
]
|
||||
);
|
||||
let s = t1.t()?.broadcast_add(&t2)?;
|
||||
assert_eq!(
|
||||
s.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[100.0, 203.0], [101.0, 204.0], [102.0, 205.0]],
|
||||
[[106.0, 209.0], [107.0, 210.0], [108.0, 211.0]],
|
||||
[[112.0, 215.0], [113.0, 216.0], [114.0, 217.0]],
|
||||
[[118.0, 221.0], [119.0, 222.0], [120.0, 223.0]]
|
||||
]
|
||||
);
|
||||
let s = t1.broadcast_sub(&t2.reshape((2, 1))?)?;
|
||||
assert_eq!(
|
||||
s.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[-100.0, -99.0, -98.0], [-197.0, -196.0, -195.0]],
|
||||
[[-94.0, -93.0, -92.0], [-191.0, -190.0, -189.0]],
|
||||
[[-88.0, -87.0, -86.0], [-185.0, -184.0, -183.0]],
|
||||
[[-82.0, -81.0, -80.0], [-179.0, -178.0, -177.0]]
|
||||
]
|
||||
);
|
||||
let s = t1.t()?.broadcast_sub(&t2)?;
|
||||
assert_eq!(
|
||||
s.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[-100.0, -197.0], [-99.0, -196.0], [-98.0, -195.0]],
|
||||
[[-94.0, -191.0], [-93.0, -190.0], [-92.0, -189.0]],
|
||||
[[-88.0, -185.0], [-87.0, -184.0], [-86.0, -183.0]],
|
||||
[[-82.0, -179.0], [-81.0, -178.0], [-80.0, -177.0]]
|
||||
]
|
||||
);
|
||||
// Test a narrowed version as this uses a layout start_offset.
|
||||
let t1 = t1.i(2..)?;
|
||||
let s = t1.broadcast_add(&t2.reshape((2, 1))?)?;
|
||||
assert_eq!(
|
||||
s.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[112.0, 113.0, 114.0], [215.0, 216.0, 217.0]],
|
||||
[[118.0, 119.0, 120.0], [221.0, 222.0, 223.0]]
|
||||
]
|
||||
);
|
||||
let s = t1.t()?.broadcast_add(&t2)?;
|
||||
assert_eq!(
|
||||
s.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[112.0, 215.0], [113.0, 216.0], [114.0, 217.0]],
|
||||
[[118.0, 221.0], [119.0, 222.0], [120.0, 223.0]]
|
||||
]
|
||||
);
|
||||
let s = t1.broadcast_sub(&t2.reshape((2, 1))?)?;
|
||||
assert_eq!(
|
||||
s.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[-88.0, -87.0, -86.0], [-185.0, -184.0, -183.0]],
|
||||
[[-82.0, -81.0, -80.0], [-179.0, -178.0, -177.0]]
|
||||
]
|
||||
);
|
||||
let s = t1.t()?.broadcast_sub(&t2)?;
|
||||
assert_eq!(
|
||||
s.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[-88.0, -185.0], [-87.0, -184.0], [-86.0, -183.0]],
|
||||
[[-82.0, -179.0], [-81.0, -178.0], [-80.0, -177.0]]
|
||||
]
|
||||
);
|
||||
let t3 = Tensor::new(1f32, device)?.broadcast_div(&t2)?;
|
||||
let s = t1.broadcast_mul(&t2.reshape((2, 1))?)?;
|
||||
let s_div = t1.broadcast_div(&t3.reshape((2, 1))?)?;
|
||||
assert_eq!(
|
||||
s.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[1200.0, 1300.0, 1400.0], [3000.0, 3200.0, 3400.0]],
|
||||
[[1800.0, 1900.0, 2000.0], [4200.0, 4400.0, 4600.0]]
|
||||
]
|
||||
);
|
||||
assert_eq!(s.to_vec3::<f32>()?, s_div.to_vec3::<f32>()?,);
|
||||
let s = t1.t()?.broadcast_mul(&t2)?;
|
||||
let s_div = t1.t()?.broadcast_div(&t3)?;
|
||||
assert_eq!(
|
||||
s.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[1200.0, 3000.0], [1300.0, 3200.0], [1400.0, 3400.0]],
|
||||
[[1800.0, 4200.0], [1900.0, 4400.0], [2000.0, 4600.0]]
|
||||
]
|
||||
);
|
||||
assert_eq!(s.to_vec3::<f32>()?, s_div.to_vec3::<f32>()?,);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
@ -362,3 +463,4 @@ test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(softmax, softmax_cpu, softmax_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||
|
Reference in New Issue
Block a user