mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add the source offset when copying the data around.
This commit is contained in:
@ -91,16 +91,23 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
|||||||
fn copy_strided_src_<T: Copy>(
|
fn copy_strided_src_<T: Copy>(
|
||||||
src: &[T],
|
src: &[T],
|
||||||
dst: &mut [T],
|
dst: &mut [T],
|
||||||
|
dst_offset: usize,
|
||||||
src_shape: &Shape,
|
src_shape: &Shape,
|
||||||
src_stride: &[usize],
|
src_stride: &[usize],
|
||||||
dst_offset: usize,
|
src_offset: usize,
|
||||||
) {
|
) {
|
||||||
|
let src = &src[src_offset..];
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_shape.is_contiguous(src_stride) {
|
||||||
dst[dst_offset..].copy_from_slice(src)
|
let elem_to_copy = dst.len() - dst_offset;
|
||||||
|
dst[dst_offset..].copy_from_slice(&src[..elem_to_copy])
|
||||||
} else {
|
} else {
|
||||||
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
||||||
for (dst_index, src_index) in src_indexes.enumerate() {
|
for (dst_index, src_index) in src_indexes.enumerate() {
|
||||||
dst[dst_index + dst_offset] = src[src_index]
|
let dst_index = dst_index + dst_offset;
|
||||||
|
if dst_index >= dst.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
dst[dst_index] = src[src_index]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -289,22 +296,23 @@ impl CpuStorage {
|
|||||||
pub(crate) fn copy_strided_src(
|
pub(crate) fn copy_strided_src(
|
||||||
&self,
|
&self,
|
||||||
dst: &mut Self,
|
dst: &mut Self,
|
||||||
|
dst_offset: usize,
|
||||||
src_shape: &Shape,
|
src_shape: &Shape,
|
||||||
src_stride: &[usize],
|
src_stride: &[usize],
|
||||||
dst_offset: usize,
|
src_offset: usize,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if src_shape.rank() != src_stride.len() {
|
if src_shape.rank() != src_stride.len() {
|
||||||
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
||||||
}
|
}
|
||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::U32(src), Self::U32(dst)) => {
|
(Self::U32(src), Self::U32(dst)) => {
|
||||||
copy_strided_src_(src, dst, src_shape, src_stride, dst_offset)
|
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||||
}
|
}
|
||||||
(Self::F32(src), Self::F32(dst)) => {
|
(Self::F32(src), Self::F32(dst)) => {
|
||||||
copy_strided_src_(src, dst, src_shape, src_stride, dst_offset)
|
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||||
}
|
}
|
||||||
(Self::F64(src), Self::F64(dst)) => {
|
(Self::F64(src), Self::F64(dst)) => {
|
||||||
copy_strided_src_(src, dst, src_shape, src_stride, dst_offset)
|
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||||
}
|
}
|
||||||
(_, dst) => {
|
(_, dst) => {
|
||||||
// This should be covered by the dtype check above.
|
// This should be covered by the dtype check above.
|
||||||
|
@ -450,9 +450,10 @@ impl CudaStorage {
|
|||||||
pub(crate) fn copy_strided_src(
|
pub(crate) fn copy_strided_src(
|
||||||
&self,
|
&self,
|
||||||
dst: &mut Self,
|
dst: &mut Self,
|
||||||
|
dst_offset: usize,
|
||||||
src_shape: &Shape,
|
src_shape: &Shape,
|
||||||
src_stride: &[usize],
|
src_stride: &[usize],
|
||||||
dst_offset: usize,
|
src_offset: usize,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if src_shape.rank() != src_stride.len() {
|
if src_shape.rank() != src_stride.len() {
|
||||||
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
||||||
@ -464,26 +465,27 @@ impl CudaStorage {
|
|||||||
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
||||||
match (&self.slice, &mut dst.slice) {
|
match (&self.slice, &mut dst.slice) {
|
||||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||||
|
let src = src.slice(src_offset..);
|
||||||
let mut dst = dst.slice_mut(dst_offset..);
|
let mut dst = dst.slice_mut(dst_offset..);
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_shape.is_contiguous(src_stride) {
|
||||||
dev.dtod_copy(src, &mut dst)?
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let params = (el_count, dims.len(), &ds, src, &mut dst);
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }?
|
unsafe { func.launch(cfg, params) }?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||||
|
let src = src.slice(src_offset..);
|
||||||
let mut dst = dst.slice_mut(dst_offset..);
|
let mut dst = dst.slice_mut(dst_offset..);
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_shape.is_contiguous(src_stride) {
|
||||||
dev.dtod_copy(src, &mut dst)?
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
||||||
let mut dst = dst.slice_mut(dst_offset..);
|
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let params = (el_count, dims.len(), &ds, src, &mut dst);
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
}
|
}
|
||||||
|
@ -101,6 +101,7 @@ impl CudaStorage {
|
|||||||
pub(crate) fn copy_strided_src(
|
pub(crate) fn copy_strided_src(
|
||||||
&self,
|
&self,
|
||||||
_: &mut Self,
|
_: &mut Self,
|
||||||
|
_: usize,
|
||||||
_: &Shape,
|
_: &Shape,
|
||||||
_: &[usize],
|
_: &[usize],
|
||||||
_: usize,
|
_: usize,
|
||||||
|
@ -196,16 +196,17 @@ impl Storage {
|
|||||||
pub(crate) fn copy_strided_src(
|
pub(crate) fn copy_strided_src(
|
||||||
&self,
|
&self,
|
||||||
dst: &mut Self,
|
dst: &mut Self,
|
||||||
|
dst_offset: usize,
|
||||||
src_shape: &Shape,
|
src_shape: &Shape,
|
||||||
src_stride: &[usize],
|
src_stride: &[usize],
|
||||||
dst_offset: usize,
|
src_offset: usize,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::Cpu(src), Self::Cpu(dst)) => {
|
(Self::Cpu(src), Self::Cpu(dst)) => {
|
||||||
src.copy_strided_src(dst, src_shape, src_stride, dst_offset)
|
src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)
|
||||||
}
|
}
|
||||||
(Self::Cuda(src), Self::Cuda(dst)) => {
|
(Self::Cuda(src), Self::Cuda(dst)) => {
|
||||||
Ok(src.copy_strided_src(dst, src_shape, src_stride, dst_offset)?)
|
Ok(src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)?)
|
||||||
}
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
|
@ -317,9 +317,10 @@ impl Tensor {
|
|||||||
let mut dims = dims.to_vec();
|
let mut dims = dims.to_vec();
|
||||||
dims[dim] = length;
|
dims[dim] = length;
|
||||||
let shape = Shape::from(dims);
|
let shape = Shape::from(dims);
|
||||||
let storage = self.device().zeros(&shape, self.dtype())?;
|
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||||
// TODO: Actually copy the data, compared to copy_strided_src this requires a src start
|
let src_offset = 0; // TODO
|
||||||
// offset as well as a way to specify the number of elements to be copied.
|
self.storage
|
||||||
|
.copy_strided_src(&mut storage, 0, &self.shape, &self.stride, src_offset)?;
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Narrow(self.clone(), dim, start, length))
|
Some(Op::Narrow(self.clone(), dim, start, length))
|
||||||
} else {
|
} else {
|
||||||
@ -666,7 +667,7 @@ impl Tensor {
|
|||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||||
self.storage
|
self.storage
|
||||||
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
|
.copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?;
|
||||||
Ok(from_storage(
|
Ok(from_storage(
|
||||||
storage,
|
storage,
|
||||||
shape.clone(),
|
shape.clone(),
|
||||||
@ -709,7 +710,7 @@ impl Tensor {
|
|||||||
} else {
|
} else {
|
||||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||||
self.storage
|
self.storage
|
||||||
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
|
.copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?;
|
||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -786,7 +787,7 @@ impl Tensor {
|
|||||||
let mut storage = device.zeros(&shape, dtype)?;
|
let mut storage = device.zeros(&shape, dtype)?;
|
||||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||||
arg.storage
|
arg.storage
|
||||||
.copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)?
|
.copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?
|
||||||
}
|
}
|
||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user