Enable the test for meshgrid + fix the implementation. (#1175)

This commit is contained in:
Laurent Mazare
2023-10-25 13:47:54 +01:00
committed by GitHub
parent e4c9adfdbe
commit c698e17619

View File

@ -552,24 +552,24 @@ impl Tensor {
/// ///
/// # Examples /// # Examples
/// ///
/// ``` /// ```rust
/// use candle_core::{Tensor, Device, Shape}; /// use candle_core::{Tensor, Device, Shape};
/// # fn dummy() -> Result<(), Box<dyn std::error::Error>> {
/// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?; /// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
/// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?; /// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
/// ///
/// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?; /// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
/// ///
/// assert_eq!(2, grids_xy.len()); /// assert_eq!(grids_xy.len(), 2);
/// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims()); /// assert_eq!(grids_xy[0].dims(), &[3, 3]);
/// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims());
/// ///
/// assert_eq!(grids_xy[0].to_vec1::<f32>()?, &[1., 1., 1., 2., 2., 2., 3., 3., 3.]); /// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
/// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
/// ///
/// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?; /// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
/// ///
/// assert_eq!(grids_xy[0].to_vec1::<f32>()?, &[1., 2., 3., 1., 2., 3., 1., 2., 3.]); /// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
/// # Ok(()) } /// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
/// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
/// ///
/// # Errors /// # Errors
@ -577,30 +577,34 @@ impl Tensor {
/// * Will return `Err` if `args` contains less than 2 tensors. /// * Will return `Err` if `args` contains less than 2 tensors.
/// ///
pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> { pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
if args.is_empty() || args.len() <= 1 { if args.len() <= 1 {
Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())? Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
} }
let args: Vec<_> = if xy_indexing {
args.iter().rev().collect()
} else {
args.iter().collect()
};
let mut shape = Vec::with_capacity(args.len());
for arg in args.iter() {
shape.push(arg.as_ref().dims1()?)
}
let mut grids = Vec::with_capacity(args.len()); let mut grids = Vec::with_capacity(args.len());
for idx in 0..args.len() { for idx in 0..args.len() {
// Repeat the tensor across the given dimensions to infiltrate it in the resulting grid let mut ones = vec![1usize; args.len()];
let repeats: Vec<_> = args ones[idx] = shape[idx];
.iter() let arg = args[idx].as_ref().reshape(ones)?;
.map(|t| t.as_ref().dim(0)) let mut repeats = shape.clone();
.collect::<Result<_>>()?; repeats[idx] = 1;
let repeated_tensor = args[idx].as_ref().clone().repeat(repeats.clone())?; let repeated_tensor = arg.repeat(repeats)?;
grids.push(repeated_tensor);
// Reshape the tensor to match the dimensions of the grid
let reshaped_tensor = repeated_tensor.reshape(repeats)?;
grids.push(reshaped_tensor);
} }
if xy_indexing { if xy_indexing {
Ok(grids) grids.reverse();
} else {
Ok(grids.into_iter().rev().collect())
} }
Ok(grids)
} }
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result. /// This operation multiplies the input tensor by `mul` then adds `add` and return the result.