mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Check the bounds in the cuda indexing kernels. (#2908)
* Check the bounds in the cuda indexing kernels. * Another check.
This commit is contained in:
@ -826,6 +826,31 @@ fn embeddings(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_fail() -> Result<()> {
|
||||
// Check that an error is properly reported on out of bounds.
|
||||
let ids = Tensor::new(&[4u32, 2u32, 1u32], &Device::Cpu)?;
|
||||
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &Device::Cpu)?;
|
||||
let hs = t.index_select(&ids, 0);
|
||||
assert!(hs.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// The test below triggers an unwinding panic as there is a panic within the
|
||||
// #[cfg(feature = "cuda")]
|
||||
// #[test]
|
||||
// #[should_panic]
|
||||
// fn index_select_fail_gpu() {
|
||||
// // Check that a panic happens for out of bounds in cuda
|
||||
// if let Ok(device) = Device::new_cuda(0) {
|
||||
// if let Ok(ids) = Tensor::new(&[4u32, 2u32, 1u32], &device) {
|
||||
// if let Ok(t) = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &device) {
|
||||
// let _ = t.index_select(&ids, 0);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
fn cmp(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
|
||||
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
|
||||
|
Reference in New Issue
Block a user