mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Sketch the where_cond cuda kernel wrapper.
This commit is contained in:
@ -459,14 +459,59 @@ impl CudaStorage {
|
|||||||
|
|
||||||
pub(crate) fn where_cond(
|
pub(crate) fn where_cond(
|
||||||
&self,
|
&self,
|
||||||
_shape: &Shape,
|
shape: &Shape,
|
||||||
_stride: &[usize],
|
stride: &[usize],
|
||||||
_t: &Self,
|
t: &Self,
|
||||||
_stride_t: &[usize],
|
stride_t: &[usize],
|
||||||
_f: &Self,
|
f: &Self,
|
||||||
_stride_f: &[usize],
|
stride_f: &[usize],
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Err(CudaError::InternalError("TODO: implement where_cond"))
|
let ids = match &self.slice {
|
||||||
|
CudaStorageSlice::U32(slice) => slice,
|
||||||
|
_ => Err(CudaError::UnexpectedDType {
|
||||||
|
msg: "embedding ids should be u32",
|
||||||
|
expected: DType::U32,
|
||||||
|
got: self.dtype(),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
|
let dev = self.device();
|
||||||
|
let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?;
|
||||||
|
let slice = match (&t.slice, &f.slice) {
|
||||||
|
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
|
||||||
|
let func = dev.get_or_load_func("where_f32", kernels::BINARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<f32>(el) }?;
|
||||||
|
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||||
|
// SAFETY: ffi
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F32(out)
|
||||||
|
}
|
||||||
|
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let func = dev.get_or_load_func("where_f64", kernels::BINARY)?;
|
||||||
|
let out = unsafe { dev.alloc::<f64>(el) }?;
|
||||||
|
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||||
|
// SAFETY: ffi
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F64(out)
|
||||||
|
}
|
||||||
|
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let func = dev.get_or_load_func("where_u32", kernels::BINARY)?;
|
||||||
|
let out = unsafe { dev.alloc::<u32>(el) }?;
|
||||||
|
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||||
|
// SAFETY: ffi
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::U32(out)
|
||||||
|
}
|
||||||
|
// The dtypes should have been checked at this point so this is an internal error.
|
||||||
|
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
|
||||||
|
};
|
||||||
|
let device = dev.clone();
|
||||||
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding_impl(
|
||||||
|
Reference in New Issue
Block a user