Make the Python Wrapper more Hackable and simplify Quantization (#1010)

* Some first `Module` implementations

* Add `state_dict` and `load_state_dict` functionality

* Move modules around and create `candle.nn.Linear`

* Add `nn.Embedding` and `nn.LayerNorm`

* Add BERT implementation

* Batch q-matmul

* Automatically dequantize `QTensors` if a `Tensor` is expected

* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality

* Unittests for `Module`, `Tensor` and `candle.utils`

* Add `pytorch` like slicing to `Tensor`

* Cleanup and BERT fixes

* `black` formatting + unit-test for `nn.Linear`

* Refactor slicing implementation
This commit is contained in:
Lukas Kreussel
2023-10-06 20:01:07 +02:00
committed by GitHub
parent b0442eff8a
commit 904bbdae65
25 changed files with 2426 additions and 182 deletions

View File

@ -1,5 +1,30 @@
from .candle import *
import logging
try:
from .candle import *
except ImportError as e:
# If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here
logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...")
import os
import platform
# Try to locate CUDA_PATH environment variable
cuda_path = os.environ.get("CUDA_PATH", None)
if cuda_path:
logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}")
if platform.system() == "Windows":
cuda_path = os.path.join(cuda_path, "bin")
else:
cuda_path = os.path.join(cuda_path, "lib64")
logging.warning(f"Adding {cuda_path} to DLL search path...")
os.add_dll_directory(cuda_path)
try:
from .candle import *
except ImportError as inner_e:
raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.")
__doc__ = candle.__doc__
if hasattr(candle, "__all__"):
__all__ = candle.__all__
__all__ = candle.__all__