mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add Map2.
This commit is contained in:
@ -244,6 +244,7 @@ enum CudaStorageSlice {
|
|||||||
F32(CudaSlice<f32>),
|
F32(CudaSlice<f32>),
|
||||||
F64(CudaSlice<f64>),
|
F64(CudaSlice<f64>),
|
||||||
}
|
}
|
||||||
|
type S = CudaStorageSlice;
|
||||||
|
|
||||||
trait Map1 {
|
trait Map1 {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
@ -253,13 +254,36 @@ trait Map1 {
|
|||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
) -> Result<CudaSlice<T>>;
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
fn map(&self, s: &CudaStorageSlice, d: &CudaDevice, l: &Layout) -> Result<CudaStorageSlice> {
|
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||||
let out = match s {
|
let out = match s {
|
||||||
CudaStorageSlice::U32(s) => CudaStorageSlice::U32(self.f(s, d, l)?),
|
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||||
CudaStorageSlice::BF16(s) => CudaStorageSlice::BF16(self.f(s, d, l)?),
|
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||||
CudaStorageSlice::F16(s) => CudaStorageSlice::F16(self.f(s, d, l)?),
|
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||||
CudaStorageSlice::F32(s) => CudaStorageSlice::F32(self.f(s, d, l)?),
|
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||||
CudaStorageSlice::F64(s) => CudaStorageSlice::F64(self.f(s, d, l)?),
|
S::F64(s) => S::F64(self.f(s, d, l)?),
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trait Map2 {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||||
|
let out = match (s1, s2) {
|
||||||
|
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
|
||||||
};
|
};
|
||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
@ -411,6 +435,44 @@ impl<'a> Map1 for Embedding<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
|
||||||
|
impl<'a> Map2 for WhereCond<'a> {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
t: &CudaSlice<T>,
|
||||||
|
layout_t: &Layout,
|
||||||
|
f: &CudaSlice<T>,
|
||||||
|
layout_f: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let ids_l = &self.1;
|
||||||
|
let ids = match &self.0.slice {
|
||||||
|
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
|
||||||
|
_ => Err(CudaError::UnexpectedDType {
|
||||||
|
msg: "where conditions should be u32",
|
||||||
|
expected: DType::U32,
|
||||||
|
got: self.0.dtype(),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let ids = &ids;
|
||||||
|
let shape = ids_l.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
|
let ds =
|
||||||
|
dev.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?;
|
||||||
|
let t = &t.slice(layout_t.start_offset()..);
|
||||||
|
let f = &f.slice(layout_f.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<T>(el) }?;
|
||||||
|
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||||
|
// SAFETY: ffi
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn slice_src_and_dst<'a, T>(
|
fn slice_src_and_dst<'a, T>(
|
||||||
src: &'a CudaSlice<T>,
|
src: &'a CudaSlice<T>,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
@ -714,86 +776,12 @@ impl CudaStorage {
|
|||||||
&self,
|
&self,
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
t: &Self,
|
t: &Self,
|
||||||
layout_t: &Layout,
|
t_l: &Layout,
|
||||||
f: &Self,
|
f: &Self,
|
||||||
layout_f: &Layout,
|
f_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let ids = match &self.slice {
|
let device = self.device().clone();
|
||||||
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
|
let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?;
|
||||||
_ => Err(CudaError::UnexpectedDType {
|
|
||||||
msg: "where conditions should be u32",
|
|
||||||
expected: DType::U32,
|
|
||||||
got: self.dtype(),
|
|
||||||
})?,
|
|
||||||
};
|
|
||||||
let ids = &ids;
|
|
||||||
let shape = layout.shape();
|
|
||||||
let dims = shape.dims();
|
|
||||||
let el = shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
||||||
let dev = self.device();
|
|
||||||
let ds =
|
|
||||||
dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?;
|
|
||||||
let slice = match (&t.slice, &f.slice) {
|
|
||||||
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
|
|
||||||
let t = &t.slice(layout_t.start_offset()..);
|
|
||||||
let f = &f.slice(layout_f.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<bf16>(el) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::BF16(out)
|
|
||||||
}
|
|
||||||
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
|
|
||||||
let t = &t.slice(layout_t.start_offset()..);
|
|
||||||
let f = &f.slice(layout_f.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f16>(el) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F16(out)
|
|
||||||
}
|
|
||||||
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
|
|
||||||
let t = &t.slice(layout_t.start_offset()..);
|
|
||||||
let f = &f.slice(layout_f.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f32>(el) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F32(out)
|
|
||||||
}
|
|
||||||
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
|
|
||||||
let t = &t.slice(layout_t.start_offset()..);
|
|
||||||
let f = &f.slice(layout_f.start_offset()..);
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
|
|
||||||
let out = unsafe { dev.alloc::<f64>(el) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F64(out)
|
|
||||||
}
|
|
||||||
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
|
|
||||||
let t = &t.slice(layout_t.start_offset()..);
|
|
||||||
let f = &f.slice(layout_f.start_offset()..);
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
|
|
||||||
let out = unsafe { dev.alloc::<u32>(el) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::U32(out)
|
|
||||||
}
|
|
||||||
// The dtypes should have been checked at this point so this is an internal error.
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
|
||||||
};
|
|
||||||
let device = dev.clone();
|
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user