mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Handle copying for the u32 type.
This commit is contained in:
@ -88,6 +88,23 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn copy_strided_src_<T: Copy>(
|
||||||
|
src: &[T],
|
||||||
|
dst: &mut [T],
|
||||||
|
src_shape: &Shape,
|
||||||
|
src_stride: &[usize],
|
||||||
|
dst_offset: usize,
|
||||||
|
) {
|
||||||
|
if src_shape.is_contiguous(src_stride) {
|
||||||
|
dst[dst_offset..].copy_from_slice(src)
|
||||||
|
} else {
|
||||||
|
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
||||||
|
for (dst_index, src_index) in src_indexes.enumerate() {
|
||||||
|
dst[dst_index + dst_offset] = src[src_index]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl CpuStorage {
|
impl CpuStorage {
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self {
|
match self {
|
||||||
@ -280,25 +297,14 @@ impl CpuStorage {
|
|||||||
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
||||||
}
|
}
|
||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
|
(Self::U32(src), Self::U32(dst)) => {
|
||||||
|
copy_strided_src_(src, dst, src_shape, src_stride, dst_offset)
|
||||||
|
}
|
||||||
(Self::F32(src), Self::F32(dst)) => {
|
(Self::F32(src), Self::F32(dst)) => {
|
||||||
if src_shape.is_contiguous(src_stride) {
|
copy_strided_src_(src, dst, src_shape, src_stride, dst_offset)
|
||||||
dst[dst_offset..].copy_from_slice(src)
|
|
||||||
} else {
|
|
||||||
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
|
||||||
for (dst_index, src_index) in src_indexes.enumerate() {
|
|
||||||
dst[dst_index + dst_offset] = src[src_index]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
(Self::F64(src), Self::F64(dst)) => {
|
(Self::F64(src), Self::F64(dst)) => {
|
||||||
if src_shape.is_contiguous(src_stride) {
|
copy_strided_src_(src, dst, src_shape, src_stride, dst_offset)
|
||||||
dst[dst_offset..].copy_from_slice(src)
|
|
||||||
} else {
|
|
||||||
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
|
||||||
for (dst_index, src_index) in src_indexes.enumerate() {
|
|
||||||
dst[dst_index + dst_offset] = src[src_index]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
(_, dst) => {
|
(_, dst) => {
|
||||||
// This should be covered by the dtype check above.
|
// This should be covered by the dtype check above.
|
||||||
|
Reference in New Issue
Block a user