mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Return the metadata in the gguf pyo3 bindings. (#729)
* Return the metadata in the gguf pyo3 bindings. * Read the metadata in the quantized llama example. * Get inference to work on gguf files.
This commit is contained in:
@ -111,7 +111,8 @@ class QuantizedLlama:
|
||||
self.norm = RmsNorm(all_tensors["norm.weight"])
|
||||
self.output = all_tensors["output.weight"]
|
||||
self.layers = []
|
||||
cos_sin = precompute_freqs_cis(hparams, 10000.)
|
||||
rope_freq = hparams.get("rope_freq", 10000.)
|
||||
cos_sin = precompute_freqs_cis(hparams, rope_freq)
|
||||
for layer_idx in range(hparams["n_layer"]):
|
||||
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
|
||||
self.layers.append(layer)
|
||||
@ -133,15 +134,45 @@ class QuantizedLlama:
|
||||
x = self.output.matmul_t(x)
|
||||
return x
|
||||
|
||||
def gguf_rename(tensor_name):
|
||||
if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight'
|
||||
if tensor_name == 'output_norm.weight': return 'norm.weight'
|
||||
tensor_name = tensor_name.replace('blk.', 'layers.')
|
||||
tensor_name = tensor_name.replace('.attn_q.', '.attention.wq.')
|
||||
tensor_name = tensor_name.replace('.attn_k.', '.attention.wk.')
|
||||
tensor_name = tensor_name.replace('.attn_v.', '.attention.wv.')
|
||||
tensor_name = tensor_name.replace('.attn_output.', '.attention.wo.')
|
||||
tensor_name = tensor_name.replace('.ffn_gate.', '.feed_forward.w1.')
|
||||
tensor_name = tensor_name.replace('.ffn_down.', '.feed_forward.w2.')
|
||||
tensor_name = tensor_name.replace('.ffn_up.', '.feed_forward.w3.')
|
||||
tensor_name = tensor_name.replace('.attn_norm.', '.attention_norm.')
|
||||
return tensor_name
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
raise ValueError("missing weight file argument")
|
||||
filename = sys.argv[1]
|
||||
print(f"reading model file {filename}")
|
||||
if filename.endswith("gguf"):
|
||||
all_tensors = candle.load_gguf(sys.argv[1])
|
||||
hparams = None
|
||||
vocab = None
|
||||
all_tensors, metadata = candle.load_gguf(sys.argv[1])
|
||||
vocab = metadata["tokenizer.ggml.tokens"]
|
||||
for i, v in enumerate(vocab):
|
||||
vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
|
||||
hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")}
|
||||
print(hparams)
|
||||
hparams = {
|
||||
'n_vocab': len(vocab),
|
||||
'n_embd': metadata['llama.embedding_length'],
|
||||
'n_mult': 256,
|
||||
'n_head': metadata['llama.attention.head_count'],
|
||||
'n_head_kv': metadata['llama.attention.head_count_kv'],
|
||||
'n_layer': metadata['llama.block_count'],
|
||||
'n_rot': metadata['llama.rope.dimension_count'],
|
||||
'rope_freq': metadata['llama.rope.freq_base'],
|
||||
'ftype': metadata['general.file_type'],
|
||||
}
|
||||
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
|
||||
|
||||
else:
|
||||
all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
|
||||
print(hparams)
|
||||
|
Reference in New Issue
Block a user