mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Fix a cuda warning. (#2693)
This commit is contained in:
@ -52,42 +52,16 @@ impl ArgSort {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::CustomOp1 for ArgSort {
|
#[cfg(feature = "cuda")]
|
||||||
fn name(&self) -> &'static str {
|
mod cuda {
|
||||||
"argsort"
|
use super::*;
|
||||||
}
|
|
||||||
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
storage: &crate::CpuStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(crate::CpuStorage, crate::Shape)> {
|
|
||||||
let sort_indexes = match storage {
|
|
||||||
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
|
|
||||||
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
|
|
||||||
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
|
|
||||||
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
|
|
||||||
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
|
|
||||||
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
|
|
||||||
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
|
|
||||||
};
|
|
||||||
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
|
|
||||||
Ok((sort_indexes, layout.shape().into()))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
storage: &crate::CudaStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
|
||||||
use crate::cuda_backend::cudarc::driver::{
|
use crate::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||||
};
|
};
|
||||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
|
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
||||||
use crate::{CudaDevice, WithDType};
|
use crate::{CudaDevice, WithDType};
|
||||||
|
|
||||||
impl Map1Any for ArgSort {
|
impl crate::cuda_backend::Map1Any for ArgSort {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||||
&self,
|
&self,
|
||||||
src: &CudaSlice<T>,
|
src: &CudaSlice<T>,
|
||||||
@ -119,8 +93,39 @@ impl crate::CustomOp1 for ArgSort {
|
|||||||
Ok(S::U32(dst))
|
Ok(S::U32(dst))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::CustomOp1 for ArgSort {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"argsort"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &crate::CpuStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(crate::CpuStorage, crate::Shape)> {
|
||||||
|
let sort_indexes = match storage {
|
||||||
|
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
|
||||||
|
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
|
||||||
|
};
|
||||||
|
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
|
||||||
|
Ok((sort_indexes, layout.shape().into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &crate::CudaStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
|
use crate::cuda_backend::Map1Any;
|
||||||
let dev = storage.device();
|
let dev = storage.device();
|
||||||
let slice = self.map(&storage.slice, dev, layout)?;
|
let slice = self.map(&storage.slice, dev, layout)?;
|
||||||
let dst = crate::cuda_backend::CudaStorage {
|
let dst = crate::cuda_backend::CudaStorage {
|
||||||
|
Reference in New Issue
Block a user