mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00

* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
77 lines
2.9 KiB
Python
77 lines
2.9 KiB
Python
# This example shows how the candle Python api can be used to replicate llama.cpp.
|
|
import sys
|
|
from typing import Dict, Tuple, Any
|
|
import candle
|
|
from candle.models.llama import QuantizedLlama
|
|
from candle import utils
|
|
|
|
MAX_SEQ_LEN = 4096
|
|
|
|
|
|
def gguf_rename(tensor_name: str):
|
|
if tensor_name == "token_embd.weight":
|
|
return "tok_embeddings.weight"
|
|
if tensor_name == "output_norm.weight":
|
|
return "norm.weight"
|
|
tensor_name = tensor_name.replace("blk.", "layers.")
|
|
tensor_name = tensor_name.replace(".attn_q.", ".attention.wq.")
|
|
tensor_name = tensor_name.replace(".attn_k.", ".attention.wk.")
|
|
tensor_name = tensor_name.replace(".attn_v.", ".attention.wv.")
|
|
tensor_name = tensor_name.replace(".attn_output.", ".attention.wo.")
|
|
tensor_name = tensor_name.replace(".ffn_gate.", ".feed_forward.w1.")
|
|
tensor_name = tensor_name.replace(".ffn_down.", ".feed_forward.w2.")
|
|
tensor_name = tensor_name.replace(".ffn_up.", ".feed_forward.w3.")
|
|
tensor_name = tensor_name.replace(".attn_norm.", ".attention_norm.")
|
|
return tensor_name
|
|
|
|
|
|
def main():
|
|
if len(sys.argv) < 2:
|
|
raise ValueError("missing weight file argument")
|
|
|
|
filename = sys.argv[1]
|
|
print(f"reading model file {filename}")
|
|
if filename.endswith("gguf"):
|
|
all_tensors, metadata = utils.load_gguf(filename)
|
|
vocab = metadata["tokenizer.ggml.tokens"]
|
|
for i, v in enumerate(vocab):
|
|
vocab[i] = "\n" if v == "<0x0A>" else v.replace("▁", " ")
|
|
hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")}
|
|
print(hparams)
|
|
hparams = {
|
|
"n_vocab": len(vocab),
|
|
"n_embd": metadata["llama.embedding_length"],
|
|
"n_mult": 256,
|
|
"n_head": metadata["llama.attention.head_count"],
|
|
"n_head_kv": metadata["llama.attention.head_count_kv"],
|
|
"n_layer": metadata["llama.block_count"],
|
|
"n_rot": metadata["llama.rope.dimension_count"],
|
|
"rope_freq": metadata.get("llama.rope.freq_base", 10000.0),
|
|
"ftype": metadata["general.file_type"],
|
|
"context_length": metadata["llama.context_length"],
|
|
}
|
|
all_tensors = {gguf_rename(k): v for k, v in all_tensors.items()}
|
|
else:
|
|
all_tensors, hparams, vocab = utils.load_ggml(filename)
|
|
hparams["context_length"] = 2048
|
|
|
|
print(hparams)
|
|
model = QuantizedLlama(hparams, all_tensors)
|
|
print("model built, starting inference")
|
|
|
|
tokens = [1]
|
|
for token_idx in range(500):
|
|
last_token = tokens[-1]
|
|
lt = candle.tensor([last_token]).unsqueeze(0)
|
|
logits = model.forward(lt, len(tokens))
|
|
# Greedy sampling for now
|
|
# pr = candle.nn.softmax(logits, -1)
|
|
m = logits.get(0).argmax_keepdim(-1)
|
|
next_token = m.values()[0]
|
|
print(vocab[next_token], end="", flush=True)
|
|
tokens.append(next_token)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|