diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e1cae41c..e6e7b415 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -361,6 +361,16 @@ impl Tensor { Self::new_impl(array, shape, device, false) } + /// Returns a new tensor with all the elements having the same specified value. Note that + /// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed. + pub fn full>( + value: D, + shape: S, + device: &Device, + ) -> Result { + Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape) + } + /// Creates a new 1D tensor from an iterator. pub fn from_iter( iter: impl IntoIterator, diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index a4548d56..e83fb55b 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -32,6 +32,14 @@ fn ones(device: &Device) -> Result<()> { Ok(()) } +fn full(device: &Device) -> Result<()> { + assert_eq!( + Tensor::full(42u32, (2, 3), device)?.to_vec2::()?, + [[42, 42, 42], [42, 42, 42]], + ); + Ok(()) +} + fn arange(device: &Device) -> Result<()> { assert_eq!( Tensor::arange(0u8, 5u8, device)?.to_vec1::()?, @@ -1072,6 +1080,7 @@ fn randn(device: &Device) -> Result<()> { test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal); test_device!(ones, ones_cpu, ones_gpu, ones_metal); +test_device!(full, full_cpu, full_gpu, full_metal); test_device!(arange, arange_cpu, arange_gpu, arange_metal); test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);