Make the Python Wrapper more Hackable and simplify Quantization (#1010)

* Some first `Module` implementations

* Add `state_dict` and `load_state_dict` functionality

* Move modules around and create `candle.nn.Linear`

* Add `nn.Embedding` and `nn.LayerNorm`

* Add BERT implementation

* Batch q-matmul

* Automatically dequantize `QTensors` if a `Tensor` is expected

* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality

* Unittests for `Module`, `Tensor` and `candle.utils`

* Add `pytorch` like slicing to `Tensor`

* Cleanup and BERT fixes

* `black` formatting + unit-test for `nn.Linear`

* Refactor slicing implementation
This commit is contained in:
Lukas Kreussel
2023-10-06 20:01:07 +02:00
committed by GitHub
parent b0442eff8a
commit 904bbdae65
25 changed files with 2426 additions and 182 deletions

View File

@ -1,4 +1,4 @@
#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
# See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
import argparse
import inspect
import os
@ -23,7 +23,7 @@ def do_indent(text: Optional[str], indent: str):
return text.replace("\n", f"\n{indent}")
def function(obj, indent:str, text_signature:str=None):
def function(obj, indent: str, text_signature: str = None):
if text_signature is None:
text_signature = obj.__text_signature__
@ -32,12 +32,12 @@ def function(obj, indent:str, text_signature:str=None):
if doc_string is None:
doc_string = ""
# Check if we have a return type annotation in the docstring
# Check if we have a return type annotation in the docstring
return_type = None
doc_lines = doc_string.split("\n")
if doc_lines[-1].lstrip().startswith(RETURN_TYPE_MARKER):
# Extract the return type and remove it from the docstring
return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER):].strip()
return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER) :].strip()
doc_string = "\n".join(doc_lines[:-1])
string = ""
@ -115,7 +115,7 @@ def pyi_file(obj, indent=""):
body += f"{indent+INDENT}pass\n"
body += "\n"
for (name, fn) in fns:
for name, fn in fns:
body += pyi_file(fn, indent=indent)
if not body:
@ -221,12 +221,12 @@ if __name__ == "__main__":
args = parser.parse_args()
#Enable execution from the candle and candle-pyo3 directories
# Enable execution from the candle and candle-pyo3 directories
cwd = Path.cwd()
directory = "py_src/candle/"
if cwd.name != "candle-pyo3":
directory = f"candle-pyo3/{directory}"
import candle
write(candle.candle, directory, "candle", check=args.check)