mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Support negative steps in arange. (#1218)
This commit is contained in:
@ -385,12 +385,22 @@ impl Tensor {
|
|||||||
step: D,
|
step: D,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
if D::is_zero(&step) {
|
||||||
|
crate::bail!("step cannot be zero")
|
||||||
|
}
|
||||||
let mut data = vec![];
|
let mut data = vec![];
|
||||||
let mut current = start;
|
let mut current = start;
|
||||||
|
if step >= D::zero() {
|
||||||
while current < end {
|
while current < end {
|
||||||
data.push(current);
|
data.push(current);
|
||||||
current += step;
|
current += step;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
while current > end {
|
||||||
|
data.push(current);
|
||||||
|
current += step;
|
||||||
|
}
|
||||||
|
}
|
||||||
let len = data.len();
|
let len = data.len();
|
||||||
Self::from_vec_impl(data, len, device, false)
|
Self::from_vec_impl(data, len, device, false)
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,26 @@ fn ones(device: &Device) -> Result<()> {
|
|||||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn arange(device: &Device) -> Result<()> {
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
|
||||||
|
[0, 1, 2, 3, 4],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
|
||||||
|
[0, 2, 4],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
|
||||||
|
[0, 3],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
|
||||||
|
[5, 4, 3, 2, 1],
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1037,6 +1056,7 @@ fn randn(device: &Device) -> Result<()> {
|
|||||||
|
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||||
test_device!(ones, ones_cpu, ones_gpu);
|
test_device!(ones, ones_cpu, ones_gpu);
|
||||||
|
test_device!(arange, arange_cpu, arange_gpu);
|
||||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||||
|
Reference in New Issue
Block a user