mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Actually copy the data around in cat (cpu only).
This commit is contained in:
@ -108,6 +108,38 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn copy_strided_src(
|
||||||
|
&self,
|
||||||
|
dst: &mut Self,
|
||||||
|
src_shape: &Shape,
|
||||||
|
src_stride: &[usize],
|
||||||
|
dst_offset: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
// TODO: Optimize the contiguous case.
|
||||||
|
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
||||||
|
match (self, dst) {
|
||||||
|
(Self::F32(src), Self::F32(dst)) => {
|
||||||
|
for (dst_index, src_index) in src_indexes.enumerate() {
|
||||||
|
dst[dst_index + dst_offset] = src[src_index]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(Self::F64(src), Self::F64(dst)) => {
|
||||||
|
for (dst_index, src_index) in src_indexes.enumerate() {
|
||||||
|
dst[dst_index + dst_offset] = src[src_index]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(_, dst) => {
|
||||||
|
// This should be covered by the dtype check above.
|
||||||
|
return Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: self.dtype(),
|
||||||
|
rhs: dst.dtype(),
|
||||||
|
op: "copy_strided",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul_impl(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
|
@ -132,7 +132,7 @@ impl Storage {
|
|||||||
self.same_device(rhs, "matmul")?;
|
self.same_device(rhs, "matmul")?;
|
||||||
self.same_dtype(rhs, "matmul")?;
|
self.same_dtype(rhs, "matmul")?;
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
|
||||||
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
|
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
@ -151,11 +151,23 @@ impl Storage {
|
|||||||
// self, the source can be strided whereas dst is contiguous.
|
// self, the source can be strided whereas dst is contiguous.
|
||||||
pub(crate) fn copy_strided_src(
|
pub(crate) fn copy_strided_src(
|
||||||
&self,
|
&self,
|
||||||
_dst: &mut Self,
|
dst: &mut Self,
|
||||||
_shape: &Shape,
|
src_shape: &Shape,
|
||||||
_stride: &[usize],
|
src_stride: &[usize],
|
||||||
_offset: usize,
|
dst_offset: usize,
|
||||||
) {
|
) -> Result<()> {
|
||||||
|
match (self, dst) {
|
||||||
|
(Self::Cpu(src), Self::Cpu(dst)) => {
|
||||||
|
src.copy_strided_src(dst, src_shape, src_stride, dst_offset)
|
||||||
|
}
|
||||||
|
(Self::Cuda(_src), Self::Cuda(_dst)) => {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: lhs.device().location(),
|
||||||
|
rhs: rhs.device().location(),
|
||||||
|
op: "copy",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -648,7 +648,7 @@ impl Tensor {
|
|||||||
let mut storage = device.zeros(&shape, dtype)?;
|
let mut storage = device.zeros(&shape, dtype)?;
|
||||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||||
arg.storage
|
arg.storage
|
||||||
.copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)
|
.copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)?
|
||||||
}
|
}
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
|
Reference in New Issue
Block a user