mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Fix a cuda warning. (#2693)
This commit is contained in:
@ -52,6 +52,49 @@ impl ArgSort {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
mod cuda {
|
||||||
|
use super::*;
|
||||||
|
use crate::cuda_backend::cudarc::driver::{
|
||||||
|
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||||
|
};
|
||||||
|
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
||||||
|
use crate::{CudaDevice, WithDType};
|
||||||
|
|
||||||
|
impl crate::cuda_backend::Map1Any for ArgSort {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
_wrap: W,
|
||||||
|
) -> Result<S> {
|
||||||
|
let slice = match layout.contiguous_offsets() {
|
||||||
|
None => crate::bail!("input has to be contiguous"),
|
||||||
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
|
};
|
||||||
|
let elem_count = layout.shape().elem_count();
|
||||||
|
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||||
|
let func = if self.asc {
|
||||||
|
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
||||||
|
} else {
|
||||||
|
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
||||||
|
};
|
||||||
|
let ncols = self.last_dim;
|
||||||
|
let nrows = elem_count / ncols;
|
||||||
|
let ncols_pad = next_power_of_2(ncols);
|
||||||
|
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
||||||
|
let cfg = LaunchConfig {
|
||||||
|
grid_dim: (1, nrows as u32, 1),
|
||||||
|
block_dim: (ncols_pad as u32, 1, 1),
|
||||||
|
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||||
|
};
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(S::U32(dst))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl crate::CustomOp1 for ArgSort {
|
impl crate::CustomOp1 for ArgSort {
|
||||||
fn name(&self) -> &'static str {
|
fn name(&self) -> &'static str {
|
||||||
"argsort"
|
"argsort"
|
||||||
@ -81,46 +124,8 @@ impl crate::CustomOp1 for ArgSort {
|
|||||||
storage: &crate::CudaStorage,
|
storage: &crate::CudaStorage,
|
||||||
layout: &crate::Layout,
|
layout: &crate::Layout,
|
||||||
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
||||||
use crate::cuda_backend::cudarc::driver::{
|
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
|
||||||
};
|
|
||||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
|
|
||||||
use crate::{CudaDevice, WithDType};
|
|
||||||
|
|
||||||
impl Map1Any for ArgSort {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
|
||||||
&self,
|
|
||||||
src: &CudaSlice<T>,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
_wrap: W,
|
|
||||||
) -> Result<S> {
|
|
||||||
let slice = match layout.contiguous_offsets() {
|
|
||||||
None => crate::bail!("input has to be contiguous"),
|
|
||||||
Some((o1, o2)) => src.slice(o1..o2),
|
|
||||||
};
|
|
||||||
let elem_count = layout.shape().elem_count();
|
|
||||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
|
||||||
let func = if self.asc {
|
|
||||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
|
||||||
} else {
|
|
||||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
|
||||||
};
|
|
||||||
let ncols = self.last_dim;
|
|
||||||
let nrows = elem_count / ncols;
|
|
||||||
let ncols_pad = next_power_of_2(ncols);
|
|
||||||
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
|
||||||
let cfg = LaunchConfig {
|
|
||||||
grid_dim: (1, nrows as u32, 1),
|
|
||||||
block_dim: (ncols_pad as u32, 1, 1),
|
|
||||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
|
||||||
};
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(S::U32(dst))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
|
use crate::cuda_backend::Map1Any;
|
||||||
let dev = storage.device();
|
let dev = storage.device();
|
||||||
let slice = self.map(&storage.slice, dev, layout)?;
|
let slice = self.map(&storage.slice, dev, layout)?;
|
||||||
let dst = crate::cuda_backend::CudaStorage {
|
let dst = crate::cuda_backend::CudaStorage {
|
||||||
|
Reference in New Issue
Block a user