mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Cublas fixes.
This commit is contained in:
@ -214,11 +214,11 @@ fn gemm_config<T>(
|
|||||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
// The a tensor has dims batching, m, k
|
// The a tensor has dims batching, k, n (rhs)
|
||||||
let transa = if lhs_m1 == 1 && lhs_m2 == k {
|
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
cublasOperation_t::CUBLAS_OP_N
|
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||||
} else if rhs_m1 == m && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
cublasOperation_t::CUBLAS_OP_T
|
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||||
} else {
|
} else {
|
||||||
Err(CudaError::MatMulNonContiguous {
|
Err(CudaError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
@ -226,11 +226,11 @@ fn gemm_config<T>(
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?
|
})?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, k, n
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let transb = if rhs_m1 == 1 && rhs_m2 == n {
|
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
cublasOperation_t::CUBLAS_OP_N
|
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
cublasOperation_t::CUBLAS_OP_T
|
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||||
} else {
|
} else {
|
||||||
Err(CudaError::MatMulNonContiguous {
|
Err(CudaError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
@ -246,8 +246,8 @@ fn gemm_config<T>(
|
|||||||
m: n as i32,
|
m: n as i32,
|
||||||
n: m as i32,
|
n: m as i32,
|
||||||
k: k as i32,
|
k: k as i32,
|
||||||
lda: n as i32,
|
lda,
|
||||||
ldb: k as i32,
|
ldb,
|
||||||
ldc: n as i32,
|
ldc: n as i32,
|
||||||
transa,
|
transa,
|
||||||
transb,
|
transb,
|
||||||
|
Reference in New Issue
Block a user