mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a sort function. (#2134)
This commit is contained in:
@ -219,4 +219,21 @@ impl Tensor {
|
|||||||
// No need for a backward pass for arg sort.
|
// No need for a backward pass for arg sort.
|
||||||
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
|
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
|
||||||
|
/// sorted indexes.
|
||||||
|
///
|
||||||
|
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||||
|
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||||
|
/// comes to ties.
|
||||||
|
pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
|
||||||
|
if !self.is_contiguous() {
|
||||||
|
return Err(crate::Error::RequiresContiguous {
|
||||||
|
op: "sort_last_dim",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let asort = self.arg_sort_last_dim(asc)?;
|
||||||
|
let sorted = self.gather(&asort, crate::D::Minus1)?;
|
||||||
|
Ok((sorted, asort))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -109,6 +109,24 @@ fn asort(device: &Device) -> Result<()> {
|
|||||||
indexes.to_vec2::<u32>()?,
|
indexes.to_vec2::<u32>()?,
|
||||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||||
);
|
);
|
||||||
|
let (sorted, indexes) = tensor.sort_last_dim(true)?;
|
||||||
|
assert_eq!(
|
||||||
|
indexes.to_vec2::<u32>()?,
|
||||||
|
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
sorted.to_vec2::<f32>()?,
|
||||||
|
[[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]]
|
||||||
|
);
|
||||||
|
let (sorted, indexes) = tensor.sort_last_dim(false)?;
|
||||||
|
assert_eq!(
|
||||||
|
indexes.to_vec2::<u32>()?,
|
||||||
|
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
sorted.to_vec2::<f32>()?,
|
||||||
|
[[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user