mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14. * Adapt a couple more things. * And a couple more fixes. * More tweaks. * And a couple more fixes. * Bump the major version number. * Proper module system for the cuda kernels. * Proper ptx loading. * Launch the sort kernel. * Custom op. * Start using the builder pattern. * More builder. * More builder. * Get candle-core to compile. * Get the tests to pass. * Get candle-nn to work too. * Support for custom cuda functions. * cudnn fixes. * Get flash attn to run. * Switch the crate versions to be alpha. * Bump the ug dependency.
This commit is contained in:
@ -43,7 +43,7 @@ pub(crate) fn launch_conv2d<
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
return Ok(cudnn.clone());
|
||||
}
|
||||
let c = Cudnn::new(dev.cuda_device());
|
||||
let c = Cudnn::new(dev.cuda_stream());
|
||||
if let Ok(c) = &c {
|
||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||
}
|
||||
@ -109,7 +109,7 @@ pub(crate) fn launch_conv2d<
|
||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
||||
alg,
|
||||
|
@ -2,8 +2,9 @@ use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
||||
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
||||
use half::{bf16, f16};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
||||
@ -24,10 +25,17 @@ impl DeviceId {
|
||||
struct CudaRng(cudarc::curand::CudaRng);
|
||||
unsafe impl Send for CudaRng {}
|
||||
|
||||
pub struct ModuleStore {
|
||||
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CudaDevice {
|
||||
id: DeviceId,
|
||||
device: Arc<cudarc::driver::CudaDevice>,
|
||||
context: Arc<cudarc::driver::CudaContext>,
|
||||
modules: Arc<std::sync::RwLock<ModuleStore>>,
|
||||
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
|
||||
stream: Arc<cudarc::driver::CudaStream>,
|
||||
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
||||
curand: Arc<Mutex<CudaRng>>,
|
||||
}
|
||||
@ -39,16 +47,51 @@ impl std::fmt::Debug for CudaDevice {
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaDevice {
|
||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||
type Target = Arc<cudarc::driver::CudaStream>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
&self.stream
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CudaFunc {
|
||||
func: CudaFunction,
|
||||
stream: Arc<cudarc::driver::CudaStream>,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaFunc {
|
||||
type Target = CudaFunction;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.func
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaFunc {
|
||||
pub fn into_cuda_function(self) -> CudaFunction {
|
||||
self.func
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! builder_arg {
|
||||
($b:ident, $($arg:expr),*) => {
|
||||
$(
|
||||
let __arg = $arg;
|
||||
$b.arg(&__arg);
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl CudaFunc {
|
||||
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
|
||||
self.stream.launch_builder(&self.func)
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||
self.device.clone()
|
||||
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
|
||||
self.stream.clone()
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@ -56,7 +99,7 @@ impl CudaDevice {
|
||||
&self,
|
||||
func_name: &'static str,
|
||||
kernel: ug::lang::ssa::Kernel,
|
||||
) -> Result<CudaFunction> {
|
||||
) -> Result<CudaFunc> {
|
||||
let mut buf = vec![];
|
||||
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
|
||||
let cuda_code = String::from_utf8(buf)?;
|
||||
@ -65,12 +108,12 @@ impl CudaDevice {
|
||||
..Default::default()
|
||||
};
|
||||
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
|
||||
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
|
||||
let func = match self.device.get_func("ug", func_name) {
|
||||
Some(func) => func,
|
||||
None => crate::bail!("unknown function ug::{func_name}"),
|
||||
};
|
||||
Ok(func)
|
||||
let module = self.context.load_module(ptx).w()?;
|
||||
let func = module.load_function(func_name).w()?;
|
||||
Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DeviceId {
|
||||
@ -84,57 +127,84 @@ impl CudaDevice {
|
||||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
||||
let params = (&data, v as u8, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u8;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||
let params = (&data, v as u32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||
let params = (&data, v as i64, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as i64;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||
let params = (&data, bf16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = bf16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||
let params = (&data, f16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = f16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||
let params = (&data, v as f32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as f32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||
let params = (&data, v, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
@ -144,38 +214,69 @@ impl CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||
if !self.has_func(module_name, module_name) {
|
||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||
// done once per kernel name.
|
||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||
.map_err(|cuda| CudaError::Load {
|
||||
cuda,
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()?;
|
||||
pub fn get_or_load_custom_func(
|
||||
&self,
|
||||
fn_name: &str,
|
||||
module_name: &str,
|
||||
ptx: &str,
|
||||
) -> Result<CudaFunc> {
|
||||
let ms = self.custom_modules.read().unwrap();
|
||||
if let Some(mdl) = ms.get(module_name).as_ref() {
|
||||
let func = mdl.load_function(fn_name).w()?;
|
||||
return Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
});
|
||||
}
|
||||
self.get_func(module_name, module_name)
|
||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||
// able to only build the error value if needed.
|
||||
.ok_or(CudaError::MissingKernel {
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()
|
||||
drop(ms);
|
||||
let mut ms = self.custom_modules.write().unwrap();
|
||||
let cuda_module = self.context.load_module(ptx.into()).w()?;
|
||||
ms.insert(module_name.to_string(), cuda_module.clone());
|
||||
let func = cuda_module.load_function(fn_name).w()?;
|
||||
Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
|
||||
let ms = self.modules.read().unwrap();
|
||||
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
|
||||
let func = mdl.load_function(fn_name).w()?;
|
||||
return Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
});
|
||||
}
|
||||
drop(ms);
|
||||
let mut ms = self.modules.write().unwrap();
|
||||
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
|
||||
ms.mdls[mdl.index()] = Some(cuda_module.clone());
|
||||
let func = cuda_module.load_function(fn_name).w()?;
|
||||
Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||
let stream = context.new_stream().w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||
let module_store = ModuleStore {
|
||||
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||
};
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
context,
|
||||
stream,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -184,14 +285,21 @@ impl BackendDevice for CudaDevice {
|
||||
type Storage = CudaStorage;
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||
let stream = context.default_stream();
|
||||
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||
let module_store = ModuleStore {
|
||||
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||
};
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
context,
|
||||
stream,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
|
||||
@ -199,13 +307,13 @@ impl BackendDevice for CudaDevice {
|
||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||
// state will be identical and the same random numbers will be generated.
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.device.ordinal(),
|
||||
gpu_id: self.context.ordinal(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -373,31 +481,31 @@ impl BackendDevice for CudaDevice {
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let slice = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorageRef::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorageRef::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorageRef::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorageRef::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorageRef::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorageRef::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
@ -410,31 +518,31 @@ impl BackendDevice for CudaDevice {
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
@ -447,31 +555,31 @@ impl BackendDevice for CudaDevice {
|
||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
@ -482,7 +590,7 @@ impl BackendDevice for CudaDevice {
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
self.device.synchronize().map_err(crate::Error::wrap)?;
|
||||
self.stream.synchronize().map_err(crate::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user