Add some quantized functions to pyo3. (#708)

This commit is contained in:
Laurent Mazare
2023-09-01 19:45:36 +02:00
committed by GitHub
parent 2ed78ab336
commit 1e5b2cc1d5

View File

@ -2,7 +2,8 @@
// TODO: Handle negative dimension indexes. // TODO: Handle negative dimension indexes.
use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::PyTuple; use pyo3::types::{IntoPyDict, PyTuple};
use pyo3::ToPyObject;
use std::sync::Arc; use std::sync::Arc;
use half::{bf16, f16}; use half::{bf16, f16};
@ -583,6 +584,45 @@ impl PyQTensor {
} }
} }
#[pyfunction]
fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
let res = res
.into_iter()
.map(|(key, value)| (key, PyTensor(value).into_py(py)))
.collect::<Vec<_>>();
Ok(res.into_py_dict(py).to_object(py))
}
#[pyfunction]
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<PyObject> {
let mut file = std::fs::File::open(path)?;
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
let res = ggml
.tensors
.into_iter()
.map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))))
.collect::<::candle::Result<Vec<_>>>()
.map_err(wrap_err)?;
Ok(res.into_py_dict(py).to_object(py))
}
#[pyfunction]
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> {
let mut file = std::fs::File::open(path)?;
let gguf = ::candle::quantized::gguf_file::Content::read(&mut file).map_err(wrap_err)?;
let res = gguf
.tensor_infos
.keys()
.map(|key| {
let qtensor = gguf.tensor(&mut file, key)?;
Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))
})
.collect::<::candle::Result<Vec<_>>>()
.map_err(wrap_err)?;
Ok(res.into_py_dict(py).to_object(py))
}
#[pyfunction] #[pyfunction]
fn cuda_is_available() -> bool { fn cuda_is_available() -> bool {
::candle::utils::cuda_is_available() ::candle::utils::cuda_is_available()
@ -627,6 +667,9 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add("f32", PyDType(DType::F32))?; m.add("f32", PyDType(DType::F32))?;
m.add("f64", PyDType(DType::F64))?; m.add("f64", PyDType(DType::F64))?;
m.add_function(wrap_pyfunction!(cat, m)?)?; m.add_function(wrap_pyfunction!(cat, m)?)?;
m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
m.add_function(wrap_pyfunction!(ones, m)?)?; m.add_function(wrap_pyfunction!(ones, m)?)?;
m.add_function(wrap_pyfunction!(rand, m)?)?; m.add_function(wrap_pyfunction!(rand, m)?)?;
m.add_function(wrap_pyfunction!(randn, m)?)?; m.add_function(wrap_pyfunction!(randn, m)?)?;