Fix a cuda warning. (#2693)

This commit is contained in:
Laurent Mazare
2024-12-31 09:06:10 +01:00
committed by GitHub
parent 460616fc84
commit e38e2a85dd

View File

@ -52,42 +52,16 @@ impl ArgSort {
} }
} }
impl crate::CustomOp1 for ArgSort { #[cfg(feature = "cuda")]
fn name(&self) -> &'static str { mod cuda {
"argsort" use super::*;
}
fn cpu_fwd(
&self,
storage: &crate::CpuStorage,
layout: &crate::Layout,
) -> Result<(crate::CpuStorage, crate::Shape)> {
let sort_indexes = match storage {
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
};
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
Ok((sort_indexes, layout.shape().into()))
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
storage: &crate::CudaStorage,
layout: &crate::Layout,
) -> Result<(crate::CudaStorage, crate::Shape)> {
use crate::cuda_backend::cudarc::driver::{ use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits, CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
}; };
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr}; use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
use crate::{CudaDevice, WithDType}; use crate::{CudaDevice, WithDType};
impl Map1Any for ArgSort { impl crate::cuda_backend::Map1Any for ArgSort {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self, &self,
src: &CudaSlice<T>, src: &CudaSlice<T>,
@ -119,8 +93,39 @@ impl crate::CustomOp1 for ArgSort {
Ok(S::U32(dst)) Ok(S::U32(dst))
} }
} }
}
impl crate::CustomOp1 for ArgSort {
fn name(&self) -> &'static str {
"argsort"
}
fn cpu_fwd(
&self,
storage: &crate::CpuStorage,
layout: &crate::Layout,
) -> Result<(crate::CpuStorage, crate::Shape)> {
let sort_indexes = match storage {
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
};
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
Ok((sort_indexes, layout.shape().into()))
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
storage: &crate::CudaStorage,
layout: &crate::Layout,
) -> Result<(crate::CudaStorage, crate::Shape)> {
use crate::backend::BackendStorage; use crate::backend::BackendStorage;
use crate::cuda_backend::Map1Any;
let dev = storage.device(); let dev = storage.device();
let slice = self.map(&storage.slice, dev, layout)?; let slice = self.map(&storage.slice, dev, layout)?;
let dst = crate::cuda_backend::CudaStorage { let dst = crate::cuda_backend::CudaStorage {