Enable the test for meshgrid + fix the implementation. (#1175)

This commit is contained in:
Laurent Mazare
2023-10-25 13:47:54 +01:00
committed by GitHub
parent e4c9adfdbe
commit c698e17619

View File

@ -552,24 +552,24 @@ impl Tensor {
///
/// # Examples
///
/// ```
/// ```rust
/// use candle_core::{Tensor, Device, Shape};
/// # fn dummy() -> Result<(), Box<dyn std::error::Error>> {
/// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
/// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
///
/// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
///
/// assert_eq!(2, grids_xy.len());
/// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims());
/// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims());
/// assert_eq!(grids_xy.len(), 2);
/// assert_eq!(grids_xy[0].dims(), &[3, 3]);
///
/// assert_eq!(grids_xy[0].to_vec1::<f32>()?, &[1., 1., 1., 2., 2., 2., 3., 3., 3.]);
/// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
/// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
///
/// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
///
/// assert_eq!(grids_xy[0].to_vec1::<f32>()?, &[1., 2., 3., 1., 2., 3., 1., 2., 3.]);
/// # Ok(()) }
/// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
/// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
///
/// # Errors
@ -577,30 +577,34 @@ impl Tensor {
/// * Will return `Err` if `args` contains less than 2 tensors.
///
pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
if args.is_empty() || args.len() <= 1 {
if args.len() <= 1 {
Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
}
let args: Vec<_> = if xy_indexing {
args.iter().rev().collect()
} else {
args.iter().collect()
};
let mut shape = Vec::with_capacity(args.len());
for arg in args.iter() {
shape.push(arg.as_ref().dims1()?)
}
let mut grids = Vec::with_capacity(args.len());
for idx in 0..args.len() {
// Repeat the tensor across the given dimensions to infiltrate it in the resulting grid
let repeats: Vec<_> = args
.iter()
.map(|t| t.as_ref().dim(0))
.collect::<Result<_>>()?;
let repeated_tensor = args[idx].as_ref().clone().repeat(repeats.clone())?;
// Reshape the tensor to match the dimensions of the grid
let reshaped_tensor = repeated_tensor.reshape(repeats)?;
grids.push(reshaped_tensor);
let mut ones = vec![1usize; args.len()];
ones[idx] = shape[idx];
let arg = args[idx].as_ref().reshape(ones)?;
let mut repeats = shape.clone();
repeats[idx] = 1;
let repeated_tensor = arg.repeat(repeats)?;
grids.push(repeated_tensor);
}
if xy_indexing {
Ok(grids)
} else {
Ok(grids.into_iter().rev().collect())
grids.reverse();
}
Ok(grids)
}
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.