Proper error on unsupported dtypes when using gemm. (#813)

This commit is contained in:
Laurent Mazare
2023-09-11 12:10:51 +01:00
committed by GitHub
parent d7b9fec849
commit 70f38c2069

View File

@ -1371,8 +1371,9 @@ impl Map2 for MatMul {
) -> Result<Vec<T>> {
use gemm::{gemm, Parallelism};
if T::DTYPE == DType::BF16 {
return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?;
match T::DTYPE {
DType::F16 | DType::F32 | DType::F64 => {}
_ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
}
let (b, m, n, k) = self.0;