mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Extend stub.py
to accept external typehinting (#1102)
This commit is contained in:
@ -5,6 +5,7 @@ import os
|
||||
from typing import Optional
|
||||
import black
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
|
||||
INDENT = " " * 4
|
||||
@ -12,9 +13,11 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
|
||||
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
"""
|
||||
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
|
||||
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n"
|
||||
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
|
||||
RETURN_TYPE_MARKER = "&RETURNS&: "
|
||||
ADDITIONAL_TYPEHINTS = {}
|
||||
FORWARD_REF_PATTERN = re.compile(r"ForwardRef\('([^']+)'\)")
|
||||
|
||||
|
||||
def do_indent(text: Optional[str], indent: str):
|
||||
@ -115,6 +118,27 @@ def pyi_file(obj, indent=""):
|
||||
body += f"{indent+INDENT}pass\n"
|
||||
body += "\n"
|
||||
|
||||
if obj.__name__ in ADDITIONAL_TYPEHINTS:
|
||||
additional_members = inspect.getmembers(ADDITIONAL_TYPEHINTS[obj.__name__])
|
||||
additional_functions = []
|
||||
for name, member in additional_members:
|
||||
if inspect.isfunction(member):
|
||||
additional_functions.append((name, member))
|
||||
|
||||
def process_additional_function(fn):
|
||||
signature = inspect.signature(fn)
|
||||
cleaned_signature = re.sub(FORWARD_REF_PATTERN, r"\1", str(signature))
|
||||
string = f"{indent}def {fn.__name__}{cleaned_signature}:\n"
|
||||
string += (
|
||||
f'{indent+INDENT}"""{indent+INDENT}{do_indent(fn.__doc__, indent+INDENT)}{indent+INDENT}"""\n'
|
||||
)
|
||||
string += f"{indent+INDENT}pass\n"
|
||||
string += "\n"
|
||||
return string
|
||||
|
||||
for name, fn in additional_functions:
|
||||
body += process_additional_function(fn)
|
||||
|
||||
for name, fn in fns:
|
||||
body += pyi_file(fn, indent=indent)
|
||||
|
||||
@ -215,6 +239,19 @@ def write(module, directory, origin, check=False):
|
||||
write(submodule, os.path.join(directory, name), f"{name}", check=check)
|
||||
|
||||
|
||||
def extract_additional_types(module):
|
||||
additional_types = {}
|
||||
for name, member in inspect.getmembers(module):
|
||||
if inspect.isclass(member):
|
||||
if hasattr(member, "__name__"):
|
||||
name = member.__name__
|
||||
else:
|
||||
name = str(member)
|
||||
if name not in additional_types:
|
||||
additional_types[name] = member
|
||||
return additional_types
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check", action="store_true")
|
||||
@ -228,5 +265,8 @@ if __name__ == "__main__":
|
||||
directory = f"candle-pyo3/{directory}"
|
||||
|
||||
import candle
|
||||
import _additional_typing
|
||||
|
||||
ADDITIONAL_TYPEHINTS = extract_additional_types(_additional_typing)
|
||||
|
||||
write(candle.candle, directory, "candle", check=args.check)
|
||||
|
Reference in New Issue
Block a user