mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Generate *.pyi
stubs for PyO3 wrapper (#870)
* Begin to generate typehints. * generate correct stubs * Correctly include stubs * Add comments and typhints to static functions * ensure candle-pyo3 directory * Make `llama.rope.freq_base` optional * `fmt`
This commit is contained in:
@ -197,38 +197,40 @@ trait MapDType {
|
||||
#[pymethods]
|
||||
impl PyTensor {
|
||||
#[new]
|
||||
#[pyo3(text_signature = "(data:_ArrayLike)")]
|
||||
// TODO: Handle arbitrary input dtype and shape.
|
||||
fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
|
||||
/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
|
||||
fn new(py: Python<'_>, data: PyObject) -> PyResult<Self> {
|
||||
use Device::Cpu;
|
||||
let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
|
||||
let tensor = if let Ok(vs) = data.extract::<u32>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<i64>(py) {
|
||||
} else if let Ok(vs) = data.extract::<i64>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<f32>(py) {
|
||||
} else if let Ok(vs) = data.extract::<f32>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
|
||||
} else if let Ok(vs) = data.extract::<Vec<u32>>(py) {
|
||||
let len = vs.len();
|
||||
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
|
||||
} else if let Ok(vs) = data.extract::<Vec<i64>>(py) {
|
||||
let len = vs.len();
|
||||
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
|
||||
} else if let Ok(vs) = data.extract::<Vec<f32>>(py) {
|
||||
let len = vs.len();
|
||||
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) {
|
||||
} else if let Ok(vs) = data.extract::<Vec<Vec<u32>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) {
|
||||
} else if let Ok(vs) = data.extract::<Vec<Vec<i64>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) {
|
||||
} else if let Ok(vs) = data.extract::<Vec<Vec<f32>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) {
|
||||
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<u32>>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) {
|
||||
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<i64>>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) {
|
||||
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else {
|
||||
let ty = vs.as_ref(py).get_type();
|
||||
let ty = data.as_ref(py).get_type();
|
||||
Err(PyTypeError::new_err(format!(
|
||||
"incorrect type {ty} for tensor"
|
||||
)))?
|
||||
@ -236,7 +238,7 @@ impl PyTensor {
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
/// Gets the tensor data as a Python value/array/array of array/...
|
||||
/// Gets the tensor's data as a Python scalar or array-like object.
|
||||
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||
struct M<'a>(Python<'a>);
|
||||
impl<'a> MapDType for M<'a> {
|
||||
@ -280,6 +282,7 @@ impl PyTensor {
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor shape as a Python tuple.
|
||||
fn shape(&self, py: Python<'_>) -> PyObject {
|
||||
PyTuple::new(py, self.0.dims()).to_object(py)
|
||||
}
|
||||
@ -580,8 +583,9 @@ impl PyTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Concatenate the tensors across one axis.
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")]
|
||||
/// Concatenate the tensors across one axis.
|
||||
fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
|
||||
if tensors.is_empty() {
|
||||
return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
|
||||
@ -593,6 +597,8 @@ fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")]
|
||||
/// Stack the tensors along a new axis.
|
||||
fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
|
||||
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
||||
let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
|
||||
@ -600,12 +606,15 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
|
||||
PyTensor::new(py, vs)
|
||||
#[pyo3(text_signature = "(data:_ArrayLike)")]
|
||||
/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
|
||||
fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
|
||||
PyTensor::new(py, data)
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, device=None))]
|
||||
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor with random values.
|
||||
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||
@ -613,7 +622,7 @@ fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<P
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, device=None))]
|
||||
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
|
||||
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||
@ -621,7 +630,7 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None))]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
fn ones(
|
||||
py: Python<'_>,
|
||||
shape: PyShape,
|
||||
@ -638,7 +647,7 @@ fn ones(
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None))]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
fn zeros(
|
||||
py: Python<'_>,
|
||||
shape: PyShape,
|
||||
@ -704,6 +713,8 @@ impl PyQTensor {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||
/// Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
|
||||
fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||
let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
|
||||
let res = res
|
||||
@ -714,6 +725,8 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")]
|
||||
/// Saves a dictionary of tensors to a safetensors file.
|
||||
fn save_safetensors(
|
||||
path: &str,
|
||||
tensors: std::collections::HashMap<String, PyTensor>,
|
||||
@ -726,6 +739,9 @@ fn save_safetensors(
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
|
||||
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
|
||||
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
|
||||
let mut file = std::fs::File::open(path)?;
|
||||
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
|
||||
@ -757,6 +773,9 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
|
||||
/// and the second maps metadata keys to metadata values.
|
||||
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
|
||||
use ::candle::quantized::gguf_file;
|
||||
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
|
||||
@ -806,21 +825,25 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
/// Returns true if the 'cuda' backend is available.
|
||||
fn cuda_is_available() -> bool {
|
||||
::candle::utils::cuda_is_available()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
/// Returns true if candle was compiled with 'accelerate' support.
|
||||
fn has_accelerate() -> bool {
|
||||
::candle::utils::has_accelerate()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
/// Returns true if candle was compiled with MKL support.
|
||||
fn has_mkl() -> bool {
|
||||
::candle::utils::has_mkl()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
/// Returns the number of threads used by the candle.
|
||||
fn get_num_threads() -> usize {
|
||||
::candle::utils::get_num_threads()
|
||||
}
|
||||
@ -830,19 +853,27 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn softmax(t: PyTensor, dim: i64) -> PyResult<PyTensor> {
|
||||
let dim = actual_dim(&t, dim).map_err(wrap_err)?;
|
||||
let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?;
|
||||
#[pyo3(text_signature = "(tensor:Tensor, dim:int)")]
|
||||
/// Applies the Softmax function to a given tensor.
|
||||
fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
|
||||
let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
|
||||
let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(sm))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn silu(t: PyTensor) -> PyResult<PyTensor> {
|
||||
let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?;
|
||||
#[pyo3(text_signature = "(tensor:Tensor)")]
|
||||
/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
|
||||
fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||
let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
|
||||
Ok(PyTensor(s))
|
||||
}
|
||||
|
||||
@ -871,14 +902,10 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add("f32", PyDType(DType::F32))?;
|
||||
m.add("f64", PyDType(DType::F64))?;
|
||||
m.add_function(wrap_pyfunction!(cat, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(ones, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(rand, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(randn, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(tensor, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(stack, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(zeros, m)?)?;
|
||||
Ok(())
|
||||
|
Reference in New Issue
Block a user