diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index c6f2364d..adcdc59d 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -385,11 +385,21 @@ impl Tensor { step: D, device: &Device, ) -> Result { + if D::is_zero(&step) { + crate::bail!("step cannot be zero") + } let mut data = vec![]; let mut current = start; - while current < end { - data.push(current); - current += step; + if step >= D::zero() { + while current < end { + data.push(current); + current += step; + } + } else { + while current > end { + data.push(current); + current += step; + } } let len = data.len(); Self::from_vec_impl(data, len, device, false) diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 899efcf3..734cb7e8 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -29,7 +29,26 @@ fn ones(device: &Device) -> Result<()> { Tensor::ones((2, 3), DType::F64, device)?.to_vec2::()?, [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], ); + Ok(()) +} +fn arange(device: &Device) -> Result<()> { + assert_eq!( + Tensor::arange(0u8, 5u8, device)?.to_vec1::()?, + [0, 1, 2, 3, 4], + ); + assert_eq!( + Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::()?, + [0, 2, 4], + ); + assert_eq!( + Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::()?, + [0, 3], + ); + assert_eq!( + Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::()?, + [5, 4, 3, 2, 1], + ); Ok(()) } @@ -1037,6 +1056,7 @@ fn randn(device: &Device) -> Result<()> { test_device!(zeros, zeros_cpu, zeros_gpu); test_device!(ones, ones_cpu, ones_gpu); +test_device!(arange, arange_cpu, arange_gpu); test_device!(add_mul, add_mul_cpu, add_mul_gpu); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu); test_device!(narrow, narrow_cpu, narrow_gpu);