pyo3 update. (#2545)

* pyo3 update.

* Stub fix.
This commit is contained in:
Laurent Mazare
2024-10-06 10:09:38 +02:00
committed by GitHub
parent d2e432914e
commit f856b5c3a7
5 changed files with 22 additions and 27 deletions

View File

@ -6,7 +6,6 @@ use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::os::raw::c_long;
use std::sync::Arc;
use half::{bf16, f16};
@ -115,7 +114,7 @@ impl PyDevice {
}
impl<'source> FromPyObject<'source> for PyDevice {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
let device: String = ob.extract()?;
let device = match device.as_str() {
"cpu" => PyDevice::Cpu,
@ -217,11 +216,11 @@ enum Indexer {
IndexSelect(Tensor),
}
#[derive(Clone, Debug)]
#[derive(Debug)]
struct TorchTensor(PyObject);
impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
Ok(TorchTensor(numpy_value))
}
@ -540,7 +539,7 @@ impl PyTensor {
))
} else if let Ok(slice) = py_indexer.downcast::<pyo3::types::PySlice>() {
// Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
let index = slice.indices(dims[current_dim] as c_long)?;
let index = slice.indices(dims[current_dim] as isize)?;
Ok((
Indexer::Slice(index.start as usize, index.stop as usize),
current_dim + 1,
@ -1284,7 +1283,7 @@ fn save_safetensors(
}
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
#[pyo3(signature = (path, device = None))]
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
@ -1325,7 +1324,7 @@ fn load_ggml(
}
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
#[pyo3(signature = (path, device = None))]
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
/// and the second maps metadata keys to metadata values.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
@ -1384,7 +1383,7 @@ fn load_gguf(
#[pyfunction]
#[pyo3(
text_signature = "(path:Union[str,PathLike], tensors:Dict[str,QTensor], metadata:Dict[str,Any])"
signature = (path, tensors, metadata)
)]
/// Save quanitzed tensors and metadata to a GGUF file.
fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
@ -1430,7 +1429,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
Ok(v)
}
let tensors = tensors
.extract::<&PyDict>(py)
.downcast_bound::<PyDict>(py)
.map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
.iter()
.map(|(key, value)| {
@ -1443,7 +1442,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
.collect::<PyResult<Vec<_>>>()?;
let metadata = metadata
.extract::<&PyDict>(py)
.downcast_bound::<PyDict>(py)
.map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
.iter()
.map(|(key, value)| {