Start adding support for cuda.

This commit is contained in:
laurent
2023-06-21 18:11:56 +01:00
parent 7c317f9611
commit 2bfe8f18ab
5 changed files with 39 additions and 18 deletions

View File

@ -1,4 +1,4 @@
use crate::{DType, Device, Shape};
use crate::{DType, DeviceLocation, Shape};
/// Main library error type.
#[derive(thiserror::Error, Debug)]
@ -15,8 +15,8 @@ pub enum Error {
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
DeviceMismatchBinaryOp {
lhs: Device,
rhs: Device,
lhs: DeviceLocation,
rhs: DeviceLocation,
op: &'static str,
},
@ -33,6 +33,9 @@ pub enum Error {
got: usize,
shape: Shape,
},
#[error(transparent)]
Cudarc(#[from] cudarc::driver::DriverError),
}
pub type Result<T> = std::result::Result<T, Error>;