Add a test for scatter add. (#238)

* Add a test for scatter add (segfaults on gpus for now).

* Bugfix for the scatter add cuda kernel.
This commit is contained in:
Laurent Mazare
2023-07-25 09:12:14 +01:00
committed by GitHub
parent 18cc73954a
commit 944d70bd9a
3 changed files with 47 additions and 8 deletions

View File

@ -846,7 +846,7 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
_dst_shape: &Shape,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
@ -874,11 +874,11 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let ids_dim_sz = ids_l.dims()[dim];
let dst_dim_sz = dst_shape.dims()[dim];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let params = (ids, &src, dst, left_sz, src_dim_sz, ids_dim_sz, right_sz);
let params = (ids, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(())

View File

@ -389,6 +389,46 @@ fn index_add(device: &Device) -> Result<()> {
Ok(())
}
fn scatter_add(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
let ids = Tensor::new(&[[0u32, 1, 2], [3, 4, 0], [3, 3, 1], [2, 0, 4]], device)?;
let init = Tensor::ones((4, 5), DType::F32, device)?;
let hs = init.scatter_add(&ids, &t, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[1.0, 2.0, 3.0, 1.0, 1.0],
[6.0, 1.0, 1.0, 4.0, 5.0],
[1.0, 9.0, 1.0, 14.0, 1.0],
[11.0, 1.0, 10.0, 1.0, 12.0]
]
);
let init = Tensor::ones((6, 3), DType::F32, device)?;
let hs = init.scatter_add(&ids, &t, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[1.0, 11.0, 6.0],
[1.0, 2.0, 9.0],
[10.0, 1.0, 3.0],
[10.0, 8.0, 1.0],
[1.0, 5.0, 12.0],
[1.0, 1.0, 1.0]
]
);
Ok(())
}
fn gather(device: &Device) -> Result<()> {
let ids = Tensor::new(&[[0u32], [2u32], [1u32], [0u32]], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
@ -588,3 +628,4 @@ test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);

View File

@ -144,7 +144,6 @@ extern "C" __global__ void FN_NAME( \
template<typename T, typename I>
__device__ void scatter_add(
const I *ids,
const size_t ids_dim_size,
const T *inp,
T *out,
const size_t left_size,
@ -156,8 +155,8 @@ __device__ void scatter_add(
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
const size_t pre = i / right_size;
const size_t post = i % right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) {
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
const size_t idx = ids[src_i];
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
@ -168,14 +167,13 @@ __device__ void scatter_add(
#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const INDEX_TYPENAME *ids, \
const size_t ids_dim_size, \
const TYPENAME *inp, \
TYPENAME *out, \
const size_t left_size, \
const size_t src_dim_size, \
const size_t dst_dim_size, \
const size_t right_size \
) { scatter_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
) { scatter_add(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
#if __CUDA_ARCH__ >= 800