mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d. * Move the cat operations. * Avoid transpositions in cat. * Bugfix. * Bugfix for the cuda kernel. * Add a benchmark. * Add more testing. * Test fix. * Faster kernel. * Add the missing kernel. * Tweak the test. * Add a metal kernel. * Fix for the metal kernel. * Get the tests to pass on metal. * Also use this opportunity to fix the metal kernel for ELU. * Add some bf16 kernels. * Clippy fixes.
This commit is contained in:
@ -672,6 +672,31 @@ fn cat(device: &Device) -> Result<()> {
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
||||
]
|
||||
);
|
||||
|
||||
// 3D
|
||||
let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?;
|
||||
let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?;
|
||||
let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?;
|
||||
|
||||
let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
||||
|
||||
let t1 = t1.t()?.contiguous()?.t()?;
|
||||
let t2 = t2.t()?.contiguous()?.t()?;
|
||||
let t3 = t3.t()?.contiguous()?.t()?;
|
||||
let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
||||
|
||||
let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?;
|
||||
assert_eq!(diff.to_vec0::<f32>()?, 104.0);
|
||||
assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::<i64>()?, 0);
|
||||
assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::<i64>()?, 16);
|
||||
assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::<i64>()?, 20);
|
||||
assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::<i64>()?, 44);
|
||||
assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::<i64>()?, 100);
|
||||
assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::<i64>()?, 112);
|
||||
assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::<i64>()?, 101);
|
||||
assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::<i64>()?, 105);
|
||||
assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::<i64>()?, 10013);
|
||||
assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::<i64>()?, 10031);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user