Fix the matmul example.

This commit is contained in:
laurent
2023-06-22 21:11:41 +01:00
parent 6463d661d8
commit 2231c717d5

View File

@ -154,21 +154,23 @@ fn gemm_config<T>(
alpha: T,
beta: T,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
rhs_stride: &[usize],
_lhs_stride: &[usize],
_rhs_stride: &[usize],
) -> StridedBatchedConfig<T> {
// TODO: Handle lhs_stride and rhs_stride.
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
use cudarc::cublas::sys::cublasOperation_t;
println!("{:?} {:?} {:?}", lhs_stride, rhs_stride, (b, m, n, k));
// The setup below was copied from:
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
let gemm = GemmConfig {
alpha,
beta,
m: m as i32,
n: n as i32,
m: n as i32,
n: m as i32,
k: k as i32,
lda: m as i32,
lda: n as i32,
ldb: k as i32,
ldc: m as i32,
ldc: n as i32,
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_N,
};
@ -341,7 +343,7 @@ impl CudaStorage {
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, lhs, rhs, &mut out)
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}?;
CudaStorageSlice::F32(out)
}
@ -351,7 +353,7 @@ impl CudaStorage {
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, lhs, rhs, &mut out)
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}?;
CudaStorageSlice::F64(out)
}