mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
This commit is contained in:
@ -3,6 +3,7 @@ use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
|
||||
use pyo3::ToPyObject;
|
||||
use std::os::raw::c_long;
|
||||
use std::sync::Arc;
|
||||
|
||||
use half::{bf16, f16};
|
||||
@ -196,6 +197,12 @@ trait MapDType {
|
||||
}
|
||||
}
|
||||
|
||||
enum Indexer {
|
||||
Index(usize),
|
||||
Slice(usize, usize),
|
||||
Elipsis,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyTensor {
|
||||
#[new]
|
||||
@ -436,6 +443,95 @@ impl PyTensor {
|
||||
))
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Index a tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __getitem__(&self, py: Python, idx: PyObject) -> PyResult<Self> {
|
||||
let mut indexers: Vec<Indexer> = vec![];
|
||||
let dims = self.0.shape().dims();
|
||||
|
||||
let to_absolute_index = |index: isize, current_dim: usize| {
|
||||
// Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0]
|
||||
let actual_index = if index < 0 {
|
||||
dims[current_dim] as isize + index
|
||||
} else {
|
||||
index
|
||||
};
|
||||
|
||||
// Check that the index is in range
|
||||
if actual_index < 0 || actual_index >= dims[current_dim] as isize {
|
||||
return Err(PyTypeError::new_err(format!(
|
||||
"index out of range for dimension '{i}' with indexer '{value}'",
|
||||
i = current_dim,
|
||||
value = index
|
||||
)));
|
||||
}
|
||||
Ok(actual_index as usize)
|
||||
};
|
||||
if let Ok(index) = idx.extract(py) {
|
||||
// Handle a single index e.g. tensor[0] or tensor[-1]
|
||||
indexers.push(Indexer::Index(to_absolute_index(index, 0)?));
|
||||
} else if let Ok(slice) = idx.downcast::<pyo3::types::PySlice>(py) {
|
||||
// Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
|
||||
let index = slice.indices(dims[0] as c_long)?;
|
||||
indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
|
||||
} else if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
|
||||
// Handle multiple indices e.g. tensor[0,0] or tensor[0:1,0:1]
|
||||
|
||||
if tuple.len() > dims.len() {
|
||||
return Err(PyTypeError::new_err("provided too many indices"));
|
||||
}
|
||||
|
||||
for (i, item) in tuple.iter().enumerate() {
|
||||
if item.is_ellipsis() {
|
||||
// Handle '...' e.g. tensor[..., 0]
|
||||
|
||||
if i > 0 {
|
||||
return Err(PyTypeError::new_err("Ellipsis ('...') can only be used at the start of an indexing operation"));
|
||||
}
|
||||
indexers.push(Indexer::Elipsis);
|
||||
} else if let Ok(slice) = item.downcast::<pyo3::types::PySlice>() {
|
||||
// Handle slice
|
||||
let index = slice.indices(dims[i] as c_long)?;
|
||||
indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
|
||||
} else if let Ok(index) = item.extract::<isize>() {
|
||||
indexers.push(Indexer::Index(to_absolute_index(index, i)?));
|
||||
} else {
|
||||
return Err(PyTypeError::new_err("unsupported index"));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(PyTypeError::new_err("unsupported index"));
|
||||
}
|
||||
|
||||
let mut x = self.0.clone();
|
||||
let mut current_dim = 0;
|
||||
// Apply the indexers
|
||||
for indexer in indexers.iter() {
|
||||
x = match indexer {
|
||||
Indexer::Index(n) => x
|
||||
.narrow(current_dim, *n, 1)
|
||||
.map_err(wrap_err)?
|
||||
.squeeze(current_dim)
|
||||
.map_err(wrap_err)?,
|
||||
Indexer::Slice(start, stop) => {
|
||||
let out = x
|
||||
.narrow(current_dim, *start, stop.saturating_sub(*start))
|
||||
.map_err(wrap_err)?;
|
||||
current_dim += 1;
|
||||
out
|
||||
}
|
||||
Indexer::Elipsis => {
|
||||
// Elipsis is a special case, it means that all remaining dimensions should be selected => advance the current_dim to the last dimension we have indexers for
|
||||
current_dim += dims.len() - (indexers.len() - 1);
|
||||
x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self(x))
|
||||
}
|
||||
|
||||
/// Add two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
@ -697,7 +793,7 @@ impl PyTensor {
|
||||
/// &RETURNS&: QTensor
|
||||
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
|
||||
use ::candle::quantized;
|
||||
let res = match quantized_dtype {
|
||||
let res = match quantized_dtype.to_lowercase().as_str() {
|
||||
"q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
|
||||
"q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
|
||||
"q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
|
||||
@ -1137,9 +1233,39 @@ fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||
Ok(PyTensor(s))
|
||||
}
|
||||
|
||||
fn candle_nn_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(tensor:Tensor)")]
|
||||
/// Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn gelu(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||
let s = tensor.0.gelu_erf().map_err(wrap_err)?;
|
||||
Ok(PyTensor(s))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(tensor:Tensor)")]
|
||||
/// Applies the Rectified Linear Unit (ReLU) function to a given tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn relu(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||
let s = tensor.0.relu().map_err(wrap_err)?;
|
||||
Ok(PyTensor(s))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(tensor:Tensor)")]
|
||||
/// Applies the tanh function to a given tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||
let s = tensor.0.tanh().map_err(wrap_err)?;
|
||||
Ok(PyTensor(s))
|
||||
}
|
||||
|
||||
fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(silu, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(softmax, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gelu, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(relu, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(tanh, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1148,8 +1274,8 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let utils = PyModule::new(py, "utils")?;
|
||||
candle_utils(py, utils)?;
|
||||
m.add_submodule(utils)?;
|
||||
let nn = PyModule::new(py, "nn")?;
|
||||
candle_nn_m(py, nn)?;
|
||||
let nn = PyModule::new(py, "functional")?;
|
||||
candle_functional_m(py, nn)?;
|
||||
m.add_submodule(nn)?;
|
||||
m.add_class::<PyTensor>()?;
|
||||
m.add_class::<PyQTensor>()?;
|
||||
|
Reference in New Issue
Block a user