mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add return types to *.pyi
stubs (#880)
* Start generating return types * Finish tensor type hinting * Add `save_gguf` to `utils` * Typehint `quant-llama.py`
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{IntoPyDict, PyTuple};
|
||||
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
|
||||
use pyo3::ToPyObject;
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -31,6 +31,7 @@ impl From<PyShape> for ::candle::Shape {
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[pyclass(name = "Tensor")]
|
||||
/// A `candle` tensor.
|
||||
struct PyTensor(Tensor);
|
||||
|
||||
impl std::ops::Deref for PyTensor {
|
||||
@ -43,6 +44,7 @@ impl std::ops::Deref for PyTensor {
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[pyclass(name = "DType")]
|
||||
/// A `candle` dtype.
|
||||
struct PyDType(DType);
|
||||
|
||||
#[pymethods]
|
||||
@ -197,7 +199,7 @@ trait MapDType {
|
||||
#[pymethods]
|
||||
impl PyTensor {
|
||||
#[new]
|
||||
#[pyo3(text_signature = "(data:_ArrayLike)")]
|
||||
#[pyo3(text_signature = "(self, data:_ArrayLike)")]
|
||||
// TODO: Handle arbitrary input dtype and shape.
|
||||
/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
|
||||
fn new(py: Python<'_>, data: PyObject) -> PyResult<Self> {
|
||||
@ -239,6 +241,7 @@ impl PyTensor {
|
||||
}
|
||||
|
||||
/// Gets the tensor's data as a Python scalar or array-like object.
|
||||
/// &RETURNS&: _ArrayLike
|
||||
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||
struct M<'a>(Python<'a>);
|
||||
impl<'a> MapDType for M<'a> {
|
||||
@ -282,27 +285,36 @@ impl PyTensor {
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor shape as a Python tuple.
|
||||
/// Gets the tensor's shape.
|
||||
/// &RETURNS&: Tuple[int]
|
||||
fn shape(&self, py: Python<'_>) -> PyObject {
|
||||
PyTuple::new(py, self.0.dims()).to_object(py)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's strides.
|
||||
/// &RETURNS&: Tuple[int]
|
||||
fn stride(&self, py: Python<'_>) -> PyObject {
|
||||
PyTuple::new(py, self.0.stride()).to_object(py)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's dtype.
|
||||
/// &RETURNS&: DType
|
||||
fn dtype(&self) -> PyDType {
|
||||
PyDType(self.0.dtype())
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's device.
|
||||
/// &RETURNS&: Device
|
||||
fn device(&self, py: Python<'_>) -> PyObject {
|
||||
PyDevice::from_device(self.0.device()).to_object(py)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's rank.
|
||||
/// &RETURNS&: int
|
||||
fn rank(&self) -> usize {
|
||||
self.0.rank()
|
||||
}
|
||||
@ -315,69 +327,117 @@ impl PyTensor {
|
||||
self.__repr__()
|
||||
}
|
||||
|
||||
/// Performs the `sin` operation on the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn sin(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.sin().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Performs the `cos` operation on the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn cos(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.cos().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Performs the `log` operation on the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn log(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.log().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Squares the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn sqr(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.sqr().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Calculates the square root of the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn sqrt(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Get the `recip` of the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn recip(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.recip().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Performs the `exp` operation on the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn exp(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.exp().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, p:float)")]
|
||||
/// Performs the `pow` operation on the tensor with the given exponent.
|
||||
/// &RETURNS&: Tensor
|
||||
fn powf(&self, p: f64) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, rhs:Tensor, dim:int)")]
|
||||
/// Select values for the input tensor at the target indexes across the specified dimension.
|
||||
///
|
||||
/// The `indexes` is argument is an int tensor with a single dimension.
|
||||
/// The output has the same number of dimension as the `self` input. The target dimension of
|
||||
/// the output has length the length of `indexes` and the values are taken from `self` using
|
||||
/// the index from `indexes`. Other dimensions have the same number of elements as the input
|
||||
/// tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, rhs:Tensor)")]
|
||||
/// Performs a matrix multiplication between the two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, rhs:Tensor)")]
|
||||
/// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, rhs:Tensor)")]
|
||||
/// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, rhs:Tensor)")]
|
||||
/// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, rhs:Tensor)")]
|
||||
/// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, on_true:Tensor, on_false:Tensor)")]
|
||||
/// Returns a tensor with the same shape as the input tensor, the values are taken from
|
||||
/// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
|
||||
/// input tensor is equal to zero.
|
||||
/// &RETURNS&: Tensor
|
||||
fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Add two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 + &rhs.0).map_err(wrap_err)?
|
||||
@ -393,6 +453,8 @@ impl PyTensor {
|
||||
self.__add__(rhs)
|
||||
}
|
||||
|
||||
/// Multiply two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 * &rhs.0).map_err(wrap_err)?
|
||||
@ -408,6 +470,8 @@ impl PyTensor {
|
||||
self.__mul__(rhs)
|
||||
}
|
||||
|
||||
/// Subtract two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 - &rhs.0).map_err(wrap_err)?
|
||||
@ -419,6 +483,8 @@ impl PyTensor {
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
/// Divide two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 / &rhs.0).map_err(wrap_err)?
|
||||
@ -430,62 +496,102 @@ impl PyTensor {
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
|
||||
/// Reshapes the tensor to the given shape.
|
||||
/// &RETURNS&: Tensor
|
||||
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
|
||||
/// Broadcasts the tensor to the given shape.
|
||||
/// &RETURNS&: Tensor
|
||||
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
|
||||
/// Broadcasts the tensor to the given shape, adding new dimensions on the left.
|
||||
/// &RETURNS&: Tensor
|
||||
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:int)")]
|
||||
/// Creates a new tensor with the specified dimension removed if its size was one.
|
||||
/// &RETURNS&: Tensor
|
||||
fn squeeze(&self, dim: i64) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:int)")]
|
||||
/// Creates a new tensor with a dimension of size one inserted at the specified position.
|
||||
/// &RETURNS&: Tensor
|
||||
fn unsqueeze(&self, dim: usize) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, index:int)")]
|
||||
/// Gets the value at the specified index.
|
||||
/// &RETURNS&: Tensor
|
||||
fn get(&self, index: i64) -> PyResult<Self> {
|
||||
let index = actual_index(self, 0, index).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.get(index).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim1:int, dim2:int)")]
|
||||
/// Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
|
||||
/// &RETURNS&: Tensor
|
||||
fn transpose(&self, dim1: usize, dim2: usize) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:int, start:int, len:int)")]
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + len`.
|
||||
/// &RETURNS&: Tensor
|
||||
fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
let start = actual_index(self, dim, start).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:int)")]
|
||||
/// Returns the indices of the maximum value(s) across the selected dimension.
|
||||
/// &RETURNS&: Tensor
|
||||
fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:int)")]
|
||||
/// Returns the indices of the minimum value(s) across the selected dimension.
|
||||
/// &RETURNS&: Tensor
|
||||
fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:int)")]
|
||||
/// Gathers the maximum value across the selected dimension.
|
||||
/// &RETURNS&: Tensor
|
||||
fn max_keepdim(&self, dim: i64) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:int)")]
|
||||
/// Gathers the minimum value across the selected dimension.
|
||||
/// &RETURNS&: Tensor
|
||||
fn min_keepdim(&self, dim: i64) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:Union[int, List[int]])")]
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
|
||||
/// &RETURNS&: Tensor
|
||||
fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> {
|
||||
let dims = if let Ok(dim) = dims.extract::<usize>(py) {
|
||||
vec![dim]
|
||||
@ -497,10 +603,14 @@ impl PyTensor {
|
||||
))
|
||||
}
|
||||
|
||||
/// Returns the sum of the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn sum_all(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Returns the mean of the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn mean_all(&self) -> PyResult<Self> {
|
||||
let elements = self.0.elem_count();
|
||||
let sum = self.0.sum_all().map_err(wrap_err)?;
|
||||
@ -508,54 +618,83 @@ impl PyTensor {
|
||||
Ok(PyTensor(mean))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:int)")]
|
||||
/// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
|
||||
/// &RETURNS&: Tensor
|
||||
fn flatten_from(&self, dim: i64) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:int)")]
|
||||
///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
|
||||
/// &RETURNS&: Tensor
|
||||
fn flatten_to(&self, dim: i64) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Flattens the tensor into a 1D tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn flatten_all(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Transposes the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn t(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.t().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Makes the tensor contiguous in memory.
|
||||
/// &RETURNS&: Tensor
|
||||
fn contiguous(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Returns true if the tensor is contiguous in C order.
|
||||
/// &RETURNS&: bool
|
||||
fn is_contiguous(&self) -> bool {
|
||||
self.0.is_contiguous()
|
||||
}
|
||||
|
||||
/// Returns true if the tensor is contiguous in Fortran order.
|
||||
/// &RETURNS&: bool
|
||||
fn is_fortran_contiguous(&self) -> bool {
|
||||
self.0.is_fortran_contiguous()
|
||||
}
|
||||
|
||||
/// Detach the tensor from the computation graph.
|
||||
/// &RETURNS&: Tensor
|
||||
fn detach(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.detach().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Returns a copy of the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn copy(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dtype:Union[str,DType])")]
|
||||
/// Convert the tensor to a new dtype.
|
||||
/// &RETURNS&: Tensor
|
||||
fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> {
|
||||
let dtype = PyDType::from_pyobject(dtype, py)?;
|
||||
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, device:Union[str,Device])")]
|
||||
/// Move the tensor to a new device.
|
||||
/// &RETURNS&: Tensor
|
||||
fn to_device(&self, device: PyDevice) -> PyResult<Self> {
|
||||
let device = device.as_device()?;
|
||||
Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, quantized_dtype:str)")]
|
||||
/// Quantize the tensor.
|
||||
/// &RETURNS&: QTensor
|
||||
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
|
||||
use ::candle::quantized;
|
||||
let res = match quantized_dtype {
|
||||
@ -586,6 +725,7 @@ impl PyTensor {
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")]
|
||||
/// Concatenate the tensors across one axis.
|
||||
/// &RETURNS&: Tensor
|
||||
fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
|
||||
if tensors.is_empty() {
|
||||
return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
|
||||
@ -599,6 +739,7 @@ fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")]
|
||||
/// Stack the tensors along a new axis.
|
||||
/// &RETURNS&: Tensor
|
||||
fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
|
||||
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
||||
let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
|
||||
@ -608,6 +749,7 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(data:_ArrayLike)")]
|
||||
/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
|
||||
/// &RETURNS&: Tensor
|
||||
fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
|
||||
PyTensor::new(py, data)
|
||||
}
|
||||
@ -615,6 +757,7 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor with random values.
|
||||
/// &RETURNS&: Tensor
|
||||
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||
@ -623,6 +766,8 @@ fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<P
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor with random values from a normal distribution.
|
||||
/// &RETURNS&: Tensor
|
||||
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||
@ -631,6 +776,8 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor filled with ones.
|
||||
/// &RETURNS&: Tensor
|
||||
fn ones(
|
||||
py: Python<'_>,
|
||||
shape: PyShape,
|
||||
@ -648,6 +795,8 @@ fn ones(
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor filled with zeros.
|
||||
/// &RETURNS&: Tensor
|
||||
fn zeros(
|
||||
py: Python<'_>,
|
||||
shape: PyShape,
|
||||
@ -663,8 +812,9 @@ fn zeros(
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[pyclass(name = "QTensor")]
|
||||
/// A quantized tensor.
|
||||
struct PyQTensor(Arc<QTensor>);
|
||||
|
||||
impl std::ops::Deref for PyQTensor {
|
||||
@ -678,16 +828,22 @@ impl std::ops::Deref for PyQTensor {
|
||||
#[pymethods]
|
||||
impl PyQTensor {
|
||||
#[getter]
|
||||
///Gets the tensors quantized dtype.
|
||||
/// &RETURNS&: str
|
||||
fn ggml_dtype(&self) -> String {
|
||||
format!("{:?}", self.0.dtype())
|
||||
}
|
||||
|
||||
#[getter]
|
||||
///Gets the rank of the tensor.
|
||||
/// &RETURNS&: int
|
||||
fn rank(&self) -> usize {
|
||||
self.0.rank()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
///Gets the shape of the tensor.
|
||||
/// &RETURNS&: Tuple[int]
|
||||
fn shape(&self, py: Python<'_>) -> PyObject {
|
||||
PyTuple::new(py, self.0.shape().dims()).to_object(py)
|
||||
}
|
||||
@ -700,11 +856,16 @@ impl PyQTensor {
|
||||
self.__repr__()
|
||||
}
|
||||
|
||||
/// Dequantizes the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn dequantize(&self) -> PyResult<PyTensor> {
|
||||
let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, lhs:Tensor)")]
|
||||
/// Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
|
||||
/// &RETURNS&: Tensor
|
||||
fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
|
||||
let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone());
|
||||
let res = qmatmul.forward(lhs).map_err(wrap_err)?;
|
||||
@ -715,6 +876,7 @@ impl PyQTensor {
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||
/// Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
|
||||
/// &RETURNS&: Dict[str,Tensor]
|
||||
fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||
let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
|
||||
let res = res
|
||||
@ -727,6 +889,7 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")]
|
||||
/// Saves a dictionary of tensors to a safetensors file.
|
||||
/// &RETURNS&: None
|
||||
fn save_safetensors(
|
||||
path: &str,
|
||||
tensors: std::collections::HashMap<String, PyTensor>,
|
||||
@ -742,6 +905,7 @@ fn save_safetensors(
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
|
||||
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
|
||||
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
|
||||
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
|
||||
let mut file = std::fs::File::open(path)?;
|
||||
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
|
||||
@ -776,6 +940,7 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
|
||||
/// and the second maps metadata keys to metadata values.
|
||||
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
|
||||
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
|
||||
use ::candle::quantized::gguf_file;
|
||||
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
|
||||
@ -824,26 +989,118 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
|
||||
Ok((tensors, metadata))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(
|
||||
text_signature = "(path:Union[str,PathLike], tensors:Dict[str,QTensor], metadata:Dict[str,Any])"
|
||||
)]
|
||||
/// Save quanitzed tensors and metadata to a GGUF file.
|
||||
fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
|
||||
use ::candle::quantized::gguf_file;
|
||||
|
||||
fn pyobject_to_gguf_value(v: &PyAny, py: Python<'_>) -> PyResult<gguf_file::Value> {
|
||||
let v: gguf_file::Value = if let Ok(x) = v.extract::<u8>() {
|
||||
gguf_file::Value::U8(x)
|
||||
} else if let Ok(x) = v.extract::<i8>() {
|
||||
gguf_file::Value::I8(x)
|
||||
} else if let Ok(x) = v.extract::<u16>() {
|
||||
gguf_file::Value::U16(x)
|
||||
} else if let Ok(x) = v.extract::<i16>() {
|
||||
gguf_file::Value::I16(x)
|
||||
} else if let Ok(x) = v.extract::<u32>() {
|
||||
gguf_file::Value::U32(x)
|
||||
} else if let Ok(x) = v.extract::<i32>() {
|
||||
gguf_file::Value::I32(x)
|
||||
} else if let Ok(x) = v.extract::<u64>() {
|
||||
gguf_file::Value::U64(x)
|
||||
} else if let Ok(x) = v.extract::<i64>() {
|
||||
gguf_file::Value::I64(x)
|
||||
} else if let Ok(x) = v.extract::<f32>() {
|
||||
gguf_file::Value::F32(x)
|
||||
} else if let Ok(x) = v.extract::<f64>() {
|
||||
gguf_file::Value::F64(x)
|
||||
} else if let Ok(x) = v.extract::<bool>() {
|
||||
gguf_file::Value::Bool(x)
|
||||
} else if let Ok(x) = v.extract::<String>() {
|
||||
gguf_file::Value::String(x)
|
||||
} else if let Ok(x) = v.extract::<Vec<PyObject>>() {
|
||||
let x = x
|
||||
.into_iter()
|
||||
.map(|f| pyobject_to_gguf_value(f.as_ref(py), py))
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
gguf_file::Value::Array(x)
|
||||
} else {
|
||||
return Err(PyErr::new::<PyValueError, _>(format!(
|
||||
"unsupported type {:?}",
|
||||
v
|
||||
)));
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
let tensors = tensors
|
||||
.extract::<&PyDict>(py)
|
||||
.map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
Ok((
|
||||
key.extract::<String>()
|
||||
.map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
|
||||
value.extract::<PyQTensor>()?.0,
|
||||
))
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
|
||||
let metadata = metadata
|
||||
.extract::<&PyDict>(py)
|
||||
.map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
Ok((
|
||||
key.extract::<String>()
|
||||
.map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
|
||||
pyobject_to_gguf_value(value, py)?,
|
||||
))
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
|
||||
let converted_metadata: Vec<_> = metadata
|
||||
.iter()
|
||||
.map(|(name, value)| (name.as_str(), value))
|
||||
.collect();
|
||||
|
||||
let converted_tensors: Vec<_> = tensors
|
||||
.iter()
|
||||
.map(|(name, tensor)| (name.as_str(), tensor.as_ref()))
|
||||
.collect();
|
||||
|
||||
let mut file = std::fs::File::create(path)?;
|
||||
|
||||
gguf_file::write(&mut file, &converted_metadata, &converted_tensors).map_err(wrap_err)
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
/// Returns true if the 'cuda' backend is available.
|
||||
/// &RETURNS&: bool
|
||||
fn cuda_is_available() -> bool {
|
||||
::candle::utils::cuda_is_available()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
/// Returns true if candle was compiled with 'accelerate' support.
|
||||
/// &RETURNS&: bool
|
||||
fn has_accelerate() -> bool {
|
||||
::candle::utils::has_accelerate()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
/// Returns true if candle was compiled with MKL support.
|
||||
/// &RETURNS&: bool
|
||||
fn has_mkl() -> bool {
|
||||
::candle::utils::has_mkl()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
/// Returns the number of threads used by the candle.
|
||||
/// &RETURNS&: int
|
||||
fn get_num_threads() -> usize {
|
||||
::candle::utils::get_num_threads()
|
||||
}
|
||||
@ -855,6 +1112,7 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(save_gguf, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
|
||||
Ok(())
|
||||
@ -862,7 +1120,8 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(tensor:Tensor, dim:int)")]
|
||||
/// Applies the Softmax function to a given tensor.
|
||||
/// Applies the Softmax function to a given tensor.#
|
||||
/// &RETURNS&: Tensor
|
||||
fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
|
||||
let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
|
||||
let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
|
||||
@ -872,6 +1131,7 @@ fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(tensor:Tensor)")]
|
||||
/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||
let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
|
||||
Ok(PyTensor(s))
|
||||
|
Reference in New Issue
Block a user