Extend stub.py to accept external typehinting (#1102)

This commit is contained in:
Lukas Kreussel
2023-10-17 12:07:26 +02:00
committed by GitHub
parent b355ab4e2e
commit f9e93f5b69
7 changed files with 146 additions and 4 deletions

View File

@ -0,0 +1,3 @@
This python module contains external typehinting for certain `candle` classes. This is only necessary for `magic` methodes e.g. `__add__` as their text signature cant be set via pyo3.
The classes in this module will be parsed by the `stub.py` script and interleafed with the signatures of the actual pyo3 `candle.candle` module.

View File

@ -0,0 +1,55 @@
from typing import Union, Sequence
class Tensor:
"""
This contains the type hints for the magic methodes of the `candle.Tensor` class.
"""
def __add__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
"""
pass
def __radd__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
"""
pass
def __sub__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Subtract a scalar from a tensor or one tensor from another.
"""
pass
def __truediv__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Divide a tensor by a scalar or one tensor by another.
"""
pass
def __mul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __rmul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __richcmp__(self, rhs: Union["Tensor", "Scalar"], op) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass

View File

@ -1,7 +1,7 @@
# Generated content DO NOT EDIT # Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike from os import PathLike
from candle.typing import _ArrayLike, Device from candle.typing import _ArrayLike, Device, Scalar, Index
class bf16(DType): class bf16(DType):
pass pass
@ -119,6 +119,46 @@ class Tensor:
def __init__(self, data: _ArrayLike): def __init__(self, data: _ArrayLike):
pass pass
def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
"""
pass
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
"""
pass
def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Subtract a scalar from a tensor or one tensor from another.
"""
pass
def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Divide a tensor by a scalar or one tensor by another.
"""
pass
def argmax_keepdim(self, dim: int) -> Tensor: def argmax_keepdim(self, dim: int) -> Tensor:
""" """
Returns the indices of the maximum value(s) across the selected dimension. Returns the indices of the maximum value(s) across the selected dimension.

View File

@ -1,7 +1,7 @@
# Generated content DO NOT EDIT # Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike from os import PathLike
from candle.typing import _ArrayLike, Device from candle.typing import _ArrayLike, Device, Scalar, Index
from candle import Tensor, DType, QTensor from candle import Tensor, DType, QTensor
@staticmethod @staticmethod

View File

@ -14,3 +14,7 @@ CPU: str = "cpu"
CUDA: str = "cuda" CUDA: str = "cuda"
Device = TypeVar("Device", CPU, CUDA) Device = TypeVar("Device", CPU, CUDA)
Scalar = Union[int, float]
Index = Union[int, slice, None, "Ellipsis"]

View File

@ -1,7 +1,7 @@
# Generated content DO NOT EDIT # Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike from os import PathLike
from candle.typing import _ArrayLike, Device from candle.typing import _ArrayLike, Device, Scalar, Index
from candle import Tensor, DType, QTensor from candle import Tensor, DType, QTensor
@staticmethod @staticmethod

View File

@ -5,6 +5,7 @@ import os
from typing import Optional from typing import Optional
import black import black
from pathlib import Path from pathlib import Path
import re
INDENT = " " * 4 INDENT = " " * 4
@ -12,9 +13,11 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike from os import PathLike
""" """
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n" CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n"
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n" CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
RETURN_TYPE_MARKER = "&RETURNS&: " RETURN_TYPE_MARKER = "&RETURNS&: "
ADDITIONAL_TYPEHINTS = {}
FORWARD_REF_PATTERN = re.compile(r"ForwardRef\('([^']+)'\)")
def do_indent(text: Optional[str], indent: str): def do_indent(text: Optional[str], indent: str):
@ -115,6 +118,27 @@ def pyi_file(obj, indent=""):
body += f"{indent+INDENT}pass\n" body += f"{indent+INDENT}pass\n"
body += "\n" body += "\n"
if obj.__name__ in ADDITIONAL_TYPEHINTS:
additional_members = inspect.getmembers(ADDITIONAL_TYPEHINTS[obj.__name__])
additional_functions = []
for name, member in additional_members:
if inspect.isfunction(member):
additional_functions.append((name, member))
def process_additional_function(fn):
signature = inspect.signature(fn)
cleaned_signature = re.sub(FORWARD_REF_PATTERN, r"\1", str(signature))
string = f"{indent}def {fn.__name__}{cleaned_signature}:\n"
string += (
f'{indent+INDENT}"""{indent+INDENT}{do_indent(fn.__doc__, indent+INDENT)}{indent+INDENT}"""\n'
)
string += f"{indent+INDENT}pass\n"
string += "\n"
return string
for name, fn in additional_functions:
body += process_additional_function(fn)
for name, fn in fns: for name, fn in fns:
body += pyi_file(fn, indent=indent) body += pyi_file(fn, indent=indent)
@ -215,6 +239,19 @@ def write(module, directory, origin, check=False):
write(submodule, os.path.join(directory, name), f"{name}", check=check) write(submodule, os.path.join(directory, name), f"{name}", check=check)
def extract_additional_types(module):
additional_types = {}
for name, member in inspect.getmembers(module):
if inspect.isclass(member):
if hasattr(member, "__name__"):
name = member.__name__
else:
name = str(member)
if name not in additional_types:
additional_types[name] = member
return additional_types
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true") parser.add_argument("--check", action="store_true")
@ -228,5 +265,8 @@ if __name__ == "__main__":
directory = f"candle-pyo3/{directory}" directory = f"candle-pyo3/{directory}"
import candle import candle
import _additional_typing
ADDITIONAL_TYPEHINTS = extract_additional_types(_additional_typing)
write(candle.candle, directory, "candle", check=args.check) write(candle.candle, directory, "candle", check=args.check)