mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Extend stub.py
to accept external typehinting (#1102)
This commit is contained in:
3
candle-pyo3/_additional_typing/README.md
Normal file
3
candle-pyo3/_additional_typing/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
This python module contains external typehinting for certain `candle` classes. This is only necessary for `magic` methodes e.g. `__add__` as their text signature cant be set via pyo3.
|
||||||
|
|
||||||
|
The classes in this module will be parsed by the `stub.py` script and interleafed with the signatures of the actual pyo3 `candle.candle` module.
|
55
candle-pyo3/_additional_typing/__init__.py
Normal file
55
candle-pyo3/_additional_typing/__init__.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from typing import Union, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
class Tensor:
|
||||||
|
"""
|
||||||
|
This contains the type hints for the magic methodes of the `candle.Tensor` class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __add__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Add a scalar to a tensor or two tensors together.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __radd__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Add a scalar to a tensor or two tensors together.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __sub__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Subtract a scalar from a tensor or one tensor from another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __truediv__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Divide a tensor by a scalar or one tensor by another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __mul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Multiply a tensor by a scalar or one tensor by another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __rmul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Multiply a tensor by a scalar or one tensor by another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __richcmp__(self, rhs: Union["Tensor", "Scalar"], op) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Return a slice of a tensor.
|
||||||
|
"""
|
||||||
|
pass
|
@ -1,7 +1,7 @@
|
|||||||
# Generated content DO NOT EDIT
|
# Generated content DO NOT EDIT
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from candle.typing import _ArrayLike, Device
|
from candle.typing import _ArrayLike, Device, Scalar, Index
|
||||||
|
|
||||||
class bf16(DType):
|
class bf16(DType):
|
||||||
pass
|
pass
|
||||||
@ -119,6 +119,46 @@ class Tensor:
|
|||||||
|
|
||||||
def __init__(self, data: _ArrayLike):
|
def __init__(self, data: _ArrayLike):
|
||||||
pass
|
pass
|
||||||
|
def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Add a scalar to a tensor or two tensors together.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Return a slice of a tensor.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Multiply a tensor by a scalar or one tensor by another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Add a scalar to a tensor or two tensors together.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Multiply a tensor by a scalar or one tensor by another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Subtract a scalar from a tensor or one tensor from another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Divide a tensor by a scalar or one tensor by another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
def argmax_keepdim(self, dim: int) -> Tensor:
|
def argmax_keepdim(self, dim: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the indices of the maximum value(s) across the selected dimension.
|
Returns the indices of the maximum value(s) across the selected dimension.
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Generated content DO NOT EDIT
|
# Generated content DO NOT EDIT
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from candle.typing import _ArrayLike, Device
|
from candle.typing import _ArrayLike, Device, Scalar, Index
|
||||||
from candle import Tensor, DType, QTensor
|
from candle import Tensor, DType, QTensor
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -14,3 +14,7 @@ CPU: str = "cpu"
|
|||||||
CUDA: str = "cuda"
|
CUDA: str = "cuda"
|
||||||
|
|
||||||
Device = TypeVar("Device", CPU, CUDA)
|
Device = TypeVar("Device", CPU, CUDA)
|
||||||
|
|
||||||
|
Scalar = Union[int, float]
|
||||||
|
|
||||||
|
Index = Union[int, slice, None, "Ellipsis"]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Generated content DO NOT EDIT
|
# Generated content DO NOT EDIT
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from candle.typing import _ArrayLike, Device
|
from candle.typing import _ArrayLike, Device, Scalar, Index
|
||||||
from candle import Tensor, DType, QTensor
|
from candle import Tensor, DType, QTensor
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -5,6 +5,7 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import black
|
import black
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
INDENT = " " * 4
|
INDENT = " " * 4
|
||||||
@ -12,9 +13,11 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
|
|||||||
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
"""
|
"""
|
||||||
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
|
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n"
|
||||||
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
|
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
|
||||||
RETURN_TYPE_MARKER = "&RETURNS&: "
|
RETURN_TYPE_MARKER = "&RETURNS&: "
|
||||||
|
ADDITIONAL_TYPEHINTS = {}
|
||||||
|
FORWARD_REF_PATTERN = re.compile(r"ForwardRef\('([^']+)'\)")
|
||||||
|
|
||||||
|
|
||||||
def do_indent(text: Optional[str], indent: str):
|
def do_indent(text: Optional[str], indent: str):
|
||||||
@ -115,6 +118,27 @@ def pyi_file(obj, indent=""):
|
|||||||
body += f"{indent+INDENT}pass\n"
|
body += f"{indent+INDENT}pass\n"
|
||||||
body += "\n"
|
body += "\n"
|
||||||
|
|
||||||
|
if obj.__name__ in ADDITIONAL_TYPEHINTS:
|
||||||
|
additional_members = inspect.getmembers(ADDITIONAL_TYPEHINTS[obj.__name__])
|
||||||
|
additional_functions = []
|
||||||
|
for name, member in additional_members:
|
||||||
|
if inspect.isfunction(member):
|
||||||
|
additional_functions.append((name, member))
|
||||||
|
|
||||||
|
def process_additional_function(fn):
|
||||||
|
signature = inspect.signature(fn)
|
||||||
|
cleaned_signature = re.sub(FORWARD_REF_PATTERN, r"\1", str(signature))
|
||||||
|
string = f"{indent}def {fn.__name__}{cleaned_signature}:\n"
|
||||||
|
string += (
|
||||||
|
f'{indent+INDENT}"""{indent+INDENT}{do_indent(fn.__doc__, indent+INDENT)}{indent+INDENT}"""\n'
|
||||||
|
)
|
||||||
|
string += f"{indent+INDENT}pass\n"
|
||||||
|
string += "\n"
|
||||||
|
return string
|
||||||
|
|
||||||
|
for name, fn in additional_functions:
|
||||||
|
body += process_additional_function(fn)
|
||||||
|
|
||||||
for name, fn in fns:
|
for name, fn in fns:
|
||||||
body += pyi_file(fn, indent=indent)
|
body += pyi_file(fn, indent=indent)
|
||||||
|
|
||||||
@ -215,6 +239,19 @@ def write(module, directory, origin, check=False):
|
|||||||
write(submodule, os.path.join(directory, name), f"{name}", check=check)
|
write(submodule, os.path.join(directory, name), f"{name}", check=check)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_additional_types(module):
|
||||||
|
additional_types = {}
|
||||||
|
for name, member in inspect.getmembers(module):
|
||||||
|
if inspect.isclass(member):
|
||||||
|
if hasattr(member, "__name__"):
|
||||||
|
name = member.__name__
|
||||||
|
else:
|
||||||
|
name = str(member)
|
||||||
|
if name not in additional_types:
|
||||||
|
additional_types[name] = member
|
||||||
|
return additional_types
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--check", action="store_true")
|
parser.add_argument("--check", action="store_true")
|
||||||
@ -228,5 +265,8 @@ if __name__ == "__main__":
|
|||||||
directory = f"candle-pyo3/{directory}"
|
directory = f"candle-pyo3/{directory}"
|
||||||
|
|
||||||
import candle
|
import candle
|
||||||
|
import _additional_typing
|
||||||
|
|
||||||
|
ADDITIONAL_TYPEHINTS = extract_additional_types(_additional_typing)
|
||||||
|
|
||||||
write(candle.candle, directory, "candle", check=args.check)
|
write(candle.candle, directory, "candle", check=args.check)
|
||||||
|
Reference in New Issue
Block a user