From 70f38c20693ad26c6fffa359e734bde92124c0dd Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 11 Sep 2023 12:10:51 +0100 Subject: [PATCH] Proper error on unsupported dtypes when using gemm. (#813) --- candle-core/src/cpu_backend.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 3cdc538a..fa96ba18 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1371,8 +1371,9 @@ impl Map2 for MatMul { ) -> Result> { use gemm::{gemm, Parallelism}; - if T::DTYPE == DType::BF16 { - return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?; + match T::DTYPE { + DType::F16 | DType::F32 | DType::F64 => {} + _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?, } let (b, m, n, k) = self.0;