mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
|
||||
# See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
|
||||
import argparse
|
||||
import inspect
|
||||
import os
|
||||
@ -23,7 +23,7 @@ def do_indent(text: Optional[str], indent: str):
|
||||
return text.replace("\n", f"\n{indent}")
|
||||
|
||||
|
||||
def function(obj, indent:str, text_signature:str=None):
|
||||
def function(obj, indent: str, text_signature: str = None):
|
||||
if text_signature is None:
|
||||
text_signature = obj.__text_signature__
|
||||
|
||||
@ -32,12 +32,12 @@ def function(obj, indent:str, text_signature:str=None):
|
||||
if doc_string is None:
|
||||
doc_string = ""
|
||||
|
||||
# Check if we have a return type annotation in the docstring
|
||||
# Check if we have a return type annotation in the docstring
|
||||
return_type = None
|
||||
doc_lines = doc_string.split("\n")
|
||||
if doc_lines[-1].lstrip().startswith(RETURN_TYPE_MARKER):
|
||||
# Extract the return type and remove it from the docstring
|
||||
return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER):].strip()
|
||||
return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER) :].strip()
|
||||
doc_string = "\n".join(doc_lines[:-1])
|
||||
|
||||
string = ""
|
||||
@ -115,7 +115,7 @@ def pyi_file(obj, indent=""):
|
||||
body += f"{indent+INDENT}pass\n"
|
||||
body += "\n"
|
||||
|
||||
for (name, fn) in fns:
|
||||
for name, fn in fns:
|
||||
body += pyi_file(fn, indent=indent)
|
||||
|
||||
if not body:
|
||||
@ -221,12 +221,12 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
#Enable execution from the candle and candle-pyo3 directories
|
||||
# Enable execution from the candle and candle-pyo3 directories
|
||||
cwd = Path.cwd()
|
||||
directory = "py_src/candle/"
|
||||
if cwd.name != "candle-pyo3":
|
||||
directory = f"candle-pyo3/{directory}"
|
||||
|
||||
|
||||
import candle
|
||||
|
||||
write(candle.candle, directory, "candle", check=args.check)
|
||||
|
Reference in New Issue
Block a user