Add Map2.

This commit is contained in:
laurent
2023-06-29 10:05:06 +01:00
parent 367170da45
commit 83c7d660ca

View File

@ -244,6 +244,7 @@ enum CudaStorageSlice {
F32(CudaSlice<f32>), F32(CudaSlice<f32>),
F64(CudaSlice<f64>), F64(CudaSlice<f64>),
} }
type S = CudaStorageSlice;
trait Map1 { trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -253,13 +254,36 @@ trait Map1 {
layout: &Layout, layout: &Layout,
) -> Result<CudaSlice<T>>; ) -> Result<CudaSlice<T>>;
fn map(&self, s: &CudaStorageSlice, d: &CudaDevice, l: &Layout) -> Result<CudaStorageSlice> { fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s { let out = match s {
CudaStorageSlice::U32(s) => CudaStorageSlice::U32(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?),
CudaStorageSlice::BF16(s) => CudaStorageSlice::BF16(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?),
CudaStorageSlice::F16(s) => CudaStorageSlice::F16(self.f(s, d, l)?), S::F16(s) => S::F16(self.f(s, d, l)?),
CudaStorageSlice::F32(s) => CudaStorageSlice::F32(self.f(s, d, l)?), S::F32(s) => S::F32(self.f(s, d, l)?),
CudaStorageSlice::F64(s) => CudaStorageSlice::F64(self.f(s, d, l)?), S::F64(s) => S::F64(self.f(s, d, l)?),
};
Ok(out)
}
}
trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
}; };
Ok(out) Ok(out)
} }
@ -411,6 +435,44 @@ impl<'a> Map1 for Embedding<'a> {
} }
} }
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
impl<'a> Map2 for WhereCond<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
t: &CudaSlice<T>,
layout_t: &Layout,
f: &CudaSlice<T>,
layout_f: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let ids_l = &self.1;
let ids = match &self.0.slice {
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u32",
expected: DType::U32,
got: self.0.dtype(),
})?,
};
let ids = &ids;
let shape = ids_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds =
dev.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?;
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
fn slice_src_and_dst<'a, T>( fn slice_src_and_dst<'a, T>(
src: &'a CudaSlice<T>, src: &'a CudaSlice<T>,
src_l: &Layout, src_l: &Layout,
@ -714,86 +776,12 @@ impl CudaStorage {
&self, &self,
layout: &Layout, layout: &Layout,
t: &Self, t: &Self,
layout_t: &Layout, t_l: &Layout,
f: &Self, f: &Self,
layout_f: &Layout, f_l: &Layout,
) -> Result<Self> { ) -> Result<Self> {
let ids = match &self.slice { let device = self.device().clone();
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?;
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u32",
expected: DType::U32,
got: self.dtype(),
})?,
};
let ids = &ids;
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds =
dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?;
let slice = match (&t.slice, &f.slice) {
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
let out = unsafe { dev.alloc::<f64>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
let out = unsafe { dev.alloc::<u32>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
// The dtypes should have been checked at this point so this is an internal error.
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
};
let device = dev.clone();
Ok(Self { slice, device }) Ok(Self { slice, device })
} }