Run the tensor tests for the cuda backend too.

This commit is contained in:
laurent
2023-06-27 15:37:01 +01:00
parent b3622c972f
commit 07a682c2ff
2 changed files with 74 additions and 41 deletions

View File

@ -1,7 +1,7 @@
use crate::{CpuStorage, DType, Shape};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig};
use half::{bf16, f16};
use std::sync::Arc;
@ -904,7 +904,9 @@ impl CudaStorage {
match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
let src = src.slice(..elem_to_copy);
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
@ -917,7 +919,9 @@ impl CudaStorage {
}
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
let src = src.slice(..elem_to_copy);
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
@ -930,7 +934,9 @@ impl CudaStorage {
}
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
let src = src.slice(..elem_to_copy);
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
@ -943,7 +949,9 @@ impl CudaStorage {
}
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
let src = src.slice(..elem_to_copy);
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
@ -956,7 +964,8 @@ impl CudaStorage {
}
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {