Add dtype support.

This commit is contained in:
laurent
2023-07-02 20:12:26 +01:00
parent 65e069384c
commit 78871ffe38
3 changed files with 48 additions and 6 deletions

View File

@ -10,6 +10,24 @@ pub enum DType {
F64,
}
#[derive(Debug, PartialEq, Eq)]
pub struct DTypeParseError;
impl std::str::FromStr for DType {
type Err = DTypeParseError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"u8" => Ok(Self::U8),
"u32" => Ok(Self::U32),
"bf16" => Ok(Self::BF16),
"f16" => Ok(Self::F16),
"f32" => Ok(Self::F32),
"f64" => Ok(Self::F64),
_ => Err(DTypeParseError),
}
}
}
impl DType {
pub fn as_str(&self) -> &'static str {
match self {