From 1e5b2cc1d5144dcbb86356b99d1aec91dc416473 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 1 Sep 2023 19:45:36 +0200 Subject: [PATCH] Add some quantized functions to pyo3. (#708) --- candle-pyo3/src/lib.rs | 45 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 7c28eb52..2673d843 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -2,7 +2,8 @@ // TODO: Handle negative dimension indexes. use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::PyTuple; +use pyo3::types::{IntoPyDict, PyTuple}; +use pyo3::ToPyObject; use std::sync::Arc; use half::{bf16, f16}; @@ -583,6 +584,45 @@ impl PyQTensor { } } +#[pyfunction] +fn load_safetensors(path: &str, py: Python<'_>) -> PyResult { + let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?; + let res = res + .into_iter() + .map(|(key, value)| (key, PyTensor(value).into_py(py))) + .collect::>(); + Ok(res.into_py_dict(py).to_object(py)) +} + +#[pyfunction] +fn load_ggml(path: &str, py: Python<'_>) -> PyResult { + let mut file = std::fs::File::open(path)?; + let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; + let res = ggml + .tensors + .into_iter() + .map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))) + .collect::<::candle::Result>>() + .map_err(wrap_err)?; + Ok(res.into_py_dict(py).to_object(py)) +} + +#[pyfunction] +fn load_gguf(path: &str, py: Python<'_>) -> PyResult { + let mut file = std::fs::File::open(path)?; + let gguf = ::candle::quantized::gguf_file::Content::read(&mut file).map_err(wrap_err)?; + let res = gguf + .tensor_infos + .keys() + .map(|key| { + let qtensor = gguf.tensor(&mut file, key)?; + Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))) + }) + .collect::<::candle::Result>>() + .map_err(wrap_err)?; + Ok(res.into_py_dict(py).to_object(py)) +} + #[pyfunction] fn cuda_is_available() -> bool { ::candle::utils::cuda_is_available() @@ -627,6 +667,9 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add("f32", PyDType(DType::F32))?; m.add("f64", PyDType(DType::F64))?; m.add_function(wrap_pyfunction!(cat, m)?)?; + m.add_function(wrap_pyfunction!(load_ggml, m)?)?; + m.add_function(wrap_pyfunction!(load_gguf, m)?)?; + m.add_function(wrap_pyfunction!(load_safetensors, m)?)?; m.add_function(wrap_pyfunction!(ones, m)?)?; m.add_function(wrap_pyfunction!(rand, m)?)?; m.add_function(wrap_pyfunction!(randn, m)?)?;