Add a synchronize method to devices. (#2055)

* Add a synchronize method to devices.

* Metal version.
This commit is contained in:
Laurent Mazare
2024-04-14 16:32:55 +02:00
committed by GitHub
parent 50e49ecc5f
commit 53e5380bf6
6 changed files with 24 additions and 0 deletions

View File

@ -142,4 +142,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn set_seed(&self, _: u64) -> Result<()>;
/// Synchronize should block until all the operations on the device are completed.
fn synchronize(&self) -> Result<()>;
}

View File

@ -2628,6 +2628,10 @@ impl BackendDevice for CpuDevice {
};
Ok(storage)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}
#[macro_export]

View File

@ -407,4 +407,9 @@ impl BackendDevice for CudaDevice {
device: self.clone(),
})
}
fn synchronize(&self) -> Result<()> {
self.device.synchronize().map_err(crate::Error::wrap)?;
Ok(())
}
}

View File

@ -229,4 +229,8 @@ impl crate::backend::BackendDevice for CudaDevice {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}

View File

@ -241,4 +241,8 @@ impl crate::backend::BackendDevice for MetalDevice {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}

View File

@ -1790,6 +1790,10 @@ impl BackendDevice for MetalDevice {
Ok(())
}
fn synchronize(&self) -> Result<()> {
self.wait_until_completed()
}
}
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {