Bugfix for the strided copy + add some assertions.

This commit is contained in:
laurent
2023-06-23 16:28:18 +01:00
parent bcfbb1dca1
commit 1936a1f0a3
3 changed files with 7 additions and 1 deletions

View File

@ -182,6 +182,9 @@ impl CpuStorage {
src_stride: &[usize],
dst_offset: usize,
) -> Result<()> {
if src_shape.rank() != src_stride.len() {
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
}
match (self, dst) {
(Self::F32(src), Self::F32(dst)) => {
if src_shape.is_contiguous(src_stride) {

View File

@ -446,6 +446,9 @@ impl CudaStorage {
src_stride: &[usize],
dst_offset: usize,
) -> Result<()> {
if src_shape.rank() != src_stride.len() {
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
}
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);

View File

@ -668,7 +668,7 @@ impl Tensor {
let shape = self.shape();
let mut storage = self.device().zeros(shape, self.dtype())?;
self.storage
.copy_strided_src(&mut storage, shape, &self.stride, 0)?;
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,