Add map2.

This commit is contained in:
laurent
2023-06-28 21:38:01 +01:00
parent 46c07b924c
commit c583ee0f2c

View File

@ -32,33 +32,66 @@ trait Map1 {
} }
} }
fn wcond<T: Copy>( type C = CpuStorage;
pred: &[u32], trait Map2 {
layout: &Layout, const OP: &'static str;
t: &[T], fn f<T: WithDType + Copy + num_traits::Num + 'static>(
layout_t: &Layout, &self,
f: &[T], v1: &[T],
layout_f: &Layout, l1: &Layout,
) -> Vec<T> { v2: &[T],
match ( l2: &Layout,
layout.contiguous_offsets(), ) -> Result<Vec<T>>;
layout_t.contiguous_offsets(),
layout_f.contiguous_offsets(), fn map(
) { &self,
(Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => { v1: &CpuStorage,
let pred = &pred[o1..o2]; l1: &Layout,
let t = &t[o_t1..o_t2]; v2: &CpuStorage,
let f = &f[o_f1..o_f2]; l2: &Layout,
pred.iter() ) -> Result<CpuStorage> {
.zip(t.iter().zip(f.iter())) match (v1, v2) {
.map(|(&p, (&t, &f))| if p > 0 { t } else { f }) (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
.collect::<Vec<_>>() (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}),
} }
_ => layout }
.strided_index() }
.zip(layout_t.strided_index().zip(layout_f.strided_index()))
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] }) struct WCond<'a>(&'a [u32], &'a Layout);
.collect::<Vec<_>>(),
impl<'a> Map2 for WCond<'a> {
const OP: &'static str = "where";
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
let vs = match (
self.1.contiguous_offsets(),
t_l.contiguous_offsets(),
f_l.contiguous_offsets(),
) {
(Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
let pred = &self.0[o1..o2];
let t = &t[o_t1..o_t2];
let f = &f[o_f1..o_f2];
pred.iter()
.zip(t.iter().zip(f.iter()))
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
.collect::<Vec<_>>()
}
_ => self
.1
.strided_index()
.zip(t_l.strided_index().zip(f_l.strided_index()))
.map(|(i_p, (i_t, i_f))| if self.0[i_p] > 0 { t[i_t] } else { f[i_f] })
.collect::<Vec<_>>(),
};
Ok(vs)
} }
} }
@ -184,73 +217,79 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
} }
} }
fn matmul<T: 'static + num_traits::Num + Copy>( struct MatMul((usize, usize, usize, usize));
lhs: &[T],
rhs: &[T],
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Vec<T>> {
let lhs = &lhs[lhs_l.start_offset()..];
let rhs = &rhs[rhs_l.start_offset()..];
let a_skip: usize = m * k;
let b_skip: usize = n * k;
let c_skip: usize = m * n;
let lhs_stride = lhs_l.stride(); impl Map2 for MatMul {
let rhs_stride = rhs_l.stride(); const OP: &'static str = "mat_mul";
let rank = lhs_stride.len(); fn f<T: 'static + num_traits::Num + Copy>(
let lhs_cs = lhs_stride[rank - 1]; &self,
let lhs_rs = lhs_stride[rank - 2]; lhs: &[T],
lhs_l: &Layout,
rhs: &[T],
rhs_l: &Layout,
) -> Result<Vec<T>> {
let (b, m, n, k) = self.0;
let lhs = &lhs[lhs_l.start_offset()..];
let rhs = &rhs[rhs_l.start_offset()..];
let a_skip: usize = m * k;
let b_skip: usize = n * k;
let c_skip: usize = m * n;
let rhs_cs = rhs_stride[rank - 1]; let lhs_stride = lhs_l.stride();
let rhs_rs = rhs_stride[rank - 2]; let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let lhs_cs = lhs_stride[rank - 1];
let lhs_rs = lhs_stride[rank - 2];
if lhs_stride.len() > 2 { let rhs_cs = rhs_stride[rank - 1];
let lhs_batch_stride = &lhs_stride[..rank - 2]; let rhs_rs = rhs_stride[rank - 2];
let rhs_batch_stride = &rhs_stride[..rank - 2];
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] { if lhs_stride.len() > 2 {
// Temporary error before we support abitrary striding. let lhs_batch_stride = &lhs_stride[..rank - 2];
return Err(Error::UnexpectedStriding); let rhs_batch_stride = &rhs_stride[..rank - 2];
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
// Temporary error before we support abitrary striding.
return Err(Error::UnexpectedStriding);
}
} }
}
let dst_shape: Shape = (m, n).into(); let dst_shape: Shape = (m, n).into();
let dst_strides = dst_shape.stride_contiguous(); let dst_strides = dst_shape.stride_contiguous();
let dst_rs = dst_strides[0]; let dst_rs = dst_strides[0];
let dst_cs = dst_strides[1]; let dst_cs = dst_strides[1];
let mut dst = vec![T::zero(); b * m * n]; let mut dst = vec![T::zero(); b * m * n];
for step in 0..b { for step in 0..b {
let lhs_p = &lhs[step * a_skip..]; let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..]; let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..]; let dst_p = &mut dst[step * c_skip..];
unsafe { unsafe {
gemm( gemm(
/* m: usize = */ m, /* m: usize = */ m,
/* n: usize = */ n, /* n: usize = */ n,
/* k: usize = */ k, /* k: usize = */ k,
/* dst: *mut T = */ dst_p.as_mut_ptr(), /* dst: *mut T = */ dst_p.as_mut_ptr(),
/* dst_cs: isize = */ dst_cs as isize, /* dst_cs: isize = */ dst_cs as isize,
/* dst_rs: isize = */ dst_rs as isize, /* dst_rs: isize = */ dst_rs as isize,
/* read_dst: bool = */ false, /* read_dst: bool = */ false,
/* lhs: *const T = */ lhs_p.as_ptr(), /* lhs: *const T = */ lhs_p.as_ptr(),
/* lhs_cs: isize = */ lhs_cs as isize, /* lhs_cs: isize = */ lhs_cs as isize,
/* lhs_rs: isize = */ lhs_rs as isize, /* lhs_rs: isize = */ lhs_rs as isize,
/* rhs: *const T = */ rhs_p.as_ptr(), /* rhs: *const T = */ rhs_p.as_ptr(),
/* rhs_cs: isize = */ rhs_cs as isize, /* rhs_cs: isize = */ rhs_cs as isize,
/* rhs_rs: isize = */ rhs_rs as isize, /* rhs_rs: isize = */ rhs_rs as isize,
/* alpha: T = */ T::zero(), /* alpha: T = */ T::zero(),
/* beta: T = */ T::one(), /* beta: T = */ T::one(),
/* conj_dst: bool = */ false, /* conj_dst: bool = */ false,
/* conj_lhs: bool = */ false, /* conj_lhs: bool = */ false,
/* conj_rhs: bool = */ false, /* conj_rhs: bool = */ false,
Parallelism::Rayon(crate::utils::get_num_threads()), Parallelism::Rayon(crate::utils::get_num_threads()),
) )
}
} }
Ok(dst)
} }
Ok(dst)
} }
impl CpuStorage { impl CpuStorage {
@ -574,39 +613,13 @@ impl CpuStorage {
&self, &self,
layout: &Layout, layout: &Layout,
t: &Self, t: &Self,
layout_t: &Layout, t_l: &Layout,
f: &Self, f: &Self,
layout_f: &Layout, f_l: &Layout,
) -> Result<Self> { ) -> Result<Self> {
// TODO: Support types that could be casted to a boolean. // TODO: Support types that could be casted to a boolean.
let pred = self.as_slice::<u32>()?; let pred = self.as_slice::<u32>()?;
match (t, f) { WCond(pred, layout).map(t, t_l, f, f_l)
(Self::BF16(t), Self::BF16(f)) => {
let data = wcond(pred, layout, t, layout_t, f, layout_f);
Ok(Self::BF16(data))
}
(Self::F16(t), Self::F16(f)) => {
let data = wcond(pred, layout, t, layout_t, f, layout_f);
Ok(Self::F16(data))
}
(Self::F32(t), Self::F32(f)) => {
let data = wcond(pred, layout, t, layout_t, f, layout_f);
Ok(Self::F32(data))
}
(Self::F64(t), Self::F64(f)) => {
let data = wcond(pred, layout, t, layout_t, f, layout_f);
Ok(Self::F64(data))
}
(Self::U32(t), Self::U32(f)) => {
let data = wcond(pred, layout, t, layout_t, f, layout_f);
Ok(Self::U32(data))
}
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: t.dtype(),
rhs: f.dtype(),
op: "where_cond",
}),
}
} }
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
@ -628,25 +641,7 @@ impl CpuStorage {
lhs_l: &Layout, lhs_l: &Layout,
rhs_l: &Layout, rhs_l: &Layout,
) -> Result<Self> { ) -> Result<Self> {
match (self, rhs) { MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
Ok(Self::F16(dst))
}
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
Ok(Self::F32(dst))
}
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
Ok(Self::F64(dst))
}
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: rhs.dtype(),
op: "matmul",
}),
}
} }
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {