From f135b7963d8817d23b4069132713f6ad1a5e6af2 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 15 Apr 2024 20:00:28 +0200 Subject: [PATCH] Fix for the batch dim in the quantized matmul example. (#2073) * Fix for the batch dim in the quantized matmul example. * Enable more tests on cuda. * Add a test for qmm with a batch. * Fix the zeros-dim test on metal. --- candle-core/src/metal_backend/device.rs | 2 +- candle-core/src/quantized/cuda.rs | 2 +- candle-core/tests/quantized_tests.rs | 72 ++++++++++++------------- 3 files changed, 38 insertions(+), 38 deletions(-) diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index fdeca13f..44af7649 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -283,5 +283,5 @@ impl MetalDevice { } fn buf_size(size: NSUInteger) -> NSUInteger { - (size - 1).next_power_of_two() as NSUInteger + size.saturating_sub(1).next_power_of_two() as NSUInteger } diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 3c827e59..54b1da41 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -464,7 +464,7 @@ impl QCudaStorage { /* x_rows */ n, /* x_cols */ k, /* y_rows */ k, - /* y_cols */ m, + /* y_cols */ b * m, self.device(), )? }; diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index a2629341..223accc4 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -3,7 +3,7 @@ use candle_core::{ quantized::{self, GgmlDType}, test_device, test_utils::to_vec2_round, - Device, Module, Result, Tensor, + Device, IndexOp, Module, Result, Tensor, }; use quantized::{k_quants, GgmlType}; use rand::prelude::*; @@ -47,18 +47,14 @@ fn test_matmul( } fn quantized_matmul(device: &Device) -> Result<()> { - // TODO Enable this later when we enable cuda. - if device.is_cuda() { - return Ok(()); - } let (m, k, n) = (3, 64, 4); - let lhs = (0..(m * k)).map(|v| v as f32).collect::>(); - let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?; + let lhs_s = (0..(m * k)).map(|v| v as f32).collect::>(); + let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?; let mut dst = vec![42.; 3 * 4]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let rhs = (0..(k * n)).map(|v| v as f32).collect::>(); k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; - k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; + k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; assert_eq!( dst.iter().map(|x| x.round()).collect::>(), &[ @@ -67,7 +63,7 @@ fn quantized_matmul(device: &Device) -> Result<()> { ] ); let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; - let mm = tensor_lhs.matmul(&tensor_rhs)?; + let mm = lhs.matmul(&tensor_rhs)?; assert_eq!( mm.to_vec2::()?, &[ @@ -79,7 +75,7 @@ fn quantized_matmul(device: &Device) -> Result<()> { let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?; - let res = matmul.forward(&tensor_lhs)?; + let res = matmul.forward(&lhs)?; match device { Device::Metal(_) => assert_eq!( to_vec2_round(&res, 0)?, @@ -89,7 +85,15 @@ fn quantized_matmul(device: &Device) -> Result<()> { [341970.0, 994574.0, 1656181.0, 2302182.0] ] ), - _ => assert_eq!( + Device::Cuda(_) => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [84866.0, 214045.0, 344676.0, 473707.0], + [213425.0, 604313.0, 1000431.0, 1387960.0], + [342030.0, 994630.0, 1656248.0, 2302250.0] + ] + ), + Device::Cpu => assert_eq!( to_vec2_round(&res, 0)?, &[ [85120.0, 214562.0, 345455.0, 474748.0], @@ -98,22 +102,16 @@ fn quantized_matmul(device: &Device) -> Result<()> { ] ), } - test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?; - Ok(()) } fn quantized_matmul_neg(device: &Device) -> Result<()> { - // TODO Enable this later when we enable cuda. - if device.is_cuda() { - return Ok(()); - } let (m, k, n) = (3, 64, 4); - let lhs = (0..(m * k)) + let lhs_s = (0..(m * k)) .map(|v| v as f32 - (m * k) as f32 / 2.0) .collect::>(); - let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?; + let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?; let mut dst = vec![42.; 3 * 4]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let rhs = (0..k * n) @@ -121,7 +119,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> { .collect::>(); let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; - k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; + k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; assert_eq!( dst.iter().map(|x| x.round()).collect::>(), &[ @@ -129,7 +127,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> { -196472.0, 63012.0, 324585.0, 587902.0 ] ); - let mm = tensor_lhs.matmul(&tensor_rhs)?; + let mm = lhs.matmul(&tensor_rhs)?; assert_eq!( to_vec2_round(&mm, 0)?, &[ @@ -141,7 +139,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> { let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?; - let res = matmul.forward(&tensor_lhs)?; + let res = matmul.forward(&lhs)?; match device { Device::Metal(_) => assert_eq!( to_vec2_round(&res, 0)?, @@ -151,7 +149,15 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> { [-196102.0, 63022.0, 324233.0, 587191.0] ] ), - _ => assert_eq!( + Device::Cuda(_) => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [243740.0, -19762.0, -285476.0, -550498.0], + [23774.0, 21645.0, 19395.0, 18364.0], + [-196045.0, 63030.0, 324120.0, 587079.0] + ] + ), + Device::Cpu => assert_eq!( to_vec2_round(&res, 0)?, &[ [243524.0, -19596.0, -285051.0, -549815.0], @@ -160,22 +166,16 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> { ] ), } - + let lhs2 = Tensor::stack(&[&lhs, &lhs], 0)?; + let res2 = matmul.forward(&lhs2)?; + let res2 = res2.i(1)?; + let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, 0.); Ok(()) } -test_device!( - quantized_matmul, - quantized_matmul_cpu, - quantized_matmul_cuda, - quantized_matmul_metal -); -test_device!( - quantized_matmul_neg, - quantized_matmul_neg_cpu, - quantized_matmul_neg_cuda, - quantized_matmul_neg_metal -); +test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal); +test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal); fn quantize_q4_0(device: &Device) -> Result<()> { let src = (0..32 * 4).map(|v| v as f32).collect::>();