mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Bugfix for the strided copy + add some assertions.
This commit is contained in:
@ -182,6 +182,9 @@ impl CpuStorage {
|
|||||||
src_stride: &[usize],
|
src_stride: &[usize],
|
||||||
dst_offset: usize,
|
dst_offset: usize,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
|
if src_shape.rank() != src_stride.len() {
|
||||||
|
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
||||||
|
}
|
||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::F32(src), Self::F32(dst)) => {
|
(Self::F32(src), Self::F32(dst)) => {
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_shape.is_contiguous(src_stride) {
|
||||||
|
@ -446,6 +446,9 @@ impl CudaStorage {
|
|||||||
src_stride: &[usize],
|
src_stride: &[usize],
|
||||||
dst_offset: usize,
|
dst_offset: usize,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
|
if src_shape.rank() != src_stride.len() {
|
||||||
|
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
||||||
|
}
|
||||||
let dims = src_shape.dims();
|
let dims = src_shape.dims();
|
||||||
let el_count = src_shape.elem_count();
|
let el_count = src_shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
|
@ -668,7 +668,7 @@ impl Tensor {
|
|||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||||
self.storage
|
self.storage
|
||||||
.copy_strided_src(&mut storage, shape, &self.stride, 0)?;
|
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage,
|
storage,
|
||||||
|
Reference in New Issue
Block a user