diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index e3ed41e5..a645d8b1 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -141,28 +141,114 @@ impl IndexOp for Tensor where T: Into, { + ///```rust + /// use candle_core::{Tensor, DType, Device, IndexOp}; + /// let a = Tensor::new(&[ + /// [0., 1.], + /// [2., 3.], + /// [4., 5.] + /// ], &Device::Cpu)?; + /// + /// let b = a.i(0)?; + /// assert_eq!(b.shape().dims(), &[2]); + /// assert_eq!(b.to_vec1::()?, &[0., 1.]); + /// + /// let c = a.i(..2)?; + /// assert_eq!(c.shape().dims(), &[2, 2]); + /// assert_eq!(c.to_vec2::()?, &[ + /// [0., 1.], + /// [2., 3.] + /// ]); + /// + /// let d = a.i(1..)?; + /// assert_eq!(d.shape().dims(), &[2, 2]); + /// assert_eq!(d.to_vec2::()?, &[ + /// [2., 3.], + /// [4., 5.] + /// ]); + /// # Ok::<(), candle_core::Error>(()) fn i(&self, index: T) -> Result { self.index(&[index.into()]) } } +impl IndexOp<(A,)> for Tensor +where + A: Into, +{ + ///```rust + /// use candle_core::{Tensor, DType, Device, IndexOp}; + /// let a = Tensor::new(&[ + /// [0f32, 1.], + /// [2. , 3.], + /// [4. , 5.] + /// ], &Device::Cpu)?; + /// + /// let b = a.i((0,))?; + /// assert_eq!(b.shape().dims(), &[2]); + /// assert_eq!(b.to_vec1::()?, &[0., 1.]); + /// + /// let c = a.i((..2,))?; + /// assert_eq!(c.shape().dims(), &[2, 2]); + /// assert_eq!(c.to_vec2::()?, &[ + /// [0., 1.], + /// [2., 3.] + /// ]); + /// + /// let d = a.i((1..,))?; + /// assert_eq!(d.shape().dims(), &[2, 2]); + /// assert_eq!(d.to_vec2::()?, &[ + /// [2., 3.], + /// [4., 5.] + /// ]); + /// # Ok::<(), candle_core::Error>(()) + fn i(&self, (a,): (A,)) -> Result { + self.index(&[a.into()]) + } +} +#[allow(non_snake_case)] +impl IndexOp<(A, B)> for Tensor +where + A: Into, + B: Into, +{ + ///```rust + /// use candle_core::{Tensor, DType, Device, IndexOp}; + /// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?; + /// + /// let b = a.i((1, 0))?; + /// assert_eq!(b.to_vec0::()?, 3.); + /// + /// let c = a.i((..2, 1))?; + /// assert_eq!(c.shape().dims(), &[2]); + /// assert_eq!(c.to_vec1::()?, &[1., 4.]); + /// + /// let d = a.i((2.., ..))?; + /// assert_eq!(c.shape().dims(), &[2]); + /// assert_eq!(c.to_vec1::()?, &[1., 4.]); + /// # Ok::<(), candle_core::Error>(()) + fn i(&self, (a, b): (A, B)) -> Result { + self.index(&[a.into(), b.into()]) + } +} + macro_rules! index_op_tuple { - ($($t:ident),+) => { + ($doc:tt, $($t:ident),+) => { #[allow(non_snake_case)] impl<$($t),*> IndexOp<($($t,)*)> for Tensor where $($t: Into,)* { + #[doc=$doc] fn i(&self, ($($t,)*): ($($t,)*)) -> Result { self.index(&[$($t.into(),)*]) } } }; } -index_op_tuple!(A); -index_op_tuple!(A, B); -index_op_tuple!(A, B, C); -index_op_tuple!(A, B, C, D); -index_op_tuple!(A, B, C, D, E); -index_op_tuple!(A, B, C, D, E, F); -index_op_tuple!(A, B, C, D, E, F, G); + +index_op_tuple!("see [TensorIndex#method.i]", A, B, C); +index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D); +index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E); +index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F); +index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G); diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e8b02605..a2c3f428 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -370,6 +370,15 @@ impl Tensor { /// Returns a new tensor with all the elements having the same specified value. Note that /// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed. + ///```rust + /// use candle_core::{Tensor, Device}; + /// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec2::()?, &[ + /// [3.5, 3.5, 3.5, 3.5], + /// [3.5, 3.5, 3.5, 3.5], + /// ]); + /// # Ok::<(), candle_core::Error>(()) pub fn full>( value: D, shape: S, @@ -379,6 +388,13 @@ impl Tensor { } /// Creates a new 1D tensor from an iterator. + ///```rust + /// use candle_core::{Tensor, Device}; + /// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec1::()?, &[1.0, 2.0, 3.0, 4.0]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` pub fn from_iter( iter: impl IntoIterator, device: &Device, @@ -390,12 +406,26 @@ impl Tensor { /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common /// difference `1` from `start`. + ///```rust + /// use candle_core::{Tensor, Device}; + /// let a = Tensor::arange(2., 5., &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec1::()?, &[2., 3., 4.]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` pub fn arange(start: D, end: D, device: &Device) -> Result { Self::arange_step(start, end, D::one(), device) } /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common /// difference `step` from `start`. + ///```rust + /// use candle_core::{Tensor, Device}; + /// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec1::()?, &[2.0, 2.5, 3.0, 3.5]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` pub fn arange_step( start: D, end: D, @@ -441,6 +471,16 @@ impl Tensor { /// Creates a new tensor initialized with values from the input vector. The number of elements /// in this vector must be the same as the number of elements defined by the shape. /// If the device is cpu, no data copy is made. + ///```rust + /// use candle_core::{Tensor, Device}; + /// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec2::()?, &[ + /// [1., 2., 3.], + /// [4., 5., 6.] + /// ]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` pub fn from_vec, D: crate::WithDType>( data: Vec, shape: S, @@ -451,6 +491,17 @@ impl Tensor { /// Creates a new tensor initialized with values from the input slice. The number of elements /// in this vector must be the same as the number of elements defined by the shape. + ///```rust + /// use candle_core::{Tensor, Device}; + /// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.]; + /// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec2::()?, &[ + /// [2., 3., 4.], + /// [5., 6., 7.] + /// ]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` pub fn from_slice, D: crate::WithDType>( array: &[D], shape: S, @@ -732,6 +783,30 @@ impl Tensor { /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` /// ranges from `start` to `start + len`. + /// ``` + /// use candle_core::{Tensor, Device}; + /// let a = Tensor::new(&[ + /// [0f32, 1., 2.], + /// [3. , 4., 5.], + /// [6. , 7., 8.] + /// ], &Device::Cpu)?; + /// + /// let b = a.narrow(0, 1, 2)?; + /// assert_eq!(b.shape().dims(), &[2, 3]); + /// assert_eq!(b.to_vec2::()?, &[ + /// [3., 4., 5.], + /// [6., 7., 8.] + /// ]); + /// + /// let c = a.narrow(1, 1, 1)?; + /// assert_eq!(c.shape().dims(), &[3, 1]); + /// assert_eq!(c.to_vec2::()?, &[ + /// [1.], + /// [4.], + /// [7.] + /// ]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` pub fn narrow(&self, dim: D, start: usize, len: usize) -> Result { let dims = self.dims(); let dim = dim.to_index(self.shape(), "narrow")?;