mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the pooling operators to the pyo3 layer. (#1086)
This commit is contained in:
@ -1,7 +1,9 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from .. import functional
|
||||
|
||||
avg_pool2d = functional.avg_pool2d
|
||||
gelu = functional.gelu
|
||||
max_pool2d = functional.max_pool2d
|
||||
relu = functional.relu
|
||||
silu = functional.silu
|
||||
softmax = functional.softmax
|
||||
|
@ -4,6 +4,13 @@ from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device
|
||||
from candle import Tensor, DType, QTensor
|
||||
|
||||
@staticmethod
|
||||
def avg_pool2d(tensor: Tensor, ksize: int, stride: int = 1) -> Tensor:
|
||||
"""
|
||||
Applies the 2d avg-pool function to a given tensor.#
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def gelu(tensor: Tensor) -> Tensor:
|
||||
"""
|
||||
@ -11,6 +18,13 @@ def gelu(tensor: Tensor) -> Tensor:
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def max_pool2d(tensor: Tensor, ksize: int, stride: int = 1) -> Tensor:
|
||||
"""
|
||||
Applies the 2d max-pool function to a given tensor.#
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def relu(tensor: Tensor) -> Tensor:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user