Fix the matmul example.

This commit is contained in:
laurent
2023-06-22 21:11:41 +01:00
parent 6463d661d8
commit 2231c717d5

View File

@ -154,21 +154,23 @@ fn gemm_config<T>(
alpha: T, alpha: T,
beta: T, beta: T,
(b, m, n, k): (usize, usize, usize, usize), (b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize], _lhs_stride: &[usize],
rhs_stride: &[usize], _rhs_stride: &[usize],
) -> StridedBatchedConfig<T> { ) -> StridedBatchedConfig<T> {
// TODO: Handle lhs_stride and rhs_stride.
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
use cudarc::cublas::sys::cublasOperation_t; use cudarc::cublas::sys::cublasOperation_t;
println!("{:?} {:?} {:?}", lhs_stride, rhs_stride, (b, m, n, k)); // The setup below was copied from:
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
let gemm = GemmConfig { let gemm = GemmConfig {
alpha, alpha,
beta, beta,
m: m as i32, m: n as i32,
n: n as i32, n: m as i32,
k: k as i32, k: k as i32,
lda: m as i32, lda: n as i32,
ldb: k as i32, ldb: k as i32,
ldc: m as i32, ldc: n as i32,
transa: cublasOperation_t::CUBLAS_OP_N, transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_N, transb: cublasOperation_t::CUBLAS_OP_N,
}; };
@ -341,7 +343,7 @@ impl CudaStorage {
unsafe { unsafe {
self.device self.device
.blas .blas
.gemm_strided_batched(cfg, lhs, rhs, &mut out) .gemm_strided_batched(cfg, rhs, lhs, &mut out)
}?; }?;
CudaStorageSlice::F32(out) CudaStorageSlice::F32(out)
} }
@ -351,7 +353,7 @@ impl CudaStorage {
unsafe { unsafe {
self.device self.device
.blas .blas
.gemm_strided_batched(cfg, lhs, rhs, &mut out) .gemm_strided_batched(cfg, rhs, lhs, &mut out)
}?; }?;
CudaStorageSlice::F64(out) CudaStorageSlice::F64(out)
} }