mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Generate *.pyi
stubs for PyO3 wrapper (#870)
* Begin to generate typehints. * generate correct stubs * Correctly include stubs * Add comments and typhints to static functions * ensure candle-pyo3 directory * Make `llama.rope.freq_base` optional * `fmt`
This commit is contained in:
160
candle-pyo3/.gitignore
vendored
Normal file
160
candle-pyo3/.gitignore
vendored
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
@ -12,7 +12,6 @@ readme = "README.md"
|
|||||||
[lib]
|
[lib]
|
||||||
name = "candle"
|
name = "candle"
|
||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
doc = false
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", version = "0.2.2", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.2.2", package = "candle-core" }
|
||||||
|
@ -1,7 +1,26 @@
|
|||||||
|
## Installation
|
||||||
|
|
||||||
From the `candle-pyo3` directory, enable a virtual env where you will want the
|
From the `candle-pyo3` directory, enable a virtual env where you will want the
|
||||||
candle package to be installed then run.
|
candle package to be installed then run.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
maturin develop
|
maturin develop -r
|
||||||
python test.py
|
python test.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Generating Stub Files for Type Hinting
|
||||||
|
|
||||||
|
For type hinting support, the `candle-pyo3` package requires `*.pyi` files. You can automatically generate these files using the `stub.py` script.
|
||||||
|
|
||||||
|
### Steps:
|
||||||
|
1. Install the package using `maturin`.
|
||||||
|
2. Generate the stub files by running:
|
||||||
|
```
|
||||||
|
python stub.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Validation:
|
||||||
|
To ensure that the stub files match the current implementation, execute:
|
||||||
|
```
|
||||||
|
python stub.py --check
|
||||||
|
```
|
||||||
|
1
candle-pyo3/py_src/candle/__init__.py
Normal file
1
candle-pyo3/py_src/candle/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .candle import *
|
248
candle-pyo3/py_src/candle/__init__.pyi
Normal file
248
candle-pyo3/py_src/candle/__init__.pyi
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
# Generated content DO NOT EDIT
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||||
|
from os import PathLike
|
||||||
|
from candle.typing import _ArrayLike, Device
|
||||||
|
|
||||||
|
class bf16(DType):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cat(tensors: List[Tensor], dim: int):
|
||||||
|
"""
|
||||||
|
Concatenate the tensors across one axis.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class f16(DType):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class f32(DType):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class f64(DType):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class i64(DType):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rand(shape: Sequence[int], device: Optional[Device] = None):
|
||||||
|
"""
|
||||||
|
Creates a new tensor with random values.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def randn(shape: Sequence[int], device: Optional[Device] = None):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def stack(tensors: List[Tensor], dim: int):
|
||||||
|
"""
|
||||||
|
Stack the tensors along a new axis.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tensor(data: _ArrayLike):
|
||||||
|
"""
|
||||||
|
Creates a new tensor from a Python value. The value can be a scalar or array-like object.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class u32(DType):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class u8(DType):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class QTensor:
|
||||||
|
def dequantize(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
@property
|
||||||
|
def ggml_dtype(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def matmul_t(self, lhs):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
@property
|
||||||
|
def rank(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Tensor:
|
||||||
|
def __init__(data: _ArrayLike):
|
||||||
|
pass
|
||||||
|
def argmax_keepdim(self, dim):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def argmin_keepdim(self, dim):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def broadcast_add(self, rhs):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def broadcast_as(self, shape):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def broadcast_div(self, rhs):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def broadcast_left(self, shape):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def broadcast_mul(self, rhs):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def broadcast_sub(self, rhs):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def contiguous(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def copy(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def cos(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def detach(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def exp(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def flatten_all(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def flatten_from(self, dim):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def flatten_to(self, dim):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def get(self, index):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def index_select(self, rhs, dim):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def is_contiguous(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def is_fortran_contiguous(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def log(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def matmul(self, rhs):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def max_keepdim(self, dim):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def mean_all(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def min_keepdim(self, dim):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def narrow(self, dim, start, len):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def powf(self, p):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def quantize(self, quantized_dtype):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
@property
|
||||||
|
def rank(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def recip(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def reshape(self, shape):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
"""
|
||||||
|
Gets the tensor shape as a Python tuple.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def sin(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def sqr(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def sqrt(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def squeeze(self, dim):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
@property
|
||||||
|
def stride(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def sum_all(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def sum_keepdim(self, dims):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def t(self):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def to_device(self, device):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def to_dtype(self, dtype):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def transpose(self, dim1, dim2):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def unsqueeze(self, dim):
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
def values(self):
|
||||||
|
"""
|
||||||
|
Gets the tensor's data as a Python scalar or array-like object.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def where_cond(self, on_true, on_false):
|
||||||
|
""" """
|
||||||
|
pass
|
5
candle-pyo3/py_src/candle/nn/__init__.py
Normal file
5
candle-pyo3/py_src/candle/nn/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Generated content DO NOT EDIT
|
||||||
|
from .. import nn
|
||||||
|
|
||||||
|
silu = nn.silu
|
||||||
|
softmax = nn.softmax
|
19
candle-pyo3/py_src/candle/nn/__init__.pyi
Normal file
19
candle-pyo3/py_src/candle/nn/__init__.pyi
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# Generated content DO NOT EDIT
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||||
|
from os import PathLike
|
||||||
|
from candle.typing import _ArrayLike, Device
|
||||||
|
from candle import Tensor, DType
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def silu(tensor: Tensor):
|
||||||
|
"""
|
||||||
|
Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def softmax(tensor: Tensor, dim: int):
|
||||||
|
"""
|
||||||
|
Applies the Softmax function to a given tensor.
|
||||||
|
"""
|
||||||
|
pass
|
16
candle-pyo3/py_src/candle/typing/__init__.py
Normal file
16
candle-pyo3/py_src/candle/typing/__init__.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from typing import TypeVar, Union, Sequence
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
_ArrayLike = Union[
|
||||||
|
_T,
|
||||||
|
Sequence[_T],
|
||||||
|
Sequence[Sequence[_T]],
|
||||||
|
Sequence[Sequence[Sequence[_T]]],
|
||||||
|
Sequence[Sequence[Sequence[Sequence[_T]]]],
|
||||||
|
]
|
||||||
|
|
||||||
|
CPU:str = "cpu"
|
||||||
|
CUDA:str = "cuda"
|
||||||
|
|
||||||
|
Device = TypeVar("Device", CPU, CUDA)
|
11
candle-pyo3/py_src/candle/utils/__init__.py
Normal file
11
candle-pyo3/py_src/candle/utils/__init__.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# Generated content DO NOT EDIT
|
||||||
|
from .. import utils
|
||||||
|
|
||||||
|
cuda_is_available = utils.cuda_is_available
|
||||||
|
get_num_threads = utils.get_num_threads
|
||||||
|
has_accelerate = utils.has_accelerate
|
||||||
|
has_mkl = utils.has_mkl
|
||||||
|
load_ggml = utils.load_ggml
|
||||||
|
load_gguf = utils.load_gguf
|
||||||
|
load_safetensors = utils.load_safetensors
|
||||||
|
save_safetensors = utils.save_safetensors
|
63
candle-pyo3/py_src/candle/utils/__init__.pyi
Normal file
63
candle-pyo3/py_src/candle/utils/__init__.pyi
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# Generated content DO NOT EDIT
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||||
|
from os import PathLike
|
||||||
|
from candle.typing import _ArrayLike, Device
|
||||||
|
from candle import Tensor, DType
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cuda_is_available():
|
||||||
|
"""
|
||||||
|
Returns true if the 'cuda' backend is available.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_num_threads():
|
||||||
|
"""
|
||||||
|
Returns the number of threads used by the candle.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def has_accelerate():
|
||||||
|
"""
|
||||||
|
Returns true if candle was compiled with 'accelerate' support.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def has_mkl():
|
||||||
|
"""
|
||||||
|
Returns true if candle was compiled with MKL support.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_ggml(path: Union[str, PathLike]):
|
||||||
|
"""
|
||||||
|
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
|
||||||
|
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_gguf(path: Union[str, PathLike]):
|
||||||
|
"""
|
||||||
|
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
|
||||||
|
and the second maps metadata keys to metadata values.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_safetensors(path: Union[str, PathLike]):
|
||||||
|
"""
|
||||||
|
Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]):
|
||||||
|
"""
|
||||||
|
Saves a dictionary of tensors to a safetensors file.
|
||||||
|
"""
|
||||||
|
pass
|
30
candle-pyo3/pyproject.toml
Normal file
30
candle-pyo3/pyproject.toml
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
[project]
|
||||||
|
name = 'candle-pyo3'
|
||||||
|
requires-python = '>=3.7'
|
||||||
|
authors = [
|
||||||
|
{name = 'Laurent Mazare', email = ''},
|
||||||
|
]
|
||||||
|
|
||||||
|
dynamic = [
|
||||||
|
'description',
|
||||||
|
'license',
|
||||||
|
'readme',
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = 'https://github.com/huggingface/candle'
|
||||||
|
Source = 'https://github.com/huggingface/candle'
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["maturin>=1.0,<2.0"]
|
||||||
|
build-backend = "maturin"
|
||||||
|
|
||||||
|
[tool.maturin]
|
||||||
|
python-source = "py_src"
|
||||||
|
module-name = "candle.candle"
|
||||||
|
bindings = 'pyo3'
|
||||||
|
features = ["pyo3/extension-module"]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 119
|
||||||
|
target-version = ['py35']
|
@ -1,6 +1,7 @@
|
|||||||
# This example shows how the candle Python api can be used to replicate llama.cpp.
|
# This example shows how the candle Python api can be used to replicate llama.cpp.
|
||||||
import sys
|
import sys
|
||||||
import candle
|
import candle
|
||||||
|
from candle.utils import load_ggml,load_gguf
|
||||||
|
|
||||||
MAX_SEQ_LEN = 4096
|
MAX_SEQ_LEN = 4096
|
||||||
|
|
||||||
@ -154,7 +155,7 @@ def main():
|
|||||||
filename = sys.argv[1]
|
filename = sys.argv[1]
|
||||||
print(f"reading model file {filename}")
|
print(f"reading model file {filename}")
|
||||||
if filename.endswith("gguf"):
|
if filename.endswith("gguf"):
|
||||||
all_tensors, metadata = candle.load_gguf(sys.argv[1])
|
all_tensors, metadata = load_gguf(sys.argv[1])
|
||||||
vocab = metadata["tokenizer.ggml.tokens"]
|
vocab = metadata["tokenizer.ggml.tokens"]
|
||||||
for i, v in enumerate(vocab):
|
for i, v in enumerate(vocab):
|
||||||
vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
|
vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
|
||||||
@ -168,13 +169,13 @@ def main():
|
|||||||
'n_head_kv': metadata['llama.attention.head_count_kv'],
|
'n_head_kv': metadata['llama.attention.head_count_kv'],
|
||||||
'n_layer': metadata['llama.block_count'],
|
'n_layer': metadata['llama.block_count'],
|
||||||
'n_rot': metadata['llama.rope.dimension_count'],
|
'n_rot': metadata['llama.rope.dimension_count'],
|
||||||
'rope_freq': metadata['llama.rope.freq_base'],
|
'rope_freq': metadata.get('llama.rope.freq_base', 10000.),
|
||||||
'ftype': metadata['general.file_type'],
|
'ftype': metadata['general.file_type'],
|
||||||
}
|
}
|
||||||
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
|
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
|
||||||
|
|
||||||
else:
|
else:
|
||||||
all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
|
all_tensors, hparams, vocab = load_ggml(sys.argv[1])
|
||||||
print(hparams)
|
print(hparams)
|
||||||
model = QuantizedLlama(hparams, all_tensors)
|
model = QuantizedLlama(hparams, all_tensors)
|
||||||
print("model built, starting inference")
|
print("model built, starting inference")
|
||||||
|
@ -197,38 +197,40 @@ trait MapDType {
|
|||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyTensor {
|
impl PyTensor {
|
||||||
#[new]
|
#[new]
|
||||||
|
#[pyo3(text_signature = "(data:_ArrayLike)")]
|
||||||
// TODO: Handle arbitrary input dtype and shape.
|
// TODO: Handle arbitrary input dtype and shape.
|
||||||
fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
|
/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
|
||||||
|
fn new(py: Python<'_>, data: PyObject) -> PyResult<Self> {
|
||||||
use Device::Cpu;
|
use Device::Cpu;
|
||||||
let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
|
let tensor = if let Ok(vs) = data.extract::<u32>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<i64>(py) {
|
} else if let Ok(vs) = data.extract::<i64>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<f32>(py) {
|
} else if let Ok(vs) = data.extract::<f32>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
|
} else if let Ok(vs) = data.extract::<Vec<u32>>(py) {
|
||||||
let len = vs.len();
|
let len = vs.len();
|
||||||
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
|
} else if let Ok(vs) = data.extract::<Vec<i64>>(py) {
|
||||||
let len = vs.len();
|
let len = vs.len();
|
||||||
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
|
} else if let Ok(vs) = data.extract::<Vec<f32>>(py) {
|
||||||
let len = vs.len();
|
let len = vs.len();
|
||||||
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) {
|
} else if let Ok(vs) = data.extract::<Vec<Vec<u32>>>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) {
|
} else if let Ok(vs) = data.extract::<Vec<Vec<i64>>>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) {
|
} else if let Ok(vs) = data.extract::<Vec<Vec<f32>>>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) {
|
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<u32>>>>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) {
|
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<i64>>>>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) {
|
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else {
|
} else {
|
||||||
let ty = vs.as_ref(py).get_type();
|
let ty = data.as_ref(py).get_type();
|
||||||
Err(PyTypeError::new_err(format!(
|
Err(PyTypeError::new_err(format!(
|
||||||
"incorrect type {ty} for tensor"
|
"incorrect type {ty} for tensor"
|
||||||
)))?
|
)))?
|
||||||
@ -236,7 +238,7 @@ impl PyTensor {
|
|||||||
Ok(Self(tensor))
|
Ok(Self(tensor))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets the tensor data as a Python value/array/array of array/...
|
/// Gets the tensor's data as a Python scalar or array-like object.
|
||||||
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
|
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
struct M<'a>(Python<'a>);
|
struct M<'a>(Python<'a>);
|
||||||
impl<'a> MapDType for M<'a> {
|
impl<'a> MapDType for M<'a> {
|
||||||
@ -280,6 +282,7 @@ impl PyTensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
|
/// Gets the tensor shape as a Python tuple.
|
||||||
fn shape(&self, py: Python<'_>) -> PyObject {
|
fn shape(&self, py: Python<'_>) -> PyObject {
|
||||||
PyTuple::new(py, self.0.dims()).to_object(py)
|
PyTuple::new(py, self.0.dims()).to_object(py)
|
||||||
}
|
}
|
||||||
@ -580,8 +583,9 @@ impl PyTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Concatenate the tensors across one axis.
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
#[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")]
|
||||||
|
/// Concatenate the tensors across one axis.
|
||||||
fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
|
fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
|
||||||
if tensors.is_empty() {
|
if tensors.is_empty() {
|
||||||
return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
|
return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
|
||||||
@ -593,6 +597,8 @@ fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
#[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")]
|
||||||
|
/// Stack the tensors along a new axis.
|
||||||
fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
|
fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
|
||||||
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
||||||
let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
|
let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
|
||||||
@ -600,12 +606,15 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
|
#[pyo3(text_signature = "(data:_ArrayLike)")]
|
||||||
PyTensor::new(py, vs)
|
/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
|
||||||
|
fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
|
||||||
|
PyTensor::new(py, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
#[pyo3(signature = (shape, *, device=None))]
|
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
|
||||||
|
/// Creates a new tensor with random values.
|
||||||
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||||
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||||
@ -613,7 +622,7 @@ fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<P
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
#[pyo3(signature = (shape, *, device=None))]
|
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
|
||||||
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||||
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||||
@ -621,7 +630,7 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
#[pyo3(signature = (shape, *, dtype=None, device=None))]
|
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||||
fn ones(
|
fn ones(
|
||||||
py: Python<'_>,
|
py: Python<'_>,
|
||||||
shape: PyShape,
|
shape: PyShape,
|
||||||
@ -638,7 +647,7 @@ fn ones(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
#[pyo3(signature = (shape, *, dtype=None, device=None))]
|
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||||
fn zeros(
|
fn zeros(
|
||||||
py: Python<'_>,
|
py: Python<'_>,
|
||||||
shape: PyShape,
|
shape: PyShape,
|
||||||
@ -704,6 +713,8 @@ impl PyQTensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||||
|
/// Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
|
||||||
fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
|
let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
|
||||||
let res = res
|
let res = res
|
||||||
@ -714,6 +725,8 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
#[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")]
|
||||||
|
/// Saves a dictionary of tensors to a safetensors file.
|
||||||
fn save_safetensors(
|
fn save_safetensors(
|
||||||
path: &str,
|
path: &str,
|
||||||
tensors: std::collections::HashMap<String, PyTensor>,
|
tensors: std::collections::HashMap<String, PyTensor>,
|
||||||
@ -726,6 +739,9 @@ fn save_safetensors(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||||
|
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
|
||||||
|
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
|
||||||
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
|
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
|
||||||
let mut file = std::fs::File::open(path)?;
|
let mut file = std::fs::File::open(path)?;
|
||||||
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
|
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
|
||||||
@ -757,6 +773,9 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||||
|
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
|
||||||
|
/// and the second maps metadata keys to metadata values.
|
||||||
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
|
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
|
||||||
use ::candle::quantized::gguf_file;
|
use ::candle::quantized::gguf_file;
|
||||||
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
|
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
@ -806,21 +825,25 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
/// Returns true if the 'cuda' backend is available.
|
||||||
fn cuda_is_available() -> bool {
|
fn cuda_is_available() -> bool {
|
||||||
::candle::utils::cuda_is_available()
|
::candle::utils::cuda_is_available()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
/// Returns true if candle was compiled with 'accelerate' support.
|
||||||
fn has_accelerate() -> bool {
|
fn has_accelerate() -> bool {
|
||||||
::candle::utils::has_accelerate()
|
::candle::utils::has_accelerate()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
/// Returns true if candle was compiled with MKL support.
|
||||||
fn has_mkl() -> bool {
|
fn has_mkl() -> bool {
|
||||||
::candle::utils::has_mkl()
|
::candle::utils::has_mkl()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
/// Returns the number of threads used by the candle.
|
||||||
fn get_num_threads() -> usize {
|
fn get_num_threads() -> usize {
|
||||||
::candle::utils::get_num_threads()
|
::candle::utils::get_num_threads()
|
||||||
}
|
}
|
||||||
@ -830,19 +853,27 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
|
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
|
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
|
m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
fn softmax(t: PyTensor, dim: i64) -> PyResult<PyTensor> {
|
#[pyo3(text_signature = "(tensor:Tensor, dim:int)")]
|
||||||
let dim = actual_dim(&t, dim).map_err(wrap_err)?;
|
/// Applies the Softmax function to a given tensor.
|
||||||
let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?;
|
fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
|
||||||
|
let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
|
||||||
|
let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
|
||||||
Ok(PyTensor(sm))
|
Ok(PyTensor(sm))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
fn silu(t: PyTensor) -> PyResult<PyTensor> {
|
#[pyo3(text_signature = "(tensor:Tensor)")]
|
||||||
let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?;
|
/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
|
||||||
|
fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||||
|
let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
|
||||||
Ok(PyTensor(s))
|
Ok(PyTensor(s))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -871,14 +902,10 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add("f32", PyDType(DType::F32))?;
|
m.add("f32", PyDType(DType::F32))?;
|
||||||
m.add("f64", PyDType(DType::F64))?;
|
m.add("f64", PyDType(DType::F64))?;
|
||||||
m.add_function(wrap_pyfunction!(cat, m)?)?;
|
m.add_function(wrap_pyfunction!(cat, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
|
|
||||||
m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
|
|
||||||
m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
|
|
||||||
m.add_function(wrap_pyfunction!(ones, m)?)?;
|
m.add_function(wrap_pyfunction!(ones, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(rand, m)?)?;
|
m.add_function(wrap_pyfunction!(rand, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(randn, m)?)?;
|
m.add_function(wrap_pyfunction!(randn, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(tensor, m)?)?;
|
m.add_function(wrap_pyfunction!(tensor, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
|
|
||||||
m.add_function(wrap_pyfunction!(stack, m)?)?;
|
m.add_function(wrap_pyfunction!(stack, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(zeros, m)?)?;
|
m.add_function(wrap_pyfunction!(zeros, m)?)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
217
candle-pyo3/stub.py
Normal file
217
candle-pyo3/stub.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
|
||||||
|
import argparse
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
import black
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
INDENT = " " * 4
|
||||||
|
GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
|
||||||
|
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||||
|
from os import PathLike
|
||||||
|
"""
|
||||||
|
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
|
||||||
|
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType\n"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def do_indent(text: Optional[str], indent: str):
|
||||||
|
if text is None:
|
||||||
|
return ""
|
||||||
|
return text.replace("\n", f"\n{indent}")
|
||||||
|
|
||||||
|
|
||||||
|
def function(obj, indent:str, text_signature:str=None):
|
||||||
|
if text_signature is None:
|
||||||
|
text_signature = obj.__text_signature__
|
||||||
|
|
||||||
|
text_signature = text_signature.replace("$self", "self").lstrip().rstrip()
|
||||||
|
string = ""
|
||||||
|
string += f"{indent}def {obj.__name__}{text_signature}:\n"
|
||||||
|
indent += INDENT
|
||||||
|
string += f'{indent}"""\n'
|
||||||
|
string += f"{indent}{do_indent(obj.__doc__, indent)}\n"
|
||||||
|
string += f'{indent}"""\n'
|
||||||
|
string += f"{indent}pass\n"
|
||||||
|
string += "\n"
|
||||||
|
string += "\n"
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def member_sort(member):
|
||||||
|
if inspect.isclass(member):
|
||||||
|
value = 10 + len(inspect.getmro(member))
|
||||||
|
else:
|
||||||
|
value = 1
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def fn_predicate(obj):
|
||||||
|
value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj)
|
||||||
|
if value:
|
||||||
|
return obj.__text_signature__ and not obj.__name__.startswith("_")
|
||||||
|
if inspect.isgetsetdescriptor(obj):
|
||||||
|
return not obj.__name__.startswith("_")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_members(module):
|
||||||
|
members = [
|
||||||
|
member
|
||||||
|
for name, member in inspect.getmembers(module)
|
||||||
|
if not name.startswith("_") and not inspect.ismodule(member)
|
||||||
|
]
|
||||||
|
members.sort(key=member_sort)
|
||||||
|
return members
|
||||||
|
|
||||||
|
|
||||||
|
def pyi_file(obj, indent=""):
|
||||||
|
string = ""
|
||||||
|
if inspect.ismodule(obj):
|
||||||
|
string += GENERATED_COMMENT
|
||||||
|
string += TYPING
|
||||||
|
string += CANDLE_SPECIFIC_TYPING
|
||||||
|
if obj.__name__ != "candle.candle":
|
||||||
|
string += CANDLE_TENSOR_IMPORTS
|
||||||
|
members = get_module_members(obj)
|
||||||
|
for member in members:
|
||||||
|
string += pyi_file(member, indent)
|
||||||
|
|
||||||
|
elif inspect.isclass(obj):
|
||||||
|
indent += INDENT
|
||||||
|
mro = inspect.getmro(obj)
|
||||||
|
if len(mro) > 2:
|
||||||
|
inherit = f"({mro[1].__name__})"
|
||||||
|
else:
|
||||||
|
inherit = ""
|
||||||
|
string += f"class {obj.__name__}{inherit}:\n"
|
||||||
|
|
||||||
|
body = ""
|
||||||
|
if obj.__doc__:
|
||||||
|
body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
|
||||||
|
|
||||||
|
fns = inspect.getmembers(obj, fn_predicate)
|
||||||
|
|
||||||
|
# Init
|
||||||
|
if obj.__text_signature__:
|
||||||
|
body += f"{indent}def __init__{obj.__text_signature__}:\n"
|
||||||
|
body += f"{indent+INDENT}pass\n"
|
||||||
|
body += "\n"
|
||||||
|
|
||||||
|
for (name, fn) in fns:
|
||||||
|
body += pyi_file(fn, indent=indent)
|
||||||
|
|
||||||
|
if not body:
|
||||||
|
body += f"{indent}pass\n"
|
||||||
|
|
||||||
|
string += body
|
||||||
|
string += "\n\n"
|
||||||
|
|
||||||
|
elif inspect.isbuiltin(obj):
|
||||||
|
string += f"{indent}@staticmethod\n"
|
||||||
|
string += function(obj, indent)
|
||||||
|
|
||||||
|
elif inspect.ismethoddescriptor(obj):
|
||||||
|
string += function(obj, indent)
|
||||||
|
|
||||||
|
elif inspect.isgetsetdescriptor(obj):
|
||||||
|
# TODO it would be interesing to add the setter maybe ?
|
||||||
|
string += f"{indent}@property\n"
|
||||||
|
string += function(obj, indent, text_signature="(self)")
|
||||||
|
|
||||||
|
elif obj.__class__.__name__ == "DType":
|
||||||
|
string += f"class {str(obj).lower()}(DType):\n"
|
||||||
|
string += f"{indent+INDENT}pass\n"
|
||||||
|
else:
|
||||||
|
raise Exception(f"Object {obj} is not supported")
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def py_file(module, origin):
|
||||||
|
members = get_module_members(module)
|
||||||
|
|
||||||
|
string = GENERATED_COMMENT
|
||||||
|
string += f"from .. import {origin}\n"
|
||||||
|
string += "\n"
|
||||||
|
for member in members:
|
||||||
|
if hasattr(member, "__name__"):
|
||||||
|
name = member.__name__
|
||||||
|
else:
|
||||||
|
name = str(member)
|
||||||
|
string += f"{name} = {origin}.{name}\n"
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def do_black(content, is_pyi):
|
||||||
|
mode = black.Mode(
|
||||||
|
target_versions={black.TargetVersion.PY35},
|
||||||
|
line_length=119,
|
||||||
|
is_pyi=is_pyi,
|
||||||
|
string_normalization=True,
|
||||||
|
experimental_string_processing=False,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return black.format_file_contents(content, fast=True, mode=mode)
|
||||||
|
except black.NothingChanged:
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def write(module, directory, origin, check=False):
|
||||||
|
submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)]
|
||||||
|
|
||||||
|
filename = os.path.join(directory, "__init__.pyi")
|
||||||
|
pyi_content = pyi_file(module)
|
||||||
|
pyi_content = do_black(pyi_content, is_pyi=True)
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
if check:
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
data = f.read()
|
||||||
|
assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
|
||||||
|
else:
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
f.write(pyi_content)
|
||||||
|
|
||||||
|
filename = os.path.join(directory, "__init__.py")
|
||||||
|
py_content = py_file(module, origin)
|
||||||
|
py_content = do_black(py_content, is_pyi=False)
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
|
||||||
|
is_auto = False
|
||||||
|
if not os.path.exists(filename):
|
||||||
|
is_auto = True
|
||||||
|
else:
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
line = f.readline()
|
||||||
|
if line == GENERATED_COMMENT:
|
||||||
|
is_auto = True
|
||||||
|
|
||||||
|
if is_auto:
|
||||||
|
if check:
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
data = f.read()
|
||||||
|
assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
|
||||||
|
else:
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
f.write(py_content)
|
||||||
|
|
||||||
|
for name, submodule in submodules:
|
||||||
|
write(submodule, os.path.join(directory, name), f"{name}", check=check)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--check", action="store_true")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
#Enable execution from the candle and candle-pyo3 directories
|
||||||
|
cwd = Path.cwd()
|
||||||
|
directory = "py_src/candle/"
|
||||||
|
if cwd.name != "candle-pyo3":
|
||||||
|
directory = f"candle-pyo3/{directory}"
|
||||||
|
|
||||||
|
import candle
|
||||||
|
|
||||||
|
write(candle.candle, directory, "candle", check=args.check)
|
@ -1,4 +1,5 @@
|
|||||||
import candle
|
import candle
|
||||||
|
from candle import Tensor, QTensor
|
||||||
|
|
||||||
t = candle.Tensor(42.0)
|
t = candle.Tensor(42.0)
|
||||||
print(t)
|
print(t)
|
||||||
@ -9,7 +10,7 @@ t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
|
|||||||
print(t)
|
print(t)
|
||||||
print(t+t)
|
print(t+t)
|
||||||
|
|
||||||
t = t.reshape([2, 4])
|
t:Tensor = t.reshape([2, 4])
|
||||||
print(t.matmul(t.t()))
|
print(t.matmul(t.t()))
|
||||||
|
|
||||||
print(t.to_dtype(candle.u8))
|
print(t.to_dtype(candle.u8))
|
||||||
@ -20,7 +21,7 @@ print(t)
|
|||||||
print(t.dtype)
|
print(t.dtype)
|
||||||
|
|
||||||
t = candle.randn((16, 256))
|
t = candle.randn((16, 256))
|
||||||
quant_t = t.quantize("q6k")
|
quant_t:QTensor = t.quantize("q6k")
|
||||||
dequant_t = quant_t.dequantize()
|
dequant_t:Tensor = quant_t.dequantize()
|
||||||
diff2 = (t - dequant_t).sqr()
|
diff2:Tensor = (t - dequant_t).sqr()
|
||||||
print(diff2.mean_all())
|
print(diff2.mean_all())
|
||||||
|
Reference in New Issue
Block a user