use crate::{Result, Tensor}; use rayon::prelude::*; #[derive(Debug, Clone, Copy)] struct ArgSort { asc: bool, last_dim: usize, } impl ArgSort { fn asort(&self, vs: &[T], layout: &crate::Layout) -> Vec { #[allow(clippy::uninit_vec)] // Safety: indexes are set later in the parallelized section. let mut sort_indexes = unsafe { let el_count = layout.shape().elem_count(); let mut v = Vec::with_capacity(el_count); v.set_len(el_count); v }; if self.asc { sort_indexes .par_chunks_exact_mut(self.last_dim) .zip(vs.par_chunks_exact(self.last_dim)) .for_each(|(indexes, vs)| { indexes .iter_mut() .enumerate() .for_each(|(i, v)| *v = i as u32); indexes.sort_by(|&i, &j| { vs[i as usize] .partial_cmp(&vs[j as usize]) .unwrap_or(std::cmp::Ordering::Greater) }) }); } else { sort_indexes .par_chunks_exact_mut(self.last_dim) .zip(vs.par_chunks_exact(self.last_dim)) .for_each(|(indexes, vs)| { indexes .iter_mut() .enumerate() .for_each(|(i, v)| *v = i as u32); indexes.sort_by(|&j, &i| { vs[i as usize] .partial_cmp(&vs[j as usize]) .unwrap_or(std::cmp::Ordering::Greater) }) }); } sort_indexes } } #[cfg(feature = "cuda")] mod cuda { use super::*; use crate::cuda_backend::cudarc::driver::{ CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits, }; use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr}; use crate::{CudaDevice, WithDType}; impl crate::cuda_backend::Map1Any for ArgSort { fn f) -> S>( &self, src: &CudaSlice, dev: &CudaDevice, layout: &crate::Layout, _wrap: W, ) -> Result { use cudarc::driver::PushKernelArg; let slice = match layout.contiguous_offsets() { None => crate::bail!("input has to be contiguous"), Some((o1, o2)) => src.slice(o1..o2), }; let elem_count = layout.shape().elem_count(); let dst = unsafe { dev.alloc::(elem_count) }.w()?; let func = if self.asc { dev.get_or_load_func(&kernel_name::("asort_asc"), &kernels::SORT)? } else { dev.get_or_load_func(&kernel_name::("asort_desc"), &kernels::SORT)? }; let ncols = self.last_dim; let nrows = elem_count / ncols; let ncols_pad = next_power_of_2(ncols); let cfg = LaunchConfig { grid_dim: (1, nrows as u32, 1), block_dim: (ncols_pad as u32, 1, 1), shared_mem_bytes: (ncols_pad * std::mem::size_of::()) as u32, }; let stream = dev.cuda_stream(); let mut builder = stream.launch_builder(&func); let ncols = ncols as i32; let ncols_pad = ncols_pad as i32; builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad); unsafe { builder.launch(cfg) }.w()?; Ok(S::U32(dst)) } } } impl crate::CustomOp1 for ArgSort { fn name(&self) -> &'static str { "argsort" } fn cpu_fwd( &self, storage: &crate::CpuStorage, layout: &crate::Layout, ) -> Result<(crate::CpuStorage, crate::Shape)> { let sort_indexes = match storage { crate::CpuStorage::U8(vs) => self.asort(vs, layout), crate::CpuStorage::U32(vs) => self.asort(vs, layout), crate::CpuStorage::I64(vs) => self.asort(vs, layout), crate::CpuStorage::BF16(vs) => self.asort(vs, layout), crate::CpuStorage::F16(vs) => self.asort(vs, layout), crate::CpuStorage::F32(vs) => self.asort(vs, layout), crate::CpuStorage::F64(vs) => self.asort(vs, layout), }; let sort_indexes = crate::CpuStorage::U32(sort_indexes); Ok((sort_indexes, layout.shape().into())) } #[cfg(feature = "cuda")] fn cuda_fwd( &self, storage: &crate::CudaStorage, layout: &crate::Layout, ) -> Result<(crate::CudaStorage, crate::Shape)> { use crate::backend::BackendStorage; use crate::cuda_backend::Map1Any; let dev = storage.device(); let slice = self.map(&storage.slice, dev, layout)?; let dst = crate::cuda_backend::CudaStorage { slice, device: dev.clone(), }; Ok((dst, layout.shape().clone())) } #[cfg(feature = "metal")] fn metal_fwd( &self, storage: &crate::MetalStorage, layout: &crate::Layout, ) -> Result<(crate::MetalStorage, crate::Shape)> { use crate::backend::BackendStorage; use crate::DType; let name = { if self.asc { match storage.dtype() { DType::BF16 => "asort_asc_bf16", DType::F16 => "asort_asc_f16", DType::F32 => "asort_asc_f32", DType::F64 => "asort_asc_f64", DType::U8 => "asort_asc_u8", DType::U32 => "asort_asc_u32", DType::I64 => "asort_asc_i64", } } else { match storage.dtype() { DType::BF16 => "asort_desc_bf16", DType::F16 => "asort_desc_f16", DType::F32 => "asort_desc_f32", DType::F64 => "asort_desc_f64", DType::U8 => "asort_desc_u8", DType::U32 => "asort_desc_u32", DType::I64 => "asort_desc_i64", } } }; let device = storage.device(); let kernels = device.kernels(); let command_buffer = device.command_buffer()?; let el = layout.shape().elem_count(); let ncols = self.last_dim; let nrows = el / ncols; let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype()); let dst = device.new_buffer(el, DType::U32, "asort")?; let mut ncols_pad = 1; while ncols_pad < ncols { ncols_pad *= 2; } candle_metal_kernels::call_arg_sort( device.metal_device(), &command_buffer, kernels, name, nrows, ncols, ncols_pad, src, &dst, ) .map_err(crate::Error::wrap)?; let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32); Ok((dst, layout.shape().clone())) } } #[allow(unused)] fn next_power_of_2(x: usize) -> usize { let mut n = 1; while n < x { n *= 2 } n } impl Tensor { /// Returns the indices that sort the tensor along the last dimension. /// /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in /// descending order. The sort is unstable so there is no guarantees on the final order when it /// comes to ties. pub fn arg_sort_last_dim(&self, asc: bool) -> Result { if !self.is_contiguous() { return Err(crate::Error::RequiresContiguous { op: "arg_sort_last_dim", }); } let last_dim = match self.dims().last() { None => crate::bail!("empty last-dim in arg-sort"), Some(last_dim) => *last_dim, }; // No need for a backward pass for arg sort. self.apply_op1_no_bwd(&ArgSort { asc, last_dim }) } /// Sorts the tensor along the last dimension, returns the sorted tensor together with the /// sorted indexes. /// /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in /// descending order. The sort is unstable so there is no guarantees on the final order when it /// comes to ties. pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> { if !self.is_contiguous() { return Err(crate::Error::RequiresContiguous { op: "sort_last_dim", }); } let asort = self.arg_sort_last_dim(asc)?; let sorted = self.gather(&asort, crate::D::Minus1)?; Ok((sorted, asort)) } }