mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add some quantized functions to pyo3. (#708)
This commit is contained in:
@ -2,7 +2,8 @@
|
|||||||
// TODO: Handle negative dimension indexes.
|
// TODO: Handle negative dimension indexes.
|
||||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::PyTuple;
|
use pyo3::types::{IntoPyDict, PyTuple};
|
||||||
|
use pyo3::ToPyObject;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
@ -583,6 +584,45 @@ impl PyQTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
|
let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
|
||||||
|
let res = res
|
||||||
|
.into_iter()
|
||||||
|
.map(|(key, value)| (key, PyTensor(value).into_py(py)))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
Ok(res.into_py_dict(py).to_object(py))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
|
let mut file = std::fs::File::open(path)?;
|
||||||
|
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
|
||||||
|
let res = ggml
|
||||||
|
.tensors
|
||||||
|
.into_iter()
|
||||||
|
.map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))))
|
||||||
|
.collect::<::candle::Result<Vec<_>>>()
|
||||||
|
.map_err(wrap_err)?;
|
||||||
|
Ok(res.into_py_dict(py).to_object(py))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
|
let mut file = std::fs::File::open(path)?;
|
||||||
|
let gguf = ::candle::quantized::gguf_file::Content::read(&mut file).map_err(wrap_err)?;
|
||||||
|
let res = gguf
|
||||||
|
.tensor_infos
|
||||||
|
.keys()
|
||||||
|
.map(|key| {
|
||||||
|
let qtensor = gguf.tensor(&mut file, key)?;
|
||||||
|
Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))
|
||||||
|
})
|
||||||
|
.collect::<::candle::Result<Vec<_>>>()
|
||||||
|
.map_err(wrap_err)?;
|
||||||
|
Ok(res.into_py_dict(py).to_object(py))
|
||||||
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
fn cuda_is_available() -> bool {
|
fn cuda_is_available() -> bool {
|
||||||
::candle::utils::cuda_is_available()
|
::candle::utils::cuda_is_available()
|
||||||
@ -627,6 +667,9 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add("f32", PyDType(DType::F32))?;
|
m.add("f32", PyDType(DType::F32))?;
|
||||||
m.add("f64", PyDType(DType::F64))?;
|
m.add("f64", PyDType(DType::F64))?;
|
||||||
m.add_function(wrap_pyfunction!(cat, m)?)?;
|
m.add_function(wrap_pyfunction!(cat, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(ones, m)?)?;
|
m.add_function(wrap_pyfunction!(ones, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(rand, m)?)?;
|
m.add_function(wrap_pyfunction!(rand, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(randn, m)?)?;
|
m.add_function(wrap_pyfunction!(randn, m)?)?;
|
||||||
|
Reference in New Issue
Block a user