Files
candle/candle-pyo3/quant-llama.py
Lukas Kreussel 03e194123d Add return types to *.pyi stubs (#880)
* Start generating return types

* Finish tensor type hinting

* Add `save_gguf` to `utils`

* Typehint `quant-llama.py`
2023-09-17 22:11:01 +01:00

198 lines
8.1 KiB
Python

# This example shows how the candle Python api can be used to replicate llama.cpp.
import sys
from typing import Dict, Tuple, Any
import candle
from candle import Tensor, QTensor, utils, nn
MAX_SEQ_LEN = 4096
def masked_fill(on_false:Tensor, mask:Tensor, on_true:Tensor):
shape = mask.shape
on_true = candle.tensor(on_true).broadcast_as(shape)
return mask.where_cond(on_true, on_false)
class RmsNorm:
def __init__(self, qtensor:QTensor):
self.weight = qtensor.dequantize()
def __call__(self, x:Tensor):
b_size, seq_len, hidden_size = x.shape
norm_x = x.sqr().sum_keepdim(2) / hidden_size
x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
return x_normed.broadcast_mul(self.weight)
class QuantizedLayer:
def __init__(self, layer_idx:int, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor], cos_sin:Tuple[Tensor,Tensor]):
p = f"layers.{layer_idx}"
self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
self.attention_wv = all_tensors[f"{p}.attention.wv.weight"]
self.attention_wo = all_tensors[f"{p}.attention.wo.weight"]
self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"]
self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"]
self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"]
self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"])
self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"])
self.n_head = hparams["n_head"]
self.n_kv_head = self.n_head
self.head_dim = hparams["n_embd"] // self.n_head
self.kv_cache = None
self.cos = cos_sin[0]
self.sin = cos_sin[1]
def __call__(self, x:Tensor, mask:Tensor, index_pos:int):
residual = x
x = self.attn_norm(x)
attn = self.forward_attn(x, mask, index_pos)
x = attn + residual
residual = x
x = self.ffn_norm(x)
w1 = self.ffw1.matmul_t(x)
w3 = self.ffw3.matmul_t(x)
mlp = self.ffw2.matmul_t(nn.silu(w1) * w3)
return mlp + residual
def forward_attn(self, x:Tensor, mask:Tensor, index_pos:int):
b_size, seq_len, n_embd = x.shape
q = self.attention_wq.matmul_t(x)
k = self.attention_wk.matmul_t(x)
v = self.attention_wv.matmul_t(x)
q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2)
k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
q = self.apply_rotary_emb(q, index_pos)
k = self.apply_rotary_emb(k, index_pos)
if self.kv_cache is not None and index_pos > 0:
prev_k, prev_v = self.kv_cache
k = candle.cat([prev_k, k], 2).contiguous()
v = candle.cat([prev_v, v], 2).contiguous()
self.kv_cache = (k, v)
# TODO: maybe repeat k/v here if we start supporting MQA.
att = q.matmul(k.t()) / self.head_dim**0.5
mask = mask.broadcast_as(att.shape)
att = masked_fill(att, mask, float("-inf"))
att = nn.softmax(att, -1)
y = att.matmul(v.contiguous())
y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
return self.attention_wo.matmul_t(y)
def apply_rotary_emb(self, x:Tensor, index_pos:int):
(b_size, n_head, seq_len, n_embd) = x.shape
cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
x = x.reshape((b_size, n_head, seq_len, n_embd//2, 2))
x0 = x.narrow(-1, 0, 1)
x1 = x.narrow(-1, 1, 1)
y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin)
y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos)
rope = candle.cat([y0, y1], -1)
return rope.flatten_from(-2)
def precompute_freqs_cis(hparams, freq_base):
head_dim = hparams["n_embd"] // hparams["n_head"]
theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)]
theta = candle.tensor(theta)
idx_theta = [float(i) for i in range(MAX_SEQ_LEN)]
idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1))
m = idx_theta.matmul(theta.unsqueeze(0))
return (m.cos(), m.sin())
class QuantizedLlama:
def __init__(self, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor]):
self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
self.norm = RmsNorm(all_tensors["norm.weight"])
self.output = all_tensors["output.weight"]
self.layers = []
rope_freq = hparams.get("rope_freq", 10000.)
cos_sin = precompute_freqs_cis(hparams, rope_freq)
for layer_idx in range(hparams["n_layer"]):
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
self.layers.append(layer)
def __call__(self, token:Tensor, index_pos:int):
b_size, seq_len = token.shape
vocab_size, hidden_size = self.tok_embeddings.shape
token = token.reshape((b_size * seq_len,))
x = self.tok_embeddings.index_select(token, 0)
x = x.reshape((b_size, seq_len, hidden_size))
mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)]
mask = candle.tensor(mask).reshape((seq_len, seq_len))
for layer in self.layers:
x = layer(x, mask, index_pos)
x = self.norm(x)
x = x.narrow(1, -1, 1).squeeze(1)
x = self.output.matmul_t(x)
return x
def gguf_rename(tensor_name:str):
if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight'
if tensor_name == 'output_norm.weight': return 'norm.weight'
tensor_name = tensor_name.replace('blk.', 'layers.')
tensor_name = tensor_name.replace('.attn_q.', '.attention.wq.')
tensor_name = tensor_name.replace('.attn_k.', '.attention.wk.')
tensor_name = tensor_name.replace('.attn_v.', '.attention.wv.')
tensor_name = tensor_name.replace('.attn_output.', '.attention.wo.')
tensor_name = tensor_name.replace('.ffn_gate.', '.feed_forward.w1.')
tensor_name = tensor_name.replace('.ffn_down.', '.feed_forward.w2.')
tensor_name = tensor_name.replace('.ffn_up.', '.feed_forward.w3.')
tensor_name = tensor_name.replace('.attn_norm.', '.attention_norm.')
return tensor_name
def main():
if len(sys.argv) < 2:
raise ValueError("missing weight file argument")
filename = sys.argv[1]
print(f"reading model file {filename}")
if filename.endswith("gguf"):
all_tensors, metadata = utils.load_gguf(sys.argv[1])
vocab = metadata["tokenizer.ggml.tokens"]
for i, v in enumerate(vocab):
vocab[i] = '\n' if v == '<0x0A>' else v.replace('', ' ')
hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")}
print(hparams)
hparams = {
'n_vocab': len(vocab),
'n_embd': metadata['llama.embedding_length'],
'n_mult': 256,
'n_head': metadata['llama.attention.head_count'],
'n_head_kv': metadata['llama.attention.head_count_kv'],
'n_layer': metadata['llama.block_count'],
'n_rot': metadata['llama.rope.dimension_count'],
'rope_freq': metadata.get('llama.rope.freq_base', 10000.),
'ftype': metadata['general.file_type'],
}
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
else:
all_tensors, hparams, vocab = utils.load_ggml(sys.argv[1])
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
print("model built, starting inference")
tokens = [1]
for token_idx in range(500):
last_token = tokens[-1]
lt = candle.tensor([last_token]).unsqueeze(0)
logits = model(lt, len(tokens))
# Greedy sampling for now
# pr = candle.nn.softmax(logits, -1)
m = logits.get(0).argmax_keepdim(-1)
next_token = m.values()[0]
print(vocab[next_token], end='', flush=True)
tokens.append(next_token)
if __name__ == '__main__':
main()