mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Bugfix for the strided copy + add some assertions.
This commit is contained in:
@ -182,6 +182,9 @@ impl CpuStorage {
|
||||
src_stride: &[usize],
|
||||
dst_offset: usize,
|
||||
) -> Result<()> {
|
||||
if src_shape.rank() != src_stride.len() {
|
||||
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
||||
}
|
||||
match (self, dst) {
|
||||
(Self::F32(src), Self::F32(dst)) => {
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
|
@ -446,6 +446,9 @@ impl CudaStorage {
|
||||
src_stride: &[usize],
|
||||
dst_offset: usize,
|
||||
) -> Result<()> {
|
||||
if src_shape.rank() != src_stride.len() {
|
||||
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
||||
}
|
||||
let dims = src_shape.dims();
|
||||
let el_count = src_shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
|
@ -668,7 +668,7 @@ impl Tensor {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||
self.storage
|
||||
.copy_strided_src(&mut storage, shape, &self.stride, 0)?;
|
||||
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
|
Reference in New Issue
Block a user