mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
This commit is contained in:
@ -1,5 +1,30 @@
|
||||
from .candle import *
|
||||
import logging
|
||||
|
||||
try:
|
||||
from .candle import *
|
||||
except ImportError as e:
|
||||
# If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here
|
||||
logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...")
|
||||
import os
|
||||
import platform
|
||||
|
||||
# Try to locate CUDA_PATH environment variable
|
||||
cuda_path = os.environ.get("CUDA_PATH", None)
|
||||
if cuda_path:
|
||||
logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}")
|
||||
if platform.system() == "Windows":
|
||||
cuda_path = os.path.join(cuda_path, "bin")
|
||||
else:
|
||||
cuda_path = os.path.join(cuda_path, "lib64")
|
||||
|
||||
logging.warning(f"Adding {cuda_path} to DLL search path...")
|
||||
os.add_dll_directory(cuda_path)
|
||||
|
||||
try:
|
||||
from .candle import *
|
||||
except ImportError as inner_e:
|
||||
raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.")
|
||||
|
||||
__doc__ = candle.__doc__
|
||||
if hasattr(candle, "__all__"):
|
||||
__all__ = candle.__all__
|
||||
__all__ = candle.__all__
|
||||
|
Reference in New Issue
Block a user