Support negative steps in arange. (#1218)

This commit is contained in:
Laurent Mazare
2023-10-30 08:40:54 +01:00
committed by GitHub
parent 174b208052
commit 5fc66bd4ba
2 changed files with 33 additions and 3 deletions

View File

@ -385,11 +385,21 @@ impl Tensor {
step: D,
device: &Device,
) -> Result<Self> {
if D::is_zero(&step) {
crate::bail!("step cannot be zero")
}
let mut data = vec![];
let mut current = start;
while current < end {
data.push(current);
current += step;
if step >= D::zero() {
while current < end {
data.push(current);
current += step;
}
} else {
while current > end {
data.push(current);
current += step;
}
}
let len = data.len();
Self::from_vec_impl(data, len, device, false)

View File

@ -29,7 +29,26 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
Ok(())
}
fn arange(device: &Device) -> Result<()> {
assert_eq!(
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
[0, 1, 2, 3, 4],
);
assert_eq!(
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
[0, 2, 4],
);
assert_eq!(
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
[0, 3],
);
assert_eq!(
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
[5, 4, 3, 2, 1],
);
Ok(())
}
@ -1037,6 +1056,7 @@ fn randn(device: &Device) -> Result<()> {
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(ones, ones_cpu, ones_gpu);
test_device!(arange, arange_cpu, arange_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);