mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
1 Commits
0.7.0
...
meshgrid-f
Author | SHA1 | Date | |
---|---|---|---|
20da4f44ef |
@ -552,24 +552,24 @@ impl Tensor {
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device, Shape};
|
||||
/// # fn dummy() -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
|
||||
/// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
|
||||
///
|
||||
/// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
|
||||
///
|
||||
/// assert_eq!(2, grids_xy.len());
|
||||
/// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims());
|
||||
/// assert_eq!(Vec::from([3]), grids_xy[0].shape().clone().into_dims());
|
||||
/// assert_eq!(grids_xy.len(), 2);
|
||||
/// assert_eq!(grids_xy[0].dims(), &[3, 3]);
|
||||
///
|
||||
/// assert_eq!(grids_xy[0].to_vec1::<f32>()?, &[1., 1., 1., 2., 2., 2., 3., 3., 3.]);
|
||||
/// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
|
||||
/// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
|
||||
///
|
||||
/// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
|
||||
///
|
||||
/// assert_eq!(grids_xy[0].to_vec1::<f32>()?, &[1., 2., 3., 1., 2., 3., 1., 2., 3.]);
|
||||
/// # Ok(()) }
|
||||
/// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
|
||||
/// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
///
|
||||
/// # Errors
|
||||
@ -577,30 +577,34 @@ impl Tensor {
|
||||
/// * Will return `Err` if `args` contains less than 2 tensors.
|
||||
///
|
||||
pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
|
||||
if args.is_empty() || args.len() <= 1 {
|
||||
if args.len() <= 1 {
|
||||
Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
|
||||
}
|
||||
let args: Vec<_> = if xy_indexing {
|
||||
args.iter().rev().collect()
|
||||
} else {
|
||||
args.iter().collect()
|
||||
};
|
||||
|
||||
let mut shape = Vec::with_capacity(args.len());
|
||||
for arg in args.iter() {
|
||||
shape.push(arg.as_ref().dims1()?)
|
||||
}
|
||||
|
||||
let mut grids = Vec::with_capacity(args.len());
|
||||
for idx in 0..args.len() {
|
||||
// Repeat the tensor across the given dimensions to infiltrate it in the resulting grid
|
||||
let repeats: Vec<_> = args
|
||||
.iter()
|
||||
.map(|t| t.as_ref().dim(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let repeated_tensor = args[idx].as_ref().clone().repeat(repeats.clone())?;
|
||||
|
||||
// Reshape the tensor to match the dimensions of the grid
|
||||
let reshaped_tensor = repeated_tensor.reshape(repeats)?;
|
||||
|
||||
grids.push(reshaped_tensor);
|
||||
let mut ones = vec![1usize; args.len()];
|
||||
ones[idx] = shape[idx];
|
||||
let arg = args[idx].as_ref().reshape(ones)?;
|
||||
let mut repeats = shape.clone();
|
||||
repeats[idx] = 1;
|
||||
let repeated_tensor = arg.repeat(repeats)?;
|
||||
grids.push(repeated_tensor);
|
||||
}
|
||||
|
||||
if xy_indexing {
|
||||
Ok(grids)
|
||||
} else {
|
||||
Ok(grids.into_iter().rev().collect())
|
||||
grids.reverse();
|
||||
}
|
||||
Ok(grids)
|
||||
}
|
||||
|
||||
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
|
||||
|
Reference in New Issue
Block a user