Make the Python Wrapper more Hackable and simplify Quantization (#1010)

* Some first `Module` implementations

* Add `state_dict` and `load_state_dict` functionality

* Move modules around and create `candle.nn.Linear`

* Add `nn.Embedding` and `nn.LayerNorm`

* Add BERT implementation

* Batch q-matmul

* Automatically dequantize `QTensors` if a `Tensor` is expected

* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality

* Unittests for `Module`, `Tensor` and `candle.utils`

* Add `pytorch` like slicing to `Tensor`

* Cleanup and BERT fixes

* `black` formatting + unit-test for `nn.Linear`

* Refactor slicing implementation
This commit is contained in:
Lukas Kreussel
2023-10-06 20:01:07 +02:00
committed by GitHub
parent b0442eff8a
commit 904bbdae65
25 changed files with 2426 additions and 182 deletions

View File

@ -3,6 +3,7 @@ use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::os::raw::c_long;
use std::sync::Arc;
use half::{bf16, f16};
@ -196,6 +197,12 @@ trait MapDType {
}
}
enum Indexer {
Index(usize),
Slice(usize, usize),
Elipsis,
}
#[pymethods]
impl PyTensor {
#[new]
@ -436,6 +443,95 @@ impl PyTensor {
))
}
#[getter]
/// Index a tensor.
/// &RETURNS&: Tensor
fn __getitem__(&self, py: Python, idx: PyObject) -> PyResult<Self> {
let mut indexers: Vec<Indexer> = vec![];
let dims = self.0.shape().dims();
let to_absolute_index = |index: isize, current_dim: usize| {
// Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0]
let actual_index = if index < 0 {
dims[current_dim] as isize + index
} else {
index
};
// Check that the index is in range
if actual_index < 0 || actual_index >= dims[current_dim] as isize {
return Err(PyTypeError::new_err(format!(
"index out of range for dimension '{i}' with indexer '{value}'",
i = current_dim,
value = index
)));
}
Ok(actual_index as usize)
};
if let Ok(index) = idx.extract(py) {
// Handle a single index e.g. tensor[0] or tensor[-1]
indexers.push(Indexer::Index(to_absolute_index(index, 0)?));
} else if let Ok(slice) = idx.downcast::<pyo3::types::PySlice>(py) {
// Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
let index = slice.indices(dims[0] as c_long)?;
indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
} else if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
// Handle multiple indices e.g. tensor[0,0] or tensor[0:1,0:1]
if tuple.len() > dims.len() {
return Err(PyTypeError::new_err("provided too many indices"));
}
for (i, item) in tuple.iter().enumerate() {
if item.is_ellipsis() {
// Handle '...' e.g. tensor[..., 0]
if i > 0 {
return Err(PyTypeError::new_err("Ellipsis ('...') can only be used at the start of an indexing operation"));
}
indexers.push(Indexer::Elipsis);
} else if let Ok(slice) = item.downcast::<pyo3::types::PySlice>() {
// Handle slice
let index = slice.indices(dims[i] as c_long)?;
indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
} else if let Ok(index) = item.extract::<isize>() {
indexers.push(Indexer::Index(to_absolute_index(index, i)?));
} else {
return Err(PyTypeError::new_err("unsupported index"));
}
}
} else {
return Err(PyTypeError::new_err("unsupported index"));
}
let mut x = self.0.clone();
let mut current_dim = 0;
// Apply the indexers
for indexer in indexers.iter() {
x = match indexer {
Indexer::Index(n) => x
.narrow(current_dim, *n, 1)
.map_err(wrap_err)?
.squeeze(current_dim)
.map_err(wrap_err)?,
Indexer::Slice(start, stop) => {
let out = x
.narrow(current_dim, *start, stop.saturating_sub(*start))
.map_err(wrap_err)?;
current_dim += 1;
out
}
Indexer::Elipsis => {
// Elipsis is a special case, it means that all remaining dimensions should be selected => advance the current_dim to the last dimension we have indexers for
current_dim += dims.len() - (indexers.len() - 1);
x
}
}
}
Ok(Self(x))
}
/// Add two tensors.
/// &RETURNS&: Tensor
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
@ -697,7 +793,7 @@ impl PyTensor {
/// &RETURNS&: QTensor
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
use ::candle::quantized;
let res = match quantized_dtype {
let res = match quantized_dtype.to_lowercase().as_str() {
"q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
"q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
"q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
@ -1137,9 +1233,39 @@ fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
Ok(PyTensor(s))
}
fn candle_nn_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[pyfunction]
#[pyo3(text_signature = "(tensor:Tensor)")]
/// Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
/// &RETURNS&: Tensor
fn gelu(tensor: PyTensor) -> PyResult<PyTensor> {
let s = tensor.0.gelu_erf().map_err(wrap_err)?;
Ok(PyTensor(s))
}
#[pyfunction]
#[pyo3(text_signature = "(tensor:Tensor)")]
/// Applies the Rectified Linear Unit (ReLU) function to a given tensor.
/// &RETURNS&: Tensor
fn relu(tensor: PyTensor) -> PyResult<PyTensor> {
let s = tensor.0.relu().map_err(wrap_err)?;
Ok(PyTensor(s))
}
#[pyfunction]
#[pyo3(text_signature = "(tensor:Tensor)")]
/// Applies the tanh function to a given tensor.
/// &RETURNS&: Tensor
fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
let s = tensor.0.tanh().map_err(wrap_err)?;
Ok(PyTensor(s))
}
fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(silu, m)?)?;
m.add_function(wrap_pyfunction!(softmax, m)?)?;
m.add_function(wrap_pyfunction!(gelu, m)?)?;
m.add_function(wrap_pyfunction!(relu, m)?)?;
m.add_function(wrap_pyfunction!(tanh, m)?)?;
Ok(())
}
@ -1148,8 +1274,8 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?;
candle_utils(py, utils)?;
m.add_submodule(utils)?;
let nn = PyModule::new(py, "nn")?;
candle_nn_m(py, nn)?;
let nn = PyModule::new(py, "functional")?;
candle_functional_m(py, nn)?;
m.add_submodule(nn)?;
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;