mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Run the tensor tests for the cuda backend too.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
use crate::{CpuStorage, DType, Shape};
|
||||
use candle_kernels as kernels;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
|
||||
use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -904,7 +904,9 @@ impl CudaStorage {
|
||||
match (&self.slice, &mut dst.slice) {
|
||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
||||
let src = src.slice(..elem_to_copy);
|
||||
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
@ -917,7 +919,9 @@ impl CudaStorage {
|
||||
}
|
||||
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
||||
let src = src.slice(..elem_to_copy);
|
||||
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
@ -930,7 +934,9 @@ impl CudaStorage {
|
||||
}
|
||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
||||
let src = src.slice(..elem_to_copy);
|
||||
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
@ -943,7 +949,9 @@ impl CudaStorage {
|
||||
}
|
||||
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
||||
let src = src.slice(..elem_to_copy);
|
||||
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
@ -956,7 +964,8 @@ impl CudaStorage {
|
||||
}
|
||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
||||
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
|
Reference in New Issue
Block a user