Merge pull request #13 from LaurentMazare/cuda-bugfixes

Fix two cuda bugs (matmul and where_cond).
This commit is contained in:
Laurent Mazare
2023-06-27 11:32:26 +01:00
committed by GitHub
3 changed files with 5 additions and 5 deletions

View File

@ -1,7 +1,7 @@
clean-ptx:
find target -name "*.ptx" -type f -delete
echo "" > kernels/src/lib.rs
touch kernels/build.rs
echo "" > candle-kernels/src/lib.rs
touch candle-kernels/build.rs
clean:
cargo clean

View File

@ -301,8 +301,8 @@ fn gemm_config<T>(
Ok(StridedBatchedConfig {
batch_size: b as i32,
gemm,
stride_a: (m * k) as i64,
stride_b: (n * k) as i64,
stride_a: (n * k) as i64,
stride_b: (m * k) as i64,
stride_c: (m * n) as i64,
})
}

View File

@ -14,7 +14,7 @@ extern "C" __global__ void FN_NAME( \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
const size_t *strides_t = info + 2*num_dims; \
const size_t *strides_f = info + 2*num_dims; \
const size_t *strides_f = info + 3*num_dims; \
if (is_contiguous(num_dims, dims, strides) \
&& is_contiguous(num_dims, dims, strides_f) \
&& is_contiguous(num_dims, dims, strides_t)) { \