Add a Context trait similar to anyhow::Context. (#2676)

* Add a Context trait similar to anyhow::Context.

* Switch two unwrap to context.
This commit is contained in:
Laurent Mazare
2024-12-22 09:18:13 +01:00
committed by GitHub
parent 5c2f893e5a
commit 62ced44ea9
13 changed files with 97 additions and 41 deletions

View File

@ -9,8 +9,14 @@ pub struct MatMulUnexpectedStriding {
pub msg: &'static str,
}
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
}
}
/// Main library error type.
#[derive(thiserror::Error, Debug)]
#[derive(thiserror::Error)]
pub enum Error {
// === DType Errors ===
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
@ -199,8 +205,14 @@ pub enum Error {
UnsupportedSafeTensorDtype(safetensors::Dtype),
/// Arbitrary errors wrapping.
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
#[error("{0}")]
Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
#[error("{context}\n{inner}")]
Context {
inner: Box<Self>,
context: Box<dyn std::fmt::Display + Send + Sync>,
},
/// Adding path information to an error.
#[error("path: {path:?} {inner}")]
@ -218,16 +230,19 @@ pub enum Error {
/// User generated error message, typically created via `bail!`.
#[error("{0}")]
Msg(String),
#[error("unwrap none")]
UnwrapNone,
}
pub type Result<T> = std::result::Result<T, Error>;
impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err)).bt()
}
pub fn msg(err: impl std::error::Error) -> Self {
pub fn msg(err: impl std::fmt::Display) -> Self {
Self::Msg(err.to_string()).bt()
}
@ -253,6 +268,13 @@ impl Error {
path: p.as_ref().to_path_buf(),
}
}
pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
Self::Context {
inner: Box::new(self),
context: Box::new(c),
}
}
}
#[macro_export]
@ -275,3 +297,41 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
(_, Err(e)) => Err(e),
}
}
// Taken from anyhow.
pub trait Context<T> {
/// Wrap the error value with additional context.
fn context<C>(self, context: C) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static;
/// Wrap the error value with additional context that is evaluated lazily
/// only once an error does occur.
fn with_context<C, F>(self, f: F) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C;
}
impl<T> Context<T> for Option<T> {
fn context<C>(self, context: C) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
{
match self {
Some(v) => Ok(v),
None => Err(Error::UnwrapNone.context(context).bt()),
}
}
fn with_context<C, F>(self, f: F) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C,
{
match self {
Some(v) => Ok(v),
None => Err(Error::UnwrapNone.context(f()).bt()),
}
}
}

View File

@ -94,7 +94,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use error::{Context, Error, Result};
pub use indexer::{IndexOp, TensorIndexer};
pub use layout::Layout;
pub use shape::{Shape, D};

View File

@ -1,7 +1,7 @@
//! Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point.
use crate::{DType, Error as E, Layout, Result, Tensor};
use crate::{Context, DType, Error as E, Layout, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
use std::io::BufRead;
@ -537,7 +537,7 @@ impl Stack {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
let key = objs.pop().context("empty objs")?;
d.push((key, value))
}
} else {
@ -557,7 +557,7 @@ impl Stack {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
let key = objs.pop().context("empty objs")?;
pydict.push((key, value))
}
self.push(Object::Dict(pydict))
@ -661,7 +661,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
if !file_name.ends_with("data.pkl") {
continue;
}
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
let reader = zip.by_name(file_name)?;
let mut reader = std::io::BufReader::new(reader);
let mut stack = Stack::empty();

View File

@ -2,7 +2,7 @@
//!
use super::{GgmlDType, QTensor};
use crate::{Device, Result};
use crate::{Context, Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
@ -338,7 +338,7 @@ impl Value {
if value_type.len() != 1 {
crate::bail!("multiple value-types in the same array {value_type:?}")
}
value_type.into_iter().next().unwrap()
value_type.into_iter().next().context("empty value_type")?
};
w.write_u32::<LittleEndian>(value_type.to_u32())?;
w.write_u64::<LittleEndian>(v.len() as u64)?;

View File

@ -1,5 +1,5 @@
//! Code for GGML and GGUF files
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
@ -481,7 +481,7 @@ impl crate::CustomOp1 for QTensor {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
let last_k = dst_shape.pop().context("empty dst_shape")?;
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}

View File

@ -1,4 +1,4 @@
use crate::{shape::Dim, Error, Result, Shape, Tensor};
use crate::{shape::Dim, Context, Error, Result, Shape, Tensor};
impl Tensor {
/// Concatenates two or more tensors along a particular dimension.
@ -134,7 +134,7 @@ impl Tensor {
.bt())?
}
}
let next_offset = offsets.last().unwrap() + arg.elem_count();
let next_offset = offsets.last().context("empty offsets")? + arg.elem_count();
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);