Add the const-set op. (#2910)

* Add the const-set op.

* Cuda implementation.

* Bugfix.

* Metal cleanup.

* Add the metal kernels.

* Add some testing.

* Finish the metal implementation.

* Bump the version.
This commit is contained in:
Laurent Mazare
2025-04-19 10:07:02 +02:00
committed by GitHub
parent b2904a830b
commit a4c56a958e
20 changed files with 414 additions and 209 deletions

View File

@ -25,10 +25,12 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
if !device.is_metal() {
assert_eq!(
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
}
assert_eq!(
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
[
@ -63,6 +65,26 @@ fn ones(device: &Device) -> Result<()> {
}
fn full(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((3, 4), DType::U32, device)?;
tensor.const_set(42u32.into())?;
assert_eq!(
tensor.to_vec2::<u32>()?,
[[42, 42, 42, 42], [42, 42, 42, 42], [42, 42, 42, 42]]
);
tensor.i((.., 2))?.const_set(1337u32.into())?;
assert_eq!(
tensor.to_vec2::<u32>()?,
[[42, 42, 1337, 42], [42, 42, 1337, 42], [42, 42, 1337, 42]]
);
tensor.i((2, ..))?.const_set(1u32.into())?;
assert_eq!(
tensor.to_vec2::<u32>()?,
[[42, 42, 1337, 42], [42, 42, 1337, 42], [1, 1, 1, 1]]
);
Ok(())
}
fn const_set(device: &Device) -> Result<()> {
assert_eq!(
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
[[42, 42, 42], [42, 42, 42]],
@ -1509,6 +1531,7 @@ fn zero_dim(device: &Device) -> Result<()> {
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
test_device!(const_set, cs_cpu, cs_gpu, cs_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);