mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Generate *.pyi
stubs for PyO3 wrapper (#870)
* Begin to generate typehints. * generate correct stubs * Correctly include stubs * Add comments and typhints to static functions * ensure candle-pyo3 directory * Make `llama.rope.freq_base` optional * `fmt`
This commit is contained in:
217
candle-pyo3/stub.py
Normal file
217
candle-pyo3/stub.py
Normal file
@ -0,0 +1,217 @@
|
||||
#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
|
||||
import argparse
|
||||
import inspect
|
||||
import os
|
||||
from typing import Optional
|
||||
import black
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
INDENT = " " * 4
|
||||
GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
|
||||
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
"""
|
||||
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
|
||||
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType\n"
|
||||
|
||||
|
||||
|
||||
def do_indent(text: Optional[str], indent: str):
|
||||
if text is None:
|
||||
return ""
|
||||
return text.replace("\n", f"\n{indent}")
|
||||
|
||||
|
||||
def function(obj, indent:str, text_signature:str=None):
|
||||
if text_signature is None:
|
||||
text_signature = obj.__text_signature__
|
||||
|
||||
text_signature = text_signature.replace("$self", "self").lstrip().rstrip()
|
||||
string = ""
|
||||
string += f"{indent}def {obj.__name__}{text_signature}:\n"
|
||||
indent += INDENT
|
||||
string += f'{indent}"""\n'
|
||||
string += f"{indent}{do_indent(obj.__doc__, indent)}\n"
|
||||
string += f'{indent}"""\n'
|
||||
string += f"{indent}pass\n"
|
||||
string += "\n"
|
||||
string += "\n"
|
||||
return string
|
||||
|
||||
|
||||
def member_sort(member):
|
||||
if inspect.isclass(member):
|
||||
value = 10 + len(inspect.getmro(member))
|
||||
else:
|
||||
value = 1
|
||||
return value
|
||||
|
||||
|
||||
def fn_predicate(obj):
|
||||
value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj)
|
||||
if value:
|
||||
return obj.__text_signature__ and not obj.__name__.startswith("_")
|
||||
if inspect.isgetsetdescriptor(obj):
|
||||
return not obj.__name__.startswith("_")
|
||||
return False
|
||||
|
||||
|
||||
def get_module_members(module):
|
||||
members = [
|
||||
member
|
||||
for name, member in inspect.getmembers(module)
|
||||
if not name.startswith("_") and not inspect.ismodule(member)
|
||||
]
|
||||
members.sort(key=member_sort)
|
||||
return members
|
||||
|
||||
|
||||
def pyi_file(obj, indent=""):
|
||||
string = ""
|
||||
if inspect.ismodule(obj):
|
||||
string += GENERATED_COMMENT
|
||||
string += TYPING
|
||||
string += CANDLE_SPECIFIC_TYPING
|
||||
if obj.__name__ != "candle.candle":
|
||||
string += CANDLE_TENSOR_IMPORTS
|
||||
members = get_module_members(obj)
|
||||
for member in members:
|
||||
string += pyi_file(member, indent)
|
||||
|
||||
elif inspect.isclass(obj):
|
||||
indent += INDENT
|
||||
mro = inspect.getmro(obj)
|
||||
if len(mro) > 2:
|
||||
inherit = f"({mro[1].__name__})"
|
||||
else:
|
||||
inherit = ""
|
||||
string += f"class {obj.__name__}{inherit}:\n"
|
||||
|
||||
body = ""
|
||||
if obj.__doc__:
|
||||
body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
|
||||
|
||||
fns = inspect.getmembers(obj, fn_predicate)
|
||||
|
||||
# Init
|
||||
if obj.__text_signature__:
|
||||
body += f"{indent}def __init__{obj.__text_signature__}:\n"
|
||||
body += f"{indent+INDENT}pass\n"
|
||||
body += "\n"
|
||||
|
||||
for (name, fn) in fns:
|
||||
body += pyi_file(fn, indent=indent)
|
||||
|
||||
if not body:
|
||||
body += f"{indent}pass\n"
|
||||
|
||||
string += body
|
||||
string += "\n\n"
|
||||
|
||||
elif inspect.isbuiltin(obj):
|
||||
string += f"{indent}@staticmethod\n"
|
||||
string += function(obj, indent)
|
||||
|
||||
elif inspect.ismethoddescriptor(obj):
|
||||
string += function(obj, indent)
|
||||
|
||||
elif inspect.isgetsetdescriptor(obj):
|
||||
# TODO it would be interesing to add the setter maybe ?
|
||||
string += f"{indent}@property\n"
|
||||
string += function(obj, indent, text_signature="(self)")
|
||||
|
||||
elif obj.__class__.__name__ == "DType":
|
||||
string += f"class {str(obj).lower()}(DType):\n"
|
||||
string += f"{indent+INDENT}pass\n"
|
||||
else:
|
||||
raise Exception(f"Object {obj} is not supported")
|
||||
return string
|
||||
|
||||
|
||||
def py_file(module, origin):
|
||||
members = get_module_members(module)
|
||||
|
||||
string = GENERATED_COMMENT
|
||||
string += f"from .. import {origin}\n"
|
||||
string += "\n"
|
||||
for member in members:
|
||||
if hasattr(member, "__name__"):
|
||||
name = member.__name__
|
||||
else:
|
||||
name = str(member)
|
||||
string += f"{name} = {origin}.{name}\n"
|
||||
return string
|
||||
|
||||
|
||||
def do_black(content, is_pyi):
|
||||
mode = black.Mode(
|
||||
target_versions={black.TargetVersion.PY35},
|
||||
line_length=119,
|
||||
is_pyi=is_pyi,
|
||||
string_normalization=True,
|
||||
experimental_string_processing=False,
|
||||
)
|
||||
try:
|
||||
return black.format_file_contents(content, fast=True, mode=mode)
|
||||
except black.NothingChanged:
|
||||
return content
|
||||
|
||||
|
||||
def write(module, directory, origin, check=False):
|
||||
submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)]
|
||||
|
||||
filename = os.path.join(directory, "__init__.pyi")
|
||||
pyi_content = pyi_file(module)
|
||||
pyi_content = do_black(pyi_content, is_pyi=True)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
if check:
|
||||
with open(filename, "r") as f:
|
||||
data = f.read()
|
||||
assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
|
||||
else:
|
||||
with open(filename, "w") as f:
|
||||
f.write(pyi_content)
|
||||
|
||||
filename = os.path.join(directory, "__init__.py")
|
||||
py_content = py_file(module, origin)
|
||||
py_content = do_black(py_content, is_pyi=False)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
is_auto = False
|
||||
if not os.path.exists(filename):
|
||||
is_auto = True
|
||||
else:
|
||||
with open(filename, "r") as f:
|
||||
line = f.readline()
|
||||
if line == GENERATED_COMMENT:
|
||||
is_auto = True
|
||||
|
||||
if is_auto:
|
||||
if check:
|
||||
with open(filename, "r") as f:
|
||||
data = f.read()
|
||||
assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
|
||||
else:
|
||||
with open(filename, "w") as f:
|
||||
f.write(py_content)
|
||||
|
||||
for name, submodule in submodules:
|
||||
write(submodule, os.path.join(directory, name), f"{name}", check=check)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
#Enable execution from the candle and candle-pyo3 directories
|
||||
cwd = Path.cwd()
|
||||
directory = "py_src/candle/"
|
||||
if cwd.name != "candle-pyo3":
|
||||
directory = f"candle-pyo3/{directory}"
|
||||
|
||||
import candle
|
||||
|
||||
write(candle.candle, directory, "candle", check=args.check)
|
Reference in New Issue
Block a user