From aebffcfc1367b87941b70fba9b13d02a10f98809 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 22 Jun 2023 19:44:26 +0100 Subject: [PATCH] Add a matmul cuda example. --- examples/cuda_basics.rs | 9 ++++++--- src/cuda_backend.rs | 9 +++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/cuda_basics.rs b/examples/cuda_basics.rs index 3db613f6..cdb0ac94 100644 --- a/examples/cuda_basics.rs +++ b/examples/cuda_basics.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use candle::{DType, Device, Tensor}; +use candle::{Device, Tensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; @@ -9,7 +9,10 @@ fn main() -> Result<()> { let z = (y + x * 3.)?; println!("{:?}", z.to_vec1::()?); println!("{:?}", z.sqrt()?.to_vec1::()?); - let x = Tensor::ones((3, 2), DType::F32, &device)?; - println!("{:?}", x.to_vec2::()?); + let x = Tensor::new(&[[11f32, 22.], [33., 44.], [55., 66.], [77., 78.]], &device)?; + let y = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &device)?; + println!("{:?}", y.to_vec2::()?); + let z = x.matmul(&y)?; + println!("{:?}", z.to_vec2::()?); Ok(()) } diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 806d6c26..b474fe37 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -159,14 +159,15 @@ fn gemm_config( ) -> StridedBatchedConfig { // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm use cudarc::cublas::sys::cublasOperation_t; + println!("{:?} {:?} {:?}", lhs_stride, rhs_stride, (b, m, n, k)); let gemm = GemmConfig { alpha, beta, m: m as i32, n: n as i32, k: k as i32, - lda: lhs_stride[lhs_stride.len() - 2] as i32, - ldb: rhs_stride[rhs_stride.len() - 2] as i32, + lda: m as i32, + ldb: k as i32, ldc: m as i32, transa: cublasOperation_t::CUBLAS_OP_N, transb: cublasOperation_t::CUBLAS_OP_N, @@ -174,8 +175,8 @@ fn gemm_config( StridedBatchedConfig { batch_size: b as i32, gemm, - stride_a: lhs_stride[0] as i64, - stride_b: rhs_stride[0] as i64, + stride_a: (m * k) as i64, + stride_b: (n * k) as i64, stride_c: (m * n * k) as i64, } }