mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Implemented meshgrid (#1174)
* Implemented meshgrid * Resolved feedback from LaurentMazare * Rustfmt * Updated docstring * Removed outdated error mode from docstring
This commit is contained in:

committed by
GitHub

parent
b6053b938b
commit
e4c9adfdbe
@ -142,6 +142,9 @@ pub enum Error {
|
||||
#[error("{op} expects at least one tensor")]
|
||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||
|
||||
#[error("{op} expects at least two tensors")]
|
||||
OpRequiresAtLeastTwoTensors { op: &'static str },
|
||||
|
||||
#[error("backward is not supported for {op}")]
|
||||
BackwardNotSupported { op: &'static str },
|
||||
|
||||
|
@ -540,6 +540,69 @@ impl Tensor {
|
||||
Ok(inp)
|
||||
}
|
||||
|
||||
/// Creates grids of coordinates specified by the 1D inputs.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `args` - A slice of 1D tensors.
|
||||
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
|
||||
/// first dimension corresponds to the cardinality of the second input and the second
|
||||
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
||||
/// dimensions are in the same order as the cardinality of the inputs.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use candle_core::{Tensor, Device, Shape};
|
||||
/// # fn dummy() -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
|
||||
/// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
|
||||
///
|
||||
/// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
|
||||
///
|
||||
/// assert_eq!(2, grids_xy.len());
|
||||
/// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims());
|
||||
/// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims());
|
||||
///
|
||||
/// assert_eq!(grids_xy[0].to_vec1::<f32>()?, &[1., 1., 1., 2., 2., 2., 3., 3., 3.]);
|
||||
///
|
||||
/// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
|
||||
///
|
||||
/// assert_eq!(grids_xy[0].to_vec1::<f32>()?, &[1., 2., 3., 1., 2., 3., 1., 2., 3.]);
|
||||
/// # Ok(()) }
|
||||
/// ```
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// * Will return `Err` if `args` contains less than 2 tensors.
|
||||
///
|
||||
pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
|
||||
if args.is_empty() || args.len() <= 1 {
|
||||
Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
|
||||
}
|
||||
|
||||
let mut grids = Vec::with_capacity(args.len());
|
||||
for idx in 0..args.len() {
|
||||
// Repeat the tensor across the given dimensions to infiltrate it in the resulting grid
|
||||
let repeats: Vec<_> = args
|
||||
.iter()
|
||||
.map(|t| t.as_ref().dim(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let repeated_tensor = args[idx].as_ref().clone().repeat(repeats.clone())?;
|
||||
|
||||
// Reshape the tensor to match the dimensions of the grid
|
||||
let reshaped_tensor = repeated_tensor.reshape(repeats)?;
|
||||
|
||||
grids.push(reshaped_tensor);
|
||||
}
|
||||
|
||||
if xy_indexing {
|
||||
Ok(grids)
|
||||
} else {
|
||||
Ok(grids.into_iter().rev().collect())
|
||||
}
|
||||
}
|
||||
|
||||
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
|
||||
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
|
||||
/// be performed.
|
||||
|
Reference in New Issue
Block a user