From 7ec345c2ebfc8378f0a5bc99c981f516c8dc6964 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 10 Nov 2023 23:31:09 +0100 Subject: [PATCH] Adding the test scaffolding. --- candle-core/src/test_utils.rs | 8 ++- candle-core/tests/conv_tests.rs | 35 ++++++++++--- candle-core/tests/grad_tests.rs | 32 +++++++++--- candle-core/tests/layout_tests.rs | 2 +- candle-core/tests/pool_tests.rs | 10 ++-- candle-core/tests/tensor_tests.rs | 82 ++++++++++++++++++++----------- 6 files changed, 121 insertions(+), 48 deletions(-) diff --git a/candle-core/src/test_utils.rs b/candle-core/src/test_utils.rs index 8ff73fc0..3b8fb904 100644 --- a/candle-core/src/test_utils.rs +++ b/candle-core/src/test_utils.rs @@ -4,7 +4,7 @@ use crate::{Result, Tensor}; macro_rules! test_device { // TODO: Switch to generating the two last arguments automatically once concat_idents is // stable. https://github.com/rust-lang/rust/issues/29599 - ($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => { + ($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => { #[test] fn $test_cpu() -> Result<()> { $fn_name(&Device::Cpu) @@ -15,6 +15,12 @@ macro_rules! test_device { fn $test_cuda() -> Result<()> { $fn_name(&Device::new_cuda(0)?) } + + #[cfg(feature = "metal")] + #[test] + fn $test_metal() -> Result<()> { + $fn_name(&Device::new_metal(0)?) + } }; } diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index a5375c11..39c6cec0 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -563,14 +563,35 @@ fn conv2d_grad(dev: &Device) -> Result<()> { Ok(()) } -test_device!(conv1d, conv1d_cpu, conv1d_gpu); -test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu); -test_device!(conv2d, conv2d_cpu, conv2d_gpu); +test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal); +test_device!( + conv1d_small, + conv1d_small_cpu, + conv1d_small_gpu, + conv1d_small_metal +); +test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal); test_device!( conv2d_non_square, conv2d_non_square_cpu, - conv2d_non_square_gpu + conv2d_non_square_gpu, + conv2d_non_square_metal +); +test_device!( + conv2d_small, + conv2d_small_cpu, + conv2d_small_gpu, + conv2d_small_metal +); +test_device!( + conv2d_smaller, + conv2d_smaller_cpu, + conv2d_smaller_gpu, + conv2d_smaller_metal +); +test_device!( + conv2d_grad, + conv2d_grad_cpu, + conv2d_grad_gpu, + conv2_grad_metal ); -test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu); -test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu); -test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu); diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 6413ea2e..791532f2 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -315,9 +315,29 @@ fn binary_grad(device: &Device) -> Result<()> { Ok(()) } -test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu); -test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu); -test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu); -test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu); -test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu); -test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu); +test_device!( + simple_grad, + simple_grad_cpu, + simple_grad_gpu, + simple_grad_metal +); +test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal); +test_device!( + matmul_grad, + matmul_grad_cpu, + matmul_grad_gpu, + matmul_grad_metal +); +test_device!( + grad_descent, + grad_descent_cpu, + grad_descent_gpu, + grad_descent_metal +); +test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal); +test_device!( + binary_grad, + binary_grad_cpu, + binary_grad_gpu, + binary_grad_metal +); diff --git a/candle-core/tests/layout_tests.rs b/candle-core/tests/layout_tests.rs index 1b29476f..e0618850 100644 --- a/candle-core/tests/layout_tests.rs +++ b/candle-core/tests/layout_tests.rs @@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> { Ok(()) } -test_device!(contiguous, contiguous_cpu, contiguous_gpu); +test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal); #[test] fn strided_blocks() -> Result<()> { diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs index c6db194d..a3708ec4 100644 --- a/candle-core/tests/pool_tests.rs +++ b/candle-core/tests/pool_tests.rs @@ -98,15 +98,17 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> { Ok(()) } -test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu); +test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal); test_device!( avg_pool2d_pytorch, avg_pool2d_pytorch_cpu, - avg_pool2d_pytorch_gpu + avg_pool2d_pytorch_gpu, + avg_pool2d_pytorch_metal ); -test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu); +test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal); test_device!( upsample_nearest2d, upsample_nearest2d_cpu, - upsample_nearest2d_gpu + upsample_nearest2d_gpu, + upsample_nearest2d_metal ); diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index f565972a..eb684909 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1070,35 +1070,59 @@ fn randn(device: &Device) -> Result<()> { Ok(()) } -test_device!(zeros, zeros_cpu, zeros_gpu); -test_device!(ones, ones_cpu, ones_gpu); -test_device!(arange, arange_cpu, arange_gpu); -test_device!(add_mul, add_mul_cpu, add_mul_gpu); -test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu); -test_device!(narrow, narrow_cpu, narrow_gpu); -test_device!(broadcast, broadcast_cpu, broadcast_gpu); -test_device!(cat, cat_cpu, cat_gpu); -test_device!(sum, sum_cpu, sum_gpu); -test_device!(min, min_cpu, min_gpu); -test_device!(max, max_cpu, max_gpu); -test_device!(argmax, argmax_cpu, argmax_gpu); -test_device!(argmin, argmin_cpu, argmin_gpu); -test_device!(transpose, transpose_cpu, transpose_gpu); -test_device!(unary_op, unary_op_cpu, unary_op_gpu); -test_device!(binary_op, binary_op_cpu, binary_op_gpu); -test_device!(embeddings, embeddings_cpu, embeddings_gpu); -test_device!(cmp, cmp_cpu, cmp_gpu); -test_device!(matmul, matmul_cpu, matmul_gpu); -test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu); -test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu); -test_device!(index_select, index_select_cpu, index_select_gpu); -test_device!(index_add, index_add_cpu, index_add_gpu); -test_device!(gather, gather_cpu, gather_gpu); -test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu); -test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu); -test_device!(randn, randn_cpu, randn_gpu); -test_device!(clamp, clamp_cpu, clamp_gpu); -test_device!(var, var_cpu, var_gpu); +test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal); +test_device!(ones, ones_cpu, ones_gpu, ones_metal); +test_device!(arange, arange_cpu, arange_gpu, arange_metal); +test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal); +test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal); +test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal); +test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal); +test_device!(cat, cat_cpu, cat_gpu, cat_metal); +test_device!(sum, sum_cpu, sum_gpu, sum_metal); +test_device!(min, min_cpu, min_gpu, min_metal); +test_device!(max, max_cpu, max_gpu, max_metal); +test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal); +test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal); +test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal); +test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal); +test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal); +test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal); +test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal); +test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal); +test_device!( + broadcast_matmul, + broadcast_matmul_cpu, + broadcast_matmul_gpu, + broadcast_matmul_metal +); +test_device!( + broadcasting, + broadcasting_cpu, + broadcasting_gpu, + broadcasting_metal +); +test_device!( + index_select, + index_select_cpu, + index_select_gpu, + index_select_metal +); +test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal); +test_device!(gather, gather_cpu, gather_gpu, gather_metal); +test_device!( + scatter_add, + scatter_add_cpu, + scatter_add_gpu, + scatter_add_metal +); +test_device!( + slice_scatter, + slice_scatter_cpu, + slice_scatter_gpu, + slice_scatter_metal +); +test_device!(randn, randn_cpu, randn_gpu, randn_metal); +test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal); // There was originally a bug on the CPU implementation for randn // https://github.com/huggingface/candle/issues/381