Fix two cuda bugs (matmul and where_cond).

This commit is contained in:
laurent
2023-06-27 11:31:04 +01:00
parent d7f729fb8f
commit 380d61e990
3 changed files with 5 additions and 5 deletions

View File

@ -1,7 +1,7 @@
clean-ptx: clean-ptx:
find target -name "*.ptx" -type f -delete find target -name "*.ptx" -type f -delete
echo "" > kernels/src/lib.rs echo "" > candle-kernels/src/lib.rs
touch kernels/build.rs touch candle-kernels/build.rs
clean: clean:
cargo clean cargo clean

View File

@ -301,8 +301,8 @@ fn gemm_config<T>(
Ok(StridedBatchedConfig { Ok(StridedBatchedConfig {
batch_size: b as i32, batch_size: b as i32,
gemm, gemm,
stride_a: (m * k) as i64, stride_a: (n * k) as i64,
stride_b: (n * k) as i64, stride_b: (m * k) as i64,
stride_c: (m * n) as i64, stride_c: (m * n) as i64,
}) })
} }

View File

@ -14,7 +14,7 @@ extern "C" __global__ void FN_NAME( \
const size_t *dims = info; \ const size_t *dims = info; \
const size_t *strides = info + num_dims; \ const size_t *strides = info + num_dims; \
const size_t *strides_t = info + 2*num_dims; \ const size_t *strides_t = info + 2*num_dims; \
const size_t *strides_f = info + 2*num_dims; \ const size_t *strides_f = info + 3*num_dims; \
if (is_contiguous(num_dims, dims, strides) \ if (is_contiguous(num_dims, dims, strides) \
&& is_contiguous(num_dims, dims, strides_f) \ && is_contiguous(num_dims, dims, strides_f) \
&& is_contiguous(num_dims, dims, strides_t)) { \ && is_contiguous(num_dims, dims, strides_t)) { \