mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add map2.
This commit is contained in:
@ -32,21 +32,51 @@ trait Map1 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn wcond<T: Copy>(
|
type C = CpuStorage;
|
||||||
pred: &[u32],
|
trait Map2 {
|
||||||
layout: &Layout,
|
const OP: &'static str;
|
||||||
t: &[T],
|
fn f<T: WithDType + Copy + num_traits::Num + 'static>(
|
||||||
layout_t: &Layout,
|
&self,
|
||||||
f: &[T],
|
v1: &[T],
|
||||||
layout_f: &Layout,
|
l1: &Layout,
|
||||||
) -> Vec<T> {
|
v2: &[T],
|
||||||
match (
|
l2: &Layout,
|
||||||
layout.contiguous_offsets(),
|
) -> Result<Vec<T>>;
|
||||||
layout_t.contiguous_offsets(),
|
|
||||||
layout_f.contiguous_offsets(),
|
fn map(
|
||||||
|
&self,
|
||||||
|
v1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
v2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<CpuStorage> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WCond<'a>(&'a [u32], &'a Layout);
|
||||||
|
|
||||||
|
impl<'a> Map2 for WCond<'a> {
|
||||||
|
const OP: &'static str = "where";
|
||||||
|
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
|
||||||
|
let vs = match (
|
||||||
|
self.1.contiguous_offsets(),
|
||||||
|
t_l.contiguous_offsets(),
|
||||||
|
f_l.contiguous_offsets(),
|
||||||
) {
|
) {
|
||||||
(Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
|
(Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
|
||||||
let pred = &pred[o1..o2];
|
let pred = &self.0[o1..o2];
|
||||||
let t = &t[o_t1..o_t2];
|
let t = &t[o_t1..o_t2];
|
||||||
let f = &f[o_f1..o_f2];
|
let f = &f[o_f1..o_f2];
|
||||||
pred.iter()
|
pred.iter()
|
||||||
@ -54,11 +84,14 @@ fn wcond<T: Copy>(
|
|||||||
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
_ => layout
|
_ => self
|
||||||
|
.1
|
||||||
.strided_index()
|
.strided_index()
|
||||||
.zip(layout_t.strided_index().zip(layout_f.strided_index()))
|
.zip(t_l.strided_index().zip(f_l.strided_index()))
|
||||||
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
|
.map(|(i_p, (i_t, i_f))| if self.0[i_p] > 0 { t[i_t] } else { f[i_f] })
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
|
};
|
||||||
|
Ok(vs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,13 +217,18 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matmul<T: 'static + num_traits::Num + Copy>(
|
struct MatMul((usize, usize, usize, usize));
|
||||||
|
|
||||||
|
impl Map2 for MatMul {
|
||||||
|
const OP: &'static str = "mat_mul";
|
||||||
|
fn f<T: 'static + num_traits::Num + Copy>(
|
||||||
|
&self,
|
||||||
lhs: &[T],
|
lhs: &[T],
|
||||||
rhs: &[T],
|
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
|
||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
|
rhs: &[T],
|
||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Vec<T>> {
|
) -> Result<Vec<T>> {
|
||||||
|
let (b, m, n, k) = self.0;
|
||||||
let lhs = &lhs[lhs_l.start_offset()..];
|
let lhs = &lhs[lhs_l.start_offset()..];
|
||||||
let rhs = &rhs[rhs_l.start_offset()..];
|
let rhs = &rhs[rhs_l.start_offset()..];
|
||||||
let a_skip: usize = m * k;
|
let a_skip: usize = m * k;
|
||||||
@ -251,6 +289,7 @@ fn matmul<T: 'static + num_traits::Num + Copy>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CpuStorage {
|
impl CpuStorage {
|
||||||
@ -574,39 +613,13 @@ impl CpuStorage {
|
|||||||
&self,
|
&self,
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
t: &Self,
|
t: &Self,
|
||||||
layout_t: &Layout,
|
t_l: &Layout,
|
||||||
f: &Self,
|
f: &Self,
|
||||||
layout_f: &Layout,
|
f_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// TODO: Support types that could be casted to a boolean.
|
// TODO: Support types that could be casted to a boolean.
|
||||||
let pred = self.as_slice::<u32>()?;
|
let pred = self.as_slice::<u32>()?;
|
||||||
match (t, f) {
|
WCond(pred, layout).map(t, t_l, f, f_l)
|
||||||
(Self::BF16(t), Self::BF16(f)) => {
|
|
||||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
|
||||||
Ok(Self::BF16(data))
|
|
||||||
}
|
|
||||||
(Self::F16(t), Self::F16(f)) => {
|
|
||||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
|
||||||
Ok(Self::F16(data))
|
|
||||||
}
|
|
||||||
(Self::F32(t), Self::F32(f)) => {
|
|
||||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
|
||||||
Ok(Self::F32(data))
|
|
||||||
}
|
|
||||||
(Self::F64(t), Self::F64(f)) => {
|
|
||||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
|
||||||
Ok(Self::F64(data))
|
|
||||||
}
|
|
||||||
(Self::U32(t), Self::U32(f)) => {
|
|
||||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
|
||||||
Ok(Self::U32(data))
|
|
||||||
}
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: t.dtype(),
|
|
||||||
rhs: f.dtype(),
|
|
||||||
op: "where_cond",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||||
@ -628,25 +641,7 @@ impl CpuStorage {
|
|||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
match (self, rhs) {
|
MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
|
||||||
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
|
||||||
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
|
|
||||||
Ok(Self::F16(dst))
|
|
||||||
}
|
|
||||||
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
|
||||||
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
|
|
||||||
Ok(Self::F32(dst))
|
|
||||||
}
|
|
||||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
|
||||||
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
|
|
||||||
Ok(Self::F64(dst))
|
|
||||||
}
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: self.dtype(),
|
|
||||||
rhs: rhs.dtype(),
|
|
||||||
op: "matmul",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
||||||
|
Reference in New Issue
Block a user