diff --git a/examples/cuda_basics.rs b/examples/cuda_basics.rs index cdb0ac94..969d6e20 100644 --- a/examples/cuda_basics.rs +++ b/examples/cuda_basics.rs @@ -14,5 +14,13 @@ fn main() -> Result<()> { println!("{:?}", y.to_vec2::()?); let z = x.matmul(&y)?; println!("{:?}", z.to_vec2::()?); + let x = Tensor::new( + &[[11f32, 22.], [33., 44.], [55., 66.], [77., 78.]], + &Device::Cpu, + )?; + let y = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + println!("{:?}", y.to_vec2::()?); + let z = x.matmul(&y)?; + println!("{:?}", z.to_vec2::()?); Ok(()) } diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index b474fe37..7a22df4f 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -177,7 +177,7 @@ fn gemm_config( gemm, stride_a: (m * k) as i64, stride_b: (n * k) as i64, - stride_c: (m * n * k) as i64, + stride_c: (m * n) as i64, } } @@ -332,7 +332,7 @@ impl CudaStorage { lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { - let elem_count = b * m * n * k; + let elem_count = b * m * n; let dev = &self.device; let slice = match (&self.slice, &rhs.slice) { (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {