mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Generate *.pyi
stubs for PyO3 wrapper (#870)
* Begin to generate typehints. * generate correct stubs * Correctly include stubs * Add comments and typhints to static functions * ensure candle-pyo3 directory * Make `llama.rope.freq_base` optional * `fmt`
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
# This example shows how the candle Python api can be used to replicate llama.cpp.
|
||||
import sys
|
||||
import candle
|
||||
from candle.utils import load_ggml,load_gguf
|
||||
|
||||
MAX_SEQ_LEN = 4096
|
||||
|
||||
@ -154,7 +155,7 @@ def main():
|
||||
filename = sys.argv[1]
|
||||
print(f"reading model file {filename}")
|
||||
if filename.endswith("gguf"):
|
||||
all_tensors, metadata = candle.load_gguf(sys.argv[1])
|
||||
all_tensors, metadata = load_gguf(sys.argv[1])
|
||||
vocab = metadata["tokenizer.ggml.tokens"]
|
||||
for i, v in enumerate(vocab):
|
||||
vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
|
||||
@ -168,13 +169,13 @@ def main():
|
||||
'n_head_kv': metadata['llama.attention.head_count_kv'],
|
||||
'n_layer': metadata['llama.block_count'],
|
||||
'n_rot': metadata['llama.rope.dimension_count'],
|
||||
'rope_freq': metadata['llama.rope.freq_base'],
|
||||
'rope_freq': metadata.get('llama.rope.freq_base', 10000.),
|
||||
'ftype': metadata['general.file_type'],
|
||||
}
|
||||
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
|
||||
|
||||
else:
|
||||
all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
|
||||
all_tensors, hparams, vocab = load_ggml(sys.argv[1])
|
||||
print(hparams)
|
||||
model = QuantizedLlama(hparams, all_tensors)
|
||||
print("model built, starting inference")
|
||||
|
Reference in New Issue
Block a user