From c698e176190eca52ff754e10d05becc972c75a5c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 25 Oct 2023 13:47:54 +0100 Subject: [PATCH] Enable the test for meshgrid + fix the implementation. (#1175) --- candle-core/src/tensor.rs | 52 +++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index facbd8ed..9dea62fa 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -552,24 +552,24 @@ impl Tensor { /// /// # Examples /// - /// ``` + /// ```rust /// use candle_core::{Tensor, Device, Shape}; - /// # fn dummy() -> Result<(), Box> { /// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?; /// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?; /// /// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?; /// - /// assert_eq!(2, grids_xy.len()); - /// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims()); - /// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims()); + /// assert_eq!(grids_xy.len(), 2); + /// assert_eq!(grids_xy[0].dims(), &[3, 3]); /// - /// assert_eq!(grids_xy[0].to_vec1::()?, &[1., 1., 1., 2., 2., 2., 3., 3., 3.]); + /// assert_eq!(grids_xy[0].to_vec2::()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]); + /// assert_eq!(grids_xy[1].to_vec2::()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]); /// /// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?; /// - /// assert_eq!(grids_xy[0].to_vec1::()?, &[1., 2., 3., 1., 2., 3., 1., 2., 3.]); - /// # Ok(()) } + /// assert_eq!(grids_ij[0].to_vec2::()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]); + /// assert_eq!(grids_ij[1].to_vec2::()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]); + /// # Ok::<(), candle_core::Error>(()) /// ``` /// /// # Errors @@ -577,30 +577,34 @@ impl Tensor { /// * Will return `Err` if `args` contains less than 2 tensors. /// pub fn meshgrid>(args: &[A], xy_indexing: bool) -> Result> { - if args.is_empty() || args.len() <= 1 { + if args.len() <= 1 { Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())? } + let args: Vec<_> = if xy_indexing { + args.iter().rev().collect() + } else { + args.iter().collect() + }; + + let mut shape = Vec::with_capacity(args.len()); + for arg in args.iter() { + shape.push(arg.as_ref().dims1()?) + } let mut grids = Vec::with_capacity(args.len()); for idx in 0..args.len() { - // Repeat the tensor across the given dimensions to infiltrate it in the resulting grid - let repeats: Vec<_> = args - .iter() - .map(|t| t.as_ref().dim(0)) - .collect::>()?; - let repeated_tensor = args[idx].as_ref().clone().repeat(repeats.clone())?; - - // Reshape the tensor to match the dimensions of the grid - let reshaped_tensor = repeated_tensor.reshape(repeats)?; - - grids.push(reshaped_tensor); + let mut ones = vec![1usize; args.len()]; + ones[idx] = shape[idx]; + let arg = args[idx].as_ref().reshape(ones)?; + let mut repeats = shape.clone(); + repeats[idx] = 1; + let repeated_tensor = arg.repeat(repeats)?; + grids.push(repeated_tensor); } - if xy_indexing { - Ok(grids) - } else { - Ok(grids.into_iter().rev().collect()) + grids.reverse(); } + Ok(grids) } /// This operation multiplies the input tensor by `mul` then adds `add` and return the result.