mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Add the cuda dequantize f16 kernels. (#2137)
* Add the cuda dequantize f16 kernels. * Expose the cuda kernels. * Add some testing + fix. * Test the other cases too. * A few more tests. * Add an environment variable to enable the dequantize f16 + matmul behavior.
This commit is contained in:
@ -3,7 +3,7 @@ use candle_core::{
|
||||
quantized::{self, GgmlDType},
|
||||
test_device,
|
||||
test_utils::to_vec2_round,
|
||||
Device, IndexOp, Module, Result, Tensor,
|
||||
DType, Device, IndexOp, Module, Result, Tensor,
|
||||
};
|
||||
use quantized::{k_quants, GgmlType};
|
||||
use rand::prelude::*;
|
||||
@ -225,6 +225,13 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
assert_eq!(
|
||||
dst.to_vec1::<f32>()?,
|
||||
&[
|
||||
@ -251,6 +258,13 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
assert_eq!(
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
@ -277,6 +291,13 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
assert_eq!(
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
@ -303,6 +324,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
assert_eq!(
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
@ -387,6 +415,13 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
||||
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
||||
if error > max_error {
|
||||
bail!(
|
||||
@ -404,6 +439,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -423,6 +465,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -437,6 +486,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -456,6 +512,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -470,6 +533,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -489,6 +559,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -503,6 +580,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -522,6 +606,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -536,6 +627,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -555,6 +653,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -569,6 +674,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -588,6 +700,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
|
Reference in New Issue
Block a user