mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Fix for the index-select cuda setup. (#1022)
* Fix for index-select. * Better fix + add some testing.
This commit is contained in:
@ -891,7 +891,7 @@ impl<'a> Map1 for IndexSelect<'a> {
|
|||||||
};
|
};
|
||||||
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
||||||
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
||||||
let dim_size = src_l.dims()[self.2];
|
let dim_size = ids_shape.elem_count();
|
||||||
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
||||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||||
|
@ -653,6 +653,21 @@ fn index_select(device: &Device) -> Result<()> {
|
|||||||
hs.to_vec2::<f32>()?,
|
hs.to_vec2::<f32>()?,
|
||||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||||
);
|
);
|
||||||
|
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||||
|
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||||
|
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||||
|
let hs = t.index_select(&ids, 0)?;
|
||||||
|
assert_eq!(
|
||||||
|
hs.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[6.0, 7.0, 8.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[6.0, 7.0, 8.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user