mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Fix the matmul example.
This commit is contained in:
@ -154,21 +154,23 @@ fn gemm_config<T>(
|
|||||||
alpha: T,
|
alpha: T,
|
||||||
beta: T,
|
beta: T,
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
lhs_stride: &[usize],
|
_lhs_stride: &[usize],
|
||||||
rhs_stride: &[usize],
|
_rhs_stride: &[usize],
|
||||||
) -> StridedBatchedConfig<T> {
|
) -> StridedBatchedConfig<T> {
|
||||||
|
// TODO: Handle lhs_stride and rhs_stride.
|
||||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
|
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
|
||||||
use cudarc::cublas::sys::cublasOperation_t;
|
use cudarc::cublas::sys::cublasOperation_t;
|
||||||
println!("{:?} {:?} {:?}", lhs_stride, rhs_stride, (b, m, n, k));
|
// The setup below was copied from:
|
||||||
|
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
|
||||||
let gemm = GemmConfig {
|
let gemm = GemmConfig {
|
||||||
alpha,
|
alpha,
|
||||||
beta,
|
beta,
|
||||||
m: m as i32,
|
m: n as i32,
|
||||||
n: n as i32,
|
n: m as i32,
|
||||||
k: k as i32,
|
k: k as i32,
|
||||||
lda: m as i32,
|
lda: n as i32,
|
||||||
ldb: k as i32,
|
ldb: k as i32,
|
||||||
ldc: m as i32,
|
ldc: n as i32,
|
||||||
transa: cublasOperation_t::CUBLAS_OP_N,
|
transa: cublasOperation_t::CUBLAS_OP_N,
|
||||||
transb: cublasOperation_t::CUBLAS_OP_N,
|
transb: cublasOperation_t::CUBLAS_OP_N,
|
||||||
};
|
};
|
||||||
@ -341,7 +343,7 @@ impl CudaStorage {
|
|||||||
unsafe {
|
unsafe {
|
||||||
self.device
|
self.device
|
||||||
.blas
|
.blas
|
||||||
.gemm_strided_batched(cfg, lhs, rhs, &mut out)
|
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||||
}?;
|
}?;
|
||||||
CudaStorageSlice::F32(out)
|
CudaStorageSlice::F32(out)
|
||||||
}
|
}
|
||||||
@ -351,7 +353,7 @@ impl CudaStorage {
|
|||||||
unsafe {
|
unsafe {
|
||||||
self.device
|
self.device
|
||||||
.blas
|
.blas
|
||||||
.gemm_strided_batched(cfg, lhs, rhs, &mut out)
|
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||||
}?;
|
}?;
|
||||||
CudaStorageSlice::F64(out)
|
CudaStorageSlice::F64(out)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user