Add the pooling operators to the pyo3 layer. (#1086)

This commit is contained in:
Laurent Mazare
2023-10-13 21:18:10 +02:00
committed by GitHub
parent 75989fc3b7
commit 2c110ac7d9
3 changed files with 40 additions and 0 deletions

View File

@ -1224,6 +1224,28 @@ fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
Ok(PyTensor(sm))
}
#[pyfunction]
#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")]
/// Applies the 2d avg-pool function to a given tensor.#
/// &RETURNS&: Tensor
fn avg_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> {
let tensor = tensor
.avg_pool2d_with_stride(ksize, stride)
.map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")]
/// Applies the 2d max-pool function to a given tensor.#
/// &RETURNS&: Tensor
fn max_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> {
let tensor = tensor
.max_pool2d_with_stride(ksize, stride)
.map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
#[pyo3(text_signature = "(tensor:Tensor)")]
/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
@ -1263,6 +1285,8 @@ fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(silu, m)?)?;
m.add_function(wrap_pyfunction!(softmax, m)?)?;
m.add_function(wrap_pyfunction!(max_pool2d, m)?)?;
m.add_function(wrap_pyfunction!(avg_pool2d, m)?)?;
m.add_function(wrap_pyfunction!(gelu, m)?)?;
m.add_function(wrap_pyfunction!(relu, m)?)?;
m.add_function(wrap_pyfunction!(tanh, m)?)?;