diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index be8f7b07..96a2b809 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -142,6 +142,9 @@ pub enum Error { #[error("{op} expects at least one tensor")] OpRequiresAtLeastOneTensor { op: &'static str }, + #[error("{op} expects at least two tensors")] + OpRequiresAtLeastTwoTensors { op: &'static str }, + #[error("backward is not supported for {op}")] BackwardNotSupported { op: &'static str }, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 381c768b..facbd8ed 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -540,6 +540,69 @@ impl Tensor { Ok(inp) } + /// Creates grids of coordinates specified by the 1D inputs. + /// + /// # Arguments + /// + /// * `args` - A slice of 1D tensors. + /// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the + /// first dimension corresponds to the cardinality of the second input and the second + /// dimension corresponds to the cardinality of the first input. If ij is selected, the + /// dimensions are in the same order as the cardinality of the inputs. + /// + /// # Examples + /// + /// ``` + /// use candle_core::{Tensor, Device, Shape}; + /// # fn dummy() -> Result<(), Box> { + /// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?; + /// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?; + /// + /// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?; + /// + /// assert_eq!(2, grids_xy.len()); + /// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims()); + /// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims()); + /// + /// assert_eq!(grids_xy[0].to_vec1::()?, &[1., 1., 1., 2., 2., 2., 3., 3., 3.]); + /// + /// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?; + /// + /// assert_eq!(grids_xy[0].to_vec1::()?, &[1., 2., 3., 1., 2., 3., 1., 2., 3.]); + /// # Ok(()) } + /// ``` + /// + /// # Errors + /// + /// * Will return `Err` if `args` contains less than 2 tensors. + /// + pub fn meshgrid>(args: &[A], xy_indexing: bool) -> Result> { + if args.is_empty() || args.len() <= 1 { + Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())? + } + + let mut grids = Vec::with_capacity(args.len()); + for idx in 0..args.len() { + // Repeat the tensor across the given dimensions to infiltrate it in the resulting grid + let repeats: Vec<_> = args + .iter() + .map(|t| t.as_ref().dim(0)) + .collect::>()?; + let repeated_tensor = args[idx].as_ref().clone().repeat(repeats.clone())?; + + // Reshape the tensor to match the dimensions of the grid + let reshaped_tensor = repeated_tensor.reshape(repeats)?; + + grids.push(reshaped_tensor); + } + + if xy_indexing { + Ok(grids) + } else { + Ok(grids.into_iter().rev().collect()) + } + } + /// This operation multiplies the input tensor by `mul` then adds `add` and return the result. /// The input values `mul` and `add` are casted to the appropriate type so some rounding might /// be performed.