Generate *.pyi stubs for PyO3 wrapper (#870)

* Begin to generate typehints.

* generate correct stubs

* Correctly include stubs

* Add comments and typhints to static functions

* ensure candle-pyo3 directory

* Make `llama.rope.freq_base` optional

* `fmt`
This commit is contained in:
Lukas Kreussel
2023-09-16 18:23:38 +02:00
committed by GitHub
parent 7cafca835a
commit 8658df3485
15 changed files with 857 additions and 40 deletions

View File

@ -197,38 +197,40 @@ trait MapDType {
#[pymethods]
impl PyTensor {
#[new]
#[pyo3(text_signature = "(data:_ArrayLike)")]
// TODO: Handle arbitrary input dtype and shape.
fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
fn new(py: Python<'_>, data: PyObject) -> PyResult<Self> {
use Device::Cpu;
let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
let tensor = if let Ok(vs) = data.extract::<u32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<i64>(py) {
} else if let Ok(vs) = data.extract::<i64>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<f32>(py) {
} else if let Ok(vs) = data.extract::<f32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
} else if let Ok(vs) = data.extract::<Vec<u32>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
} else if let Ok(vs) = data.extract::<Vec<i64>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
} else if let Ok(vs) = data.extract::<Vec<f32>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) {
} else if let Ok(vs) = data.extract::<Vec<Vec<u32>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) {
} else if let Ok(vs) = data.extract::<Vec<Vec<i64>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) {
} else if let Ok(vs) = data.extract::<Vec<Vec<f32>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) {
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<u32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) {
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<i64>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) {
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else {
let ty = vs.as_ref(py).get_type();
let ty = data.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
"incorrect type {ty} for tensor"
)))?
@ -236,7 +238,7 @@ impl PyTensor {
Ok(Self(tensor))
}
/// Gets the tensor data as a Python value/array/array of array/...
/// Gets the tensor's data as a Python scalar or array-like object.
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
struct M<'a>(Python<'a>);
impl<'a> MapDType for M<'a> {
@ -280,6 +282,7 @@ impl PyTensor {
}
#[getter]
/// Gets the tensor shape as a Python tuple.
fn shape(&self, py: Python<'_>) -> PyObject {
PyTuple::new(py, self.0.dims()).to_object(py)
}
@ -580,8 +583,9 @@ impl PyTensor {
}
}
/// Concatenate the tensors across one axis.
#[pyfunction]
#[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")]
/// Concatenate the tensors across one axis.
fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
if tensors.is_empty() {
return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
@ -593,6 +597,8 @@ fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
}
#[pyfunction]
#[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")]
/// Stack the tensors along a new axis.
fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
@ -600,12 +606,15 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
}
#[pyfunction]
fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
PyTensor::new(py, vs)
#[pyo3(text_signature = "(data:_ArrayLike)")]
/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
PyTensor::new(py, data)
}
#[pyfunction]
#[pyo3(signature = (shape, *, device=None))]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
/// Creates a new tensor with random values.
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
@ -613,7 +622,7 @@ fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<P
}
#[pyfunction]
#[pyo3(signature = (shape, *, device=None))]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
@ -621,7 +630,7 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<
}
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None))]
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
fn ones(
py: Python<'_>,
shape: PyShape,
@ -638,7 +647,7 @@ fn ones(
}
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None))]
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
fn zeros(
py: Python<'_>,
shape: PyShape,
@ -704,6 +713,8 @@ impl PyQTensor {
}
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
let res = res
@ -714,6 +725,8 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
}
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")]
/// Saves a dictionary of tensors to a safetensors file.
fn save_safetensors(
path: &str,
tensors: std::collections::HashMap<String, PyTensor>,
@ -726,6 +739,9 @@ fn save_safetensors(
}
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?;
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
@ -757,6 +773,9 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
}
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
/// and the second maps metadata keys to metadata values.
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
use ::candle::quantized::gguf_file;
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
@ -806,21 +825,25 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
}
#[pyfunction]
/// Returns true if the 'cuda' backend is available.
fn cuda_is_available() -> bool {
::candle::utils::cuda_is_available()
}
#[pyfunction]
/// Returns true if candle was compiled with 'accelerate' support.
fn has_accelerate() -> bool {
::candle::utils::has_accelerate()
}
#[pyfunction]
/// Returns true if candle was compiled with MKL support.
fn has_mkl() -> bool {
::candle::utils::has_mkl()
}
#[pyfunction]
/// Returns the number of threads used by the candle.
fn get_num_threads() -> usize {
::candle::utils::get_num_threads()
}
@ -830,19 +853,27 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
Ok(())
}
#[pyfunction]
fn softmax(t: PyTensor, dim: i64) -> PyResult<PyTensor> {
let dim = actual_dim(&t, dim).map_err(wrap_err)?;
let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?;
#[pyo3(text_signature = "(tensor:Tensor, dim:int)")]
/// Applies the Softmax function to a given tensor.
fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
Ok(PyTensor(sm))
}
#[pyfunction]
fn silu(t: PyTensor) -> PyResult<PyTensor> {
let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?;
#[pyo3(text_signature = "(tensor:Tensor)")]
/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
Ok(PyTensor(s))
}
@ -871,14 +902,10 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add("f32", PyDType(DType::F32))?;
m.add("f64", PyDType(DType::F64))?;
m.add_function(wrap_pyfunction!(cat, m)?)?;
m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
m.add_function(wrap_pyfunction!(ones, m)?)?;
m.add_function(wrap_pyfunction!(rand, m)?)?;
m.add_function(wrap_pyfunction!(randn, m)?)?;
m.add_function(wrap_pyfunction!(tensor, m)?)?;
m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
m.add_function(wrap_pyfunction!(stack, m)?)?;
m.add_function(wrap_pyfunction!(zeros, m)?)?;
Ok(())