mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Merge pull request #13 from LaurentMazare/cuda-bugfixes
Fix two cuda bugs (matmul and where_cond).
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
clean-ptx:
|
||||
find target -name "*.ptx" -type f -delete
|
||||
echo "" > kernels/src/lib.rs
|
||||
touch kernels/build.rs
|
||||
echo "" > candle-kernels/src/lib.rs
|
||||
touch candle-kernels/build.rs
|
||||
|
||||
clean:
|
||||
cargo clean
|
@ -301,8 +301,8 @@ fn gemm_config<T>(
|
||||
Ok(StridedBatchedConfig {
|
||||
batch_size: b as i32,
|
||||
gemm,
|
||||
stride_a: (m * k) as i64,
|
||||
stride_b: (n * k) as i64,
|
||||
stride_a: (n * k) as i64,
|
||||
stride_b: (m * k) as i64,
|
||||
stride_c: (m * n) as i64,
|
||||
})
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
const size_t *strides_t = info + 2*num_dims; \
|
||||
const size_t *strides_f = info + 2*num_dims; \
|
||||
const size_t *strides_f = info + 3*num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides) \
|
||||
&& is_contiguous(num_dims, dims, strides_f) \
|
||||
&& is_contiguous(num_dims, dims, strides_t)) { \
|
||||
|
Reference in New Issue
Block a user