mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add a sort function. (#2134)
This commit is contained in:
@ -219,4 +219,21 @@ impl Tensor {
|
||||
// No need for a backward pass for arg sort.
|
||||
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
|
||||
}
|
||||
|
||||
/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
|
||||
/// sorted indexes.
|
||||
///
|
||||
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||
/// comes to ties.
|
||||
pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
|
||||
if !self.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous {
|
||||
op: "sort_last_dim",
|
||||
});
|
||||
}
|
||||
let asort = self.arg_sort_last_dim(asc)?;
|
||||
let sorted = self.gather(&asort, crate::D::Minus1)?;
|
||||
Ok((sorted, asort))
|
||||
}
|
||||
}
|
||||
|
@ -109,6 +109,24 @@ fn asort(device: &Device) -> Result<()> {
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||
);
|
||||
let (sorted, indexes) = tensor.sort_last_dim(true)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||
);
|
||||
assert_eq!(
|
||||
sorted.to_vec2::<f32>()?,
|
||||
[[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]]
|
||||
);
|
||||
let (sorted, indexes) = tensor.sort_last_dim(false)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
sorted.to_vec2::<f32>()?,
|
||||
[[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user