mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Cleanup some todos. (#226)
* Cleanup some todos. * Fix more todo. * Optimize for the contiguous case. * Add the IntDType trait. * Handle the intdtype trait for more ops. * Remove a todo. * Remove a todo.
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
||||
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
@ -133,9 +133,9 @@ impl Map2U8 for Cmp {
|
||||
}
|
||||
}
|
||||
|
||||
struct WCond<'a>(&'a [u32], &'a Layout);
|
||||
struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
|
||||
|
||||
impl<'a> Map2 for WCond<'a> {
|
||||
impl<'a, I: IntDType> Map2 for WCond<'a, I> {
|
||||
const OP: &'static str = "where";
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
|
||||
@ -150,14 +150,20 @@ impl<'a> Map2 for WCond<'a> {
|
||||
let f = &f[o_f1..o_f2];
|
||||
pred.iter()
|
||||
.zip(t.iter().zip(f.iter()))
|
||||
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
||||
.map(|(p, (&t, &f))| if p.is_true() { t } else { f })
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
_ => self
|
||||
.1
|
||||
.strided_index()
|
||||
.zip(t_l.strided_index().zip(f_l.strided_index()))
|
||||
.map(|(i_p, (i_t, i_f))| if self.0[i_p] > 0 { t[i_t] } else { f[i_f] })
|
||||
.map(|(i_p, (i_t, i_f))| {
|
||||
if self.0[i_p].is_true() {
|
||||
t[i_t]
|
||||
} else {
|
||||
f[i_f]
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
};
|
||||
Ok(vs)
|
||||
@ -628,13 +634,13 @@ impl Map1 for Affine {
|
||||
}
|
||||
}
|
||||
|
||||
struct Gather<'a> {
|
||||
ids: &'a [u32],
|
||||
struct Gather<'a, I: IntDType> {
|
||||
ids: &'a [I],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Gather<'a> {
|
||||
impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
@ -663,7 +669,7 @@ impl<'a> Map1 for Gather<'a> {
|
||||
let start_dst_idx = start_dst_idx + i * dst_right_len;
|
||||
for right_i in 0..dst_right_len {
|
||||
let dst_idx = start_dst_idx + right_i;
|
||||
let index = ids[dst_idx] as usize;
|
||||
let index = ids[dst_idx].as_usize();
|
||||
if index >= src_dim_len {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
@ -681,13 +687,13 @@ impl<'a> Map1 for Gather<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct IndexSelect<'a> {
|
||||
ids: &'a [u32],
|
||||
struct IndexSelect<'a, T: IntDType> {
|
||||
ids: &'a [T],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map1 for IndexSelect<'a> {
|
||||
impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
@ -714,7 +720,7 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
let start_src_idx = left_i * right_len * src_dim;
|
||||
let start_dst_idx = left_i * right_len * n_ids;
|
||||
for i in 0..n_ids {
|
||||
let index = self.ids[self.ids_l.start_offset() + stride_ids * i] as usize;
|
||||
let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
|
||||
if index >= src_dim {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
@ -733,13 +739,13 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct ScatterAdd<'a> {
|
||||
ids: &'a [u32],
|
||||
struct ScatterAdd<'a, I: IntDType> {
|
||||
ids: &'a [I],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map2 for ScatterAdd<'a> {
|
||||
impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
const OP: &'static str = "scatter-add";
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let dst_len = l1.shape().elem_count();
|
||||
@ -771,7 +777,7 @@ impl<'a> Map2 for ScatterAdd<'a> {
|
||||
let start_ids_idx = start_ids_idx + i * ids_right_len;
|
||||
for right_i in 0..dst_right_len {
|
||||
let ids_idx = start_ids_idx + right_i;
|
||||
let index = ids[ids_idx] as usize;
|
||||
let index = ids[ids_idx].as_usize();
|
||||
if index >= dst_dim_len {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
@ -790,12 +796,12 @@ impl<'a> Map2 for ScatterAdd<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct IndexAdd<'a> {
|
||||
ids: &'a [u32],
|
||||
struct IndexAdd<'a, I: IntDType> {
|
||||
ids: &'a [I],
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map2 for IndexAdd<'a> {
|
||||
impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
const OP: &'static str = "index-add";
|
||||
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
|
||||
// v1, l1 -> self
|
||||
@ -811,8 +817,8 @@ impl<'a> Map2 for IndexAdd<'a> {
|
||||
let max_idx = l1.dims()[dim];
|
||||
let stride = src_l.stride()[dim];
|
||||
if dim == 0 {
|
||||
for (src_idx, &dst_idx) in self.ids.iter().enumerate() {
|
||||
let dst_idx = dst_idx as usize;
|
||||
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
||||
let dst_idx = dst_idx.as_usize();
|
||||
if dst_idx >= max_idx {
|
||||
Err(Error::InvalidIndex {
|
||||
index: dst_idx,
|
||||
@ -831,8 +837,8 @@ impl<'a> Map2 for IndexAdd<'a> {
|
||||
} else {
|
||||
let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
|
||||
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
|
||||
for (src_idx, &dst_idx) in self.ids.iter().enumerate() {
|
||||
let dst_idx = dst_idx as usize;
|
||||
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
||||
let dst_idx = dst_idx.as_usize();
|
||||
if dst_idx >= max_idx {
|
||||
Err(Error::InvalidIndex {
|
||||
index: dst_idx,
|
||||
@ -856,31 +862,52 @@ impl<'a> Map2 for IndexAdd<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding<'a> {
|
||||
struct Embedding<'a, I: IntDType> {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
ids: &'a [u32],
|
||||
ids: &'a [I],
|
||||
ids_l: &'a Layout,
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Embedding<'a> {
|
||||
impl<'a, I: IntDType> Map1 for Embedding<'a, I> {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
// TODO: We assume that vs is contiguous here.
|
||||
if !layout.is_contiguous() {
|
||||
Err(Error::RequiresContiguous { op: "embedding" })?
|
||||
}
|
||||
let vs = &vs[layout.start_offset()..];
|
||||
let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size);
|
||||
// TODO: Optimize for the case where ids are contiguous.
|
||||
for index in self.ids_l.strided_index() {
|
||||
let index = self.ids[index].try_into()?;
|
||||
if index >= self.vocab_size {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: self.vocab_size,
|
||||
op: "take",
|
||||
match self.ids_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => {
|
||||
for index in self.ids[o1..o2].iter() {
|
||||
let index = index.as_usize();
|
||||
if index >= self.vocab_size {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: self.vocab_size,
|
||||
op: "take",
|
||||
}
|
||||
.bt())?
|
||||
} else {
|
||||
let hidden_size = self.hidden_size;
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
for index in self.ids_l.strided_index() {
|
||||
let index = self.ids[index].as_usize();
|
||||
if index >= self.vocab_size {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: self.vocab_size,
|
||||
op: "take",
|
||||
}
|
||||
.bt())?
|
||||
} else {
|
||||
let hidden_size = self.hidden_size;
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
.bt())?
|
||||
} else {
|
||||
let hidden_size = self.hidden_size;
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
Ok(values)
|
||||
@ -1671,9 +1698,11 @@ impl BackendStorage for CpuStorage {
|
||||
f: &Self,
|
||||
f_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
// TODO: Support types that could be casted to a boolean.
|
||||
let pred = self.as_slice::<u32>()?;
|
||||
WCond(pred, layout).map(t, t_l, f, f_l)
|
||||
match self {
|
||||
Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
|
||||
Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
|
||||
}
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
@ -1687,25 +1716,40 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
|
||||
fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
let ids = self.as_slice::<u32>()?;
|
||||
let (vocab_size, hidden_size) = rhs_l.shape().dims2()?;
|
||||
Embedding {
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
ids,
|
||||
ids_l,
|
||||
match self {
|
||||
Self::U8(ids) => Embedding {
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
ids,
|
||||
ids_l,
|
||||
}
|
||||
.map(rhs, rhs_l),
|
||||
Self::U32(ids) => Embedding {
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
ids,
|
||||
ids_l,
|
||||
}
|
||||
.map(rhs, rhs_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "embedding")),
|
||||
}
|
||||
.map(rhs, rhs_l)
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
IndexSelect { ids, ids_l, dim }.map(self, l)
|
||||
match ids {
|
||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
|
||||
}
|
||||
}
|
||||
|
||||
fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
Gather { ids, ids_l, dim }.map(self, l)
|
||||
match ids {
|
||||
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
@ -1717,8 +1761,11 @@ impl BackendStorage for CpuStorage {
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l)
|
||||
match ids {
|
||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
|
||||
}
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
@ -1730,12 +1777,23 @@ impl BackendStorage for CpuStorage {
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
match ids {
|
||||
Self::U8(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::U32(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
||||
}
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
|
Reference in New Issue
Block a user