mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Actually copy the data around in cat (cpu only).
This commit is contained in:
@ -108,6 +108,38 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn copy_strided_src(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
src_shape: &Shape,
|
||||
src_stride: &[usize],
|
||||
dst_offset: usize,
|
||||
) -> Result<()> {
|
||||
// TODO: Optimize the contiguous case.
|
||||
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
||||
match (self, dst) {
|
||||
(Self::F32(src), Self::F32(dst)) => {
|
||||
for (dst_index, src_index) in src_indexes.enumerate() {
|
||||
dst[dst_index + dst_offset] = src[src_index]
|
||||
}
|
||||
}
|
||||
(Self::F64(src), Self::F64(dst)) => {
|
||||
for (dst_index, src_index) in src_indexes.enumerate() {
|
||||
dst[dst_index + dst_offset] = src[src_index]
|
||||
}
|
||||
}
|
||||
(_, dst) => {
|
||||
// This should be covered by the dtype check above.
|
||||
return Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: dst.dtype(),
|
||||
op: "copy_strided",
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
|
Reference in New Issue
Block a user