mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00

* Start updating to cudarc 0.14. * Adapt a couple more things. * And a couple more fixes. * More tweaks. * And a couple more fixes. * Bump the major version number. * Proper module system for the cuda kernels. * Proper ptx loading. * Launch the sort kernel. * Custom op. * Start using the builder pattern. * More builder. * More builder. * Get candle-core to compile. * Get the tests to pass. * Get candle-nn to work too. * Support for custom cuda functions. * cudnn fixes. * Get flash attn to run. * Switch the crate versions to be alpha. * Bump the ug dependency.
251 lines
8.8 KiB
Rust
251 lines
8.8 KiB
Rust
use crate::{Result, Tensor};
|
|
use rayon::prelude::*;
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
struct ArgSort {
|
|
asc: bool,
|
|
last_dim: usize,
|
|
}
|
|
|
|
impl ArgSort {
|
|
fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
|
|
#[allow(clippy::uninit_vec)]
|
|
// Safety: indexes are set later in the parallelized section.
|
|
let mut sort_indexes = unsafe {
|
|
let el_count = layout.shape().elem_count();
|
|
let mut v = Vec::with_capacity(el_count);
|
|
v.set_len(el_count);
|
|
v
|
|
};
|
|
if self.asc {
|
|
sort_indexes
|
|
.par_chunks_exact_mut(self.last_dim)
|
|
.zip(vs.par_chunks_exact(self.last_dim))
|
|
.for_each(|(indexes, vs)| {
|
|
indexes
|
|
.iter_mut()
|
|
.enumerate()
|
|
.for_each(|(i, v)| *v = i as u32);
|
|
indexes.sort_by(|&i, &j| {
|
|
vs[i as usize]
|
|
.partial_cmp(&vs[j as usize])
|
|
.unwrap_or(std::cmp::Ordering::Greater)
|
|
})
|
|
});
|
|
} else {
|
|
sort_indexes
|
|
.par_chunks_exact_mut(self.last_dim)
|
|
.zip(vs.par_chunks_exact(self.last_dim))
|
|
.for_each(|(indexes, vs)| {
|
|
indexes
|
|
.iter_mut()
|
|
.enumerate()
|
|
.for_each(|(i, v)| *v = i as u32);
|
|
indexes.sort_by(|&j, &i| {
|
|
vs[i as usize]
|
|
.partial_cmp(&vs[j as usize])
|
|
.unwrap_or(std::cmp::Ordering::Greater)
|
|
})
|
|
});
|
|
}
|
|
sort_indexes
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "cuda")]
|
|
mod cuda {
|
|
use super::*;
|
|
use crate::cuda_backend::cudarc::driver::{
|
|
CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
|
|
};
|
|
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
|
use crate::{CudaDevice, WithDType};
|
|
|
|
impl crate::cuda_backend::Map1Any for ArgSort {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
layout: &crate::Layout,
|
|
_wrap: W,
|
|
) -> Result<S> {
|
|
use cudarc::driver::PushKernelArg;
|
|
|
|
let slice = match layout.contiguous_offsets() {
|
|
None => crate::bail!("input has to be contiguous"),
|
|
Some((o1, o2)) => src.slice(o1..o2),
|
|
};
|
|
let elem_count = layout.shape().elem_count();
|
|
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
|
let func = if self.asc {
|
|
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
|
|
} else {
|
|
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
|
|
};
|
|
let ncols = self.last_dim;
|
|
let nrows = elem_count / ncols;
|
|
let ncols_pad = next_power_of_2(ncols);
|
|
let cfg = LaunchConfig {
|
|
grid_dim: (1, nrows as u32, 1),
|
|
block_dim: (ncols_pad as u32, 1, 1),
|
|
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
|
};
|
|
let stream = dev.cuda_stream();
|
|
let mut builder = stream.launch_builder(&func);
|
|
let ncols = ncols as i32;
|
|
let ncols_pad = ncols_pad as i32;
|
|
builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
|
|
unsafe { builder.launch(cfg) }.w()?;
|
|
Ok(S::U32(dst))
|
|
}
|
|
}
|
|
}
|
|
|
|
impl crate::CustomOp1 for ArgSort {
|
|
fn name(&self) -> &'static str {
|
|
"argsort"
|
|
}
|
|
|
|
fn cpu_fwd(
|
|
&self,
|
|
storage: &crate::CpuStorage,
|
|
layout: &crate::Layout,
|
|
) -> Result<(crate::CpuStorage, crate::Shape)> {
|
|
let sort_indexes = match storage {
|
|
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
|
|
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
|
|
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
|
|
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
|
|
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
|
|
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
|
|
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
|
|
};
|
|
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
|
|
Ok((sort_indexes, layout.shape().into()))
|
|
}
|
|
|
|
#[cfg(feature = "cuda")]
|
|
fn cuda_fwd(
|
|
&self,
|
|
storage: &crate::CudaStorage,
|
|
layout: &crate::Layout,
|
|
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
|
use crate::backend::BackendStorage;
|
|
use crate::cuda_backend::Map1Any;
|
|
let dev = storage.device();
|
|
let slice = self.map(&storage.slice, dev, layout)?;
|
|
let dst = crate::cuda_backend::CudaStorage {
|
|
slice,
|
|
device: dev.clone(),
|
|
};
|
|
Ok((dst, layout.shape().clone()))
|
|
}
|
|
|
|
#[cfg(feature = "metal")]
|
|
fn metal_fwd(
|
|
&self,
|
|
storage: &crate::MetalStorage,
|
|
layout: &crate::Layout,
|
|
) -> Result<(crate::MetalStorage, crate::Shape)> {
|
|
use crate::backend::BackendStorage;
|
|
use crate::DType;
|
|
|
|
let name = {
|
|
if self.asc {
|
|
match storage.dtype() {
|
|
DType::BF16 => "asort_asc_bf16",
|
|
DType::F16 => "asort_asc_f16",
|
|
DType::F32 => "asort_asc_f32",
|
|
DType::F64 => "asort_asc_f64",
|
|
DType::U8 => "asort_asc_u8",
|
|
DType::U32 => "asort_asc_u32",
|
|
DType::I64 => "asort_asc_i64",
|
|
}
|
|
} else {
|
|
match storage.dtype() {
|
|
DType::BF16 => "asort_desc_bf16",
|
|
DType::F16 => "asort_desc_f16",
|
|
DType::F32 => "asort_desc_f32",
|
|
DType::F64 => "asort_desc_f64",
|
|
DType::U8 => "asort_desc_u8",
|
|
DType::U32 => "asort_desc_u32",
|
|
DType::I64 => "asort_desc_i64",
|
|
}
|
|
}
|
|
};
|
|
let device = storage.device();
|
|
let kernels = device.kernels();
|
|
let command_buffer = device.command_buffer()?;
|
|
let el = layout.shape().elem_count();
|
|
let ncols = self.last_dim;
|
|
let nrows = el / ncols;
|
|
let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
|
|
let dst = device.new_buffer(el, DType::U32, "asort")?;
|
|
let mut ncols_pad = 1;
|
|
while ncols_pad < ncols {
|
|
ncols_pad *= 2;
|
|
}
|
|
candle_metal_kernels::call_arg_sort(
|
|
device.metal_device(),
|
|
&command_buffer,
|
|
kernels,
|
|
name,
|
|
nrows,
|
|
ncols,
|
|
ncols_pad,
|
|
src,
|
|
&dst,
|
|
)
|
|
.map_err(crate::Error::wrap)?;
|
|
let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
|
|
Ok((dst, layout.shape().clone()))
|
|
}
|
|
}
|
|
|
|
#[allow(unused)]
|
|
fn next_power_of_2(x: usize) -> usize {
|
|
let mut n = 1;
|
|
while n < x {
|
|
n *= 2
|
|
}
|
|
n
|
|
}
|
|
|
|
impl Tensor {
|
|
/// Returns the indices that sort the tensor along the last dimension.
|
|
///
|
|
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
|
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
|
/// comes to ties.
|
|
pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
|
|
if !self.is_contiguous() {
|
|
return Err(crate::Error::RequiresContiguous {
|
|
op: "arg_sort_last_dim",
|
|
});
|
|
}
|
|
let last_dim = match self.dims().last() {
|
|
None => crate::bail!("empty last-dim in arg-sort"),
|
|
Some(last_dim) => *last_dim,
|
|
};
|
|
// No need for a backward pass for arg sort.
|
|
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
|
|
}
|
|
|
|
/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
|
|
/// sorted indexes.
|
|
///
|
|
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
|
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
|
/// comes to ties.
|
|
pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
|
|
if !self.is_contiguous() {
|
|
return Err(crate::Error::RequiresContiguous {
|
|
op: "sort_last_dim",
|
|
});
|
|
}
|
|
let asort = self.arg_sort_last_dim(asc)?;
|
|
let sorted = self.gather(&asort, crate::D::Minus1)?;
|
|
Ok((sorted, asort))
|
|
}
|
|
}
|