mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add support for cuda streams. (#2532)
This commit is contained in:
@ -144,6 +144,20 @@ impl CudaDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CudaDevice {
|
||||||
|
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
|
||||||
|
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
|
||||||
|
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||||
|
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||||
|
Ok(Self {
|
||||||
|
id: DeviceId::new(),
|
||||||
|
device,
|
||||||
|
blas: Arc::new(blas),
|
||||||
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl BackendDevice for CudaDevice {
|
impl BackendDevice for CudaDevice {
|
||||||
type Storage = CudaStorage;
|
type Storage = CudaStorage;
|
||||||
|
|
||||||
|
@ -130,6 +130,10 @@ impl Device {
|
|||||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
|
||||||
|
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
||||||
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,12 @@ macro_rules! fail {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CudaDevice {
|
||||||
|
pub fn new_with_stream(_: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl crate::backend::BackendStorage for CudaStorage {
|
impl crate::backend::BackendStorage for CudaStorage {
|
||||||
type Device = CudaDevice;
|
type Device = CudaDevice;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user