Implemented meshgrid (#1174)

* Implemented meshgrid

* Resolved feedback from LaurentMazare

* Rustfmt

* Updated docstring

* Removed outdated error mode from docstring
This commit is contained in:
Wouter Doppenberg
2023-10-25 13:49:11 +02:00
committed by GitHub
parent b6053b938b
commit e4c9adfdbe
2 changed files with 66 additions and 0 deletions

View File

@ -142,6 +142,9 @@ pub enum Error {
#[error("{op} expects at least one tensor")]
OpRequiresAtLeastOneTensor { op: &'static str },
#[error("{op} expects at least two tensors")]
OpRequiresAtLeastTwoTensors { op: &'static str },
#[error("backward is not supported for {op}")]
BackwardNotSupported { op: &'static str },

View File

@ -540,6 +540,69 @@ impl Tensor {
Ok(inp)
}
/// Creates grids of coordinates specified by the 1D inputs.
///
/// # Arguments
///
/// * `args` - A slice of 1D tensors.
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
/// first dimension corresponds to the cardinality of the second input and the second
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
/// dimensions are in the same order as the cardinality of the inputs.
///
/// # Examples
///
/// ```
/// use candle_core::{Tensor, Device, Shape};
/// # fn dummy() -> Result<(), Box<dyn std::error::Error>> {
/// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
/// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
///
/// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
///
/// assert_eq!(2, grids_xy.len());
/// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims());
/// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims());
///
/// assert_eq!(grids_xy[0].to_vec1::<f32>()?, &[1., 1., 1., 2., 2., 2., 3., 3., 3.]);
///
/// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
///
/// assert_eq!(grids_xy[0].to_vec1::<f32>()?, &[1., 2., 3., 1., 2., 3., 1., 2., 3.]);
/// # Ok(()) }
/// ```
///
/// # Errors
///
/// * Will return `Err` if `args` contains less than 2 tensors.
///
pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
if args.is_empty() || args.len() <= 1 {
Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
}
let mut grids = Vec::with_capacity(args.len());
for idx in 0..args.len() {
// Repeat the tensor across the given dimensions to infiltrate it in the resulting grid
let repeats: Vec<_> = args
.iter()
.map(|t| t.as_ref().dim(0))
.collect::<Result<_>>()?;
let repeated_tensor = args[idx].as_ref().clone().repeat(repeats.clone())?;
// Reshape the tensor to match the dimensions of the grid
let reshaped_tensor = repeated_tensor.reshape(repeats)?;
grids.push(reshaped_tensor);
}
if xy_indexing {
Ok(grids)
} else {
Ok(grids.into_iter().rev().collect())
}
}
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
/// be performed.