diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 3cdc538a..fa96ba18 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1371,8 +1371,9 @@ impl Map2 for MatMul { ) -> Result> { use gemm::{gemm, Parallelism}; - if T::DTYPE == DType::BF16 { - return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?; + match T::DTYPE { + DType::F16 | DType::F32 | DType::F64 => {} + _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?, } let (b, m, n, k) = self.0;