diff --git a/examples/cuda_basics.rs b/examples/cuda_basics.rs index 969d6e20..6f95723d 100644 --- a/examples/cuda_basics.rs +++ b/examples/cuda_basics.rs @@ -3,6 +3,11 @@ use candle::{Device, Tensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; + let x = Tensor::new(&[[11f32, 22.], [33., 44.], [55., 66.], [77., 78.]], &device)?; + println!("> {:?}", x.sum(&[0])?.to_vec2::()?); + println!("> {:?}", x.sum(&[1])?.to_vec2::()?); + println!("> {:?}", x.sum(&[0, 1])?.to_vec2::()?); + let x = Tensor::new(&[3f32, 1., 4., 1., 5.], &device)?; println!("{:?}", x.to_vec1::()?); let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?; diff --git a/kernels/src/reduce.cu b/kernels/src/reduce.cu index b9780000..0214ca88 100644 --- a/kernels/src/reduce.cu +++ b/kernels/src/reduce.cu @@ -20,7 +20,7 @@ extern "C" __global__ void FN_NAME( \ for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \ size_t stride = sum_dims_s[nd]; \ size_t pre = dst_index / stride; \ - size_t post = dst_index / stride; \ + size_t post = dst_index % stride; \ dst_index = (pre / sum_dims_l[nd]) * stride + post; \ } \ out[dst_index] += inp[i]; \ @@ -33,7 +33,7 @@ extern "C" __global__ void FN_NAME( \ for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \ size_t stride = sum_dims_s[nd]; \ size_t pre = dst_index / stride; \ - size_t post = dst_index / stride; \ + size_t post = dst_index % stride; \ dst_index = (pre / sum_dims_l[nd]) * stride + post; \ } \ out[dst_index] += inp[strided_i]; \ diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 2c96cc6b..85f55568 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -314,7 +314,14 @@ impl CudaStorage { .iter() .map(|&d| src_dims[d + 1..].iter().product::()) .collect(); - let cfg = LaunchConfig::for_num_elems(el as u32); + // let cfg = LaunchConfig::for_num_elems(el as u32); + // TODO: Hack to run the computation on a single thread, replace with a proper distributed + // algorithm. + let cfg = LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (1, 1, 1), + shared_mem_bytes: 0, + }; let dev = self.device(); let ds = dev.htod_copy([src_dims, stride, &sum_dims_l, &sum_dims_s].concat())?; let slice = match &self.slice {