Add support for cuda streams. (#2532)

This commit is contained in:
Laurent Mazare
2024-10-02 21:30:58 +02:00
committed by GitHub
parent 936300678d
commit 7b60bda4ed
3 changed files with 24 additions and 0 deletions

View File

@ -144,6 +144,20 @@ impl CudaDevice {
}
}
impl CudaDevice {
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
Ok(Self {
id: DeviceId::new(),
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
})
}
}
impl BackendDevice for CudaDevice {
type Storage = CudaStorage;