mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Cublas fixes.
This commit is contained in:
@ -214,11 +214,11 @@ fn gemm_config<T>(
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// The a tensor has dims batching, m, k
|
||||
let transa = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
cublasOperation_t::CUBLAS_OP_N
|
||||
} else if rhs_m1 == m && rhs_m2 == 1 {
|
||||
cublasOperation_t::CUBLAS_OP_T
|
||||
// The a tensor has dims batching, k, n (rhs)
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||
} else {
|
||||
Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
@ -226,11 +226,11 @@ fn gemm_config<T>(
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
// The b tensor has dims batching, k, n
|
||||
let transb = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
cublasOperation_t::CUBLAS_OP_N
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
cublasOperation_t::CUBLAS_OP_T
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||
} else {
|
||||
Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
@ -246,8 +246,8 @@ fn gemm_config<T>(
|
||||
m: n as i32,
|
||||
n: m as i32,
|
||||
k: k as i32,
|
||||
lda: n as i32,
|
||||
ldb: k as i32,
|
||||
lda,
|
||||
ldb,
|
||||
ldc: n as i32,
|
||||
transa,
|
||||
transb,
|
||||
|
Reference in New Issue
Block a user