Generate *.pyi stubs for PyO3 wrapper (#870)

* Begin to generate typehints.

* generate correct stubs

* Correctly include stubs

* Add comments and typhints to static functions

* ensure candle-pyo3 directory

* Make `llama.rope.freq_base` optional

* `fmt`
This commit is contained in:
Lukas Kreussel
2023-09-16 18:23:38 +02:00
committed by GitHub
parent 7cafca835a
commit 8658df3485
15 changed files with 857 additions and 40 deletions

View File

@ -1,6 +1,7 @@
# This example shows how the candle Python api can be used to replicate llama.cpp.
import sys
import candle
from candle.utils import load_ggml,load_gguf
MAX_SEQ_LEN = 4096
@ -154,7 +155,7 @@ def main():
filename = sys.argv[1]
print(f"reading model file {filename}")
if filename.endswith("gguf"):
all_tensors, metadata = candle.load_gguf(sys.argv[1])
all_tensors, metadata = load_gguf(sys.argv[1])
vocab = metadata["tokenizer.ggml.tokens"]
for i, v in enumerate(vocab):
vocab[i] = '\n' if v == '<0x0A>' else v.replace('', ' ')
@ -168,13 +169,13 @@ def main():
'n_head_kv': metadata['llama.attention.head_count_kv'],
'n_layer': metadata['llama.block_count'],
'n_rot': metadata['llama.rope.dimension_count'],
'rope_freq': metadata['llama.rope.freq_base'],
'rope_freq': metadata.get('llama.rope.freq_base', 10000.),
'ftype': metadata['general.file_type'],
}
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
else:
all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
all_tensors, hparams, vocab = load_ggml(sys.argv[1])
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
print("model built, starting inference")