mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the pooling operators to the pyo3 layer. (#1086)
This commit is contained in:
@ -1,7 +1,9 @@
|
|||||||
# Generated content DO NOT EDIT
|
# Generated content DO NOT EDIT
|
||||||
from .. import functional
|
from .. import functional
|
||||||
|
|
||||||
|
avg_pool2d = functional.avg_pool2d
|
||||||
gelu = functional.gelu
|
gelu = functional.gelu
|
||||||
|
max_pool2d = functional.max_pool2d
|
||||||
relu = functional.relu
|
relu = functional.relu
|
||||||
silu = functional.silu
|
silu = functional.silu
|
||||||
softmax = functional.softmax
|
softmax = functional.softmax
|
||||||
|
@ -4,6 +4,13 @@ from os import PathLike
|
|||||||
from candle.typing import _ArrayLike, Device
|
from candle.typing import _ArrayLike, Device
|
||||||
from candle import Tensor, DType, QTensor
|
from candle import Tensor, DType, QTensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def avg_pool2d(tensor: Tensor, ksize: int, stride: int = 1) -> Tensor:
|
||||||
|
"""
|
||||||
|
Applies the 2d avg-pool function to a given tensor.#
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gelu(tensor: Tensor) -> Tensor:
|
def gelu(tensor: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
@ -11,6 +18,13 @@ def gelu(tensor: Tensor) -> Tensor:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def max_pool2d(tensor: Tensor, ksize: int, stride: int = 1) -> Tensor:
|
||||||
|
"""
|
||||||
|
Applies the 2d max-pool function to a given tensor.#
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def relu(tensor: Tensor) -> Tensor:
|
def relu(tensor: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
@ -1224,6 +1224,28 @@ fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
|
|||||||
Ok(PyTensor(sm))
|
Ok(PyTensor(sm))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")]
|
||||||
|
/// Applies the 2d avg-pool function to a given tensor.#
|
||||||
|
/// &RETURNS&: Tensor
|
||||||
|
fn avg_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> {
|
||||||
|
let tensor = tensor
|
||||||
|
.avg_pool2d_with_stride(ksize, stride)
|
||||||
|
.map_err(wrap_err)?;
|
||||||
|
Ok(PyTensor(tensor))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")]
|
||||||
|
/// Applies the 2d max-pool function to a given tensor.#
|
||||||
|
/// &RETURNS&: Tensor
|
||||||
|
fn max_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> {
|
||||||
|
let tensor = tensor
|
||||||
|
.max_pool2d_with_stride(ksize, stride)
|
||||||
|
.map_err(wrap_err)?;
|
||||||
|
Ok(PyTensor(tensor))
|
||||||
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
#[pyo3(text_signature = "(tensor:Tensor)")]
|
#[pyo3(text_signature = "(tensor:Tensor)")]
|
||||||
/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
|
/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
|
||||||
@ -1263,6 +1285,8 @@ fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
|
|||||||
fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_function(wrap_pyfunction!(silu, m)?)?;
|
m.add_function(wrap_pyfunction!(silu, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(softmax, m)?)?;
|
m.add_function(wrap_pyfunction!(softmax, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(max_pool2d, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(avg_pool2d, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(gelu, m)?)?;
|
m.add_function(wrap_pyfunction!(gelu, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(relu, m)?)?;
|
m.add_function(wrap_pyfunction!(relu, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(tanh, m)?)?;
|
m.add_function(wrap_pyfunction!(tanh, m)?)?;
|
||||||
|
Reference in New Issue
Block a user