mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Propagate the layout refactoring.
This commit is contained in:
@ -24,21 +24,22 @@ fn wcond<T: Copy>(
|
||||
f: &[T],
|
||||
layout_f: &Layout,
|
||||
) -> Vec<T> {
|
||||
if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f)
|
||||
{
|
||||
let elem_count = shape.elem_count();
|
||||
let pred = &pred[..elem_count];
|
||||
let t = &t[..elem_count];
|
||||
let f = &f[..elem_count];
|
||||
if layout.is_contiguous() && layout_t.is_contiguous() && layout_f.is_contiguous() {
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let offset = layout.start_offset();
|
||||
let offset_t = layout_t.start_offset();
|
||||
let offset_f = layout_f.start_offset();
|
||||
let pred = &pred[offset..offset + elem_count];
|
||||
let t = &t[offset_t..offset_t + elem_count];
|
||||
let f = &f[offset_f..offset_f + elem_count];
|
||||
pred.iter()
|
||||
.zip(t.iter().zip(f.iter()))
|
||||
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
let dims = shape.dims();
|
||||
let it_p = StridedIndex::new(dims, stride);
|
||||
let it_t = StridedIndex::new(dims, stride_t);
|
||||
let it_f = StridedIndex::new(dims, stride_f);
|
||||
let it_p = StridedIndex::new(layout);
|
||||
let it_t = StridedIndex::new(layout_t);
|
||||
let it_f = StridedIndex::new(layout_f);
|
||||
it_p.zip(it_t.zip(it_f))
|
||||
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
|
||||
.collect::<Vec<_>>()
|
||||
@ -107,13 +108,13 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
fn take_impl1<T: Copy>(
|
||||
vs: &[T],
|
||||
ids: &[u32],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
layout: &Layout,
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<T>> {
|
||||
let mut values = Vec::with_capacity(shape.elem_count() * hidden_size);
|
||||
for index in StridedIndex::new(shape.dims(), stride) {
|
||||
// TODO: Optimize for the case where ids are contiguous.
|
||||
let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size);
|
||||
for index in StridedIndex::new(layout) {
|
||||
let index = ids[index].try_into()?;
|
||||
if index >= vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
@ -132,16 +133,14 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||
src: &[T],
|
||||
dst: &mut [T],
|
||||
dst_offset: usize,
|
||||
src_shape: &Shape,
|
||||
src_stride: &[usize],
|
||||
src_offset: usize,
|
||||
src_l: &Layout,
|
||||
) {
|
||||
let src = &src[src_offset..];
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
let src = &src[src_l.start_offset()..];
|
||||
if src_l.is_contiguous() {
|
||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
||||
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy])
|
||||
} else {
|
||||
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
||||
let src_indexes = StridedIndex::new(src_l);
|
||||
for (dst_index, src_index) in src_indexes.enumerate() {
|
||||
let dst_index = dst_index + dst_offset;
|
||||
if dst_index >= dst.len() {
|
||||
@ -556,29 +555,14 @@ impl CpuStorage {
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
dst_offset: usize,
|
||||
src_shape: &Shape,
|
||||
src_stride: &[usize],
|
||||
src_offset: usize,
|
||||
src_l: Layout,
|
||||
) -> Result<()> {
|
||||
if src_shape.rank() != src_stride.len() {
|
||||
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
||||
}
|
||||
match (self, dst) {
|
||||
(Self::U32(src), Self::U32(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::BF16(src), Self::BF16(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::F16(src), Self::F16(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::F32(src), Self::F32(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::F64(src), Self::F64(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(_, dst) => {
|
||||
// This should be covered by the dtype check above.
|
||||
return Err(Error::DTypeMismatchBinaryOp {
|
||||
@ -593,34 +577,33 @@ impl CpuStorage {
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
layout: &Layout,
|
||||
t: &Self,
|
||||
stride_t: &[usize],
|
||||
layout_t: &Layout,
|
||||
f: &Self,
|
||||
stride_f: &[usize],
|
||||
layout_f: &Layout,
|
||||
) -> Result<Self> {
|
||||
// TODO: Support types that could be casted to a boolean.
|
||||
let pred = self.as_slice::<u32>()?;
|
||||
match (t, f) {
|
||||
(Self::BF16(t), Self::BF16(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F16(t), Self::F16(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F32(t), Self::F32(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(t), Self::F64(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::U32(t), Self::U32(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
@ -631,16 +614,15 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
pub(crate) fn embedding(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
layout: &Layout,
|
||||
vs: &Self,
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
let ids = self.as_slice::<u32>()?;
|
||||
map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size)
|
||||
map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
|
Reference in New Issue
Block a user