mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
This commit is contained in:
@ -2,181 +2,59 @@
|
||||
import sys
|
||||
from typing import Dict, Tuple, Any
|
||||
import candle
|
||||
from candle import Tensor, QTensor, utils, nn
|
||||
from candle.models.llama import QuantizedLlama
|
||||
from candle import utils
|
||||
|
||||
MAX_SEQ_LEN = 4096
|
||||
|
||||
def masked_fill(on_false:Tensor, mask:Tensor, on_true:Tensor):
|
||||
shape = mask.shape
|
||||
on_true = candle.tensor(on_true).broadcast_as(shape)
|
||||
return mask.where_cond(on_true, on_false)
|
||||
|
||||
class RmsNorm:
|
||||
def __init__(self, qtensor:QTensor):
|
||||
self.weight = qtensor.dequantize()
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
b_size, seq_len, hidden_size = x.shape
|
||||
norm_x = x.sqr().sum_keepdim(2) / hidden_size
|
||||
x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
|
||||
return x_normed.broadcast_mul(self.weight)
|
||||
|
||||
class QuantizedLayer:
|
||||
def __init__(self, layer_idx:int, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor], cos_sin:Tuple[Tensor,Tensor]):
|
||||
p = f"layers.{layer_idx}"
|
||||
self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
|
||||
self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
|
||||
self.attention_wv = all_tensors[f"{p}.attention.wv.weight"]
|
||||
self.attention_wo = all_tensors[f"{p}.attention.wo.weight"]
|
||||
self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"]
|
||||
self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"]
|
||||
self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"]
|
||||
self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"])
|
||||
self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"])
|
||||
|
||||
self.n_head = hparams["n_head"]
|
||||
self.n_kv_head = self.n_head
|
||||
self.head_dim = hparams["n_embd"] // self.n_head
|
||||
|
||||
self.kv_cache = None
|
||||
self.cos = cos_sin[0]
|
||||
self.sin = cos_sin[1]
|
||||
|
||||
def __call__(self, x:Tensor, mask:Tensor, index_pos:int):
|
||||
residual = x
|
||||
x = self.attn_norm(x)
|
||||
attn = self.forward_attn(x, mask, index_pos)
|
||||
x = attn + residual
|
||||
|
||||
residual = x
|
||||
x = self.ffn_norm(x)
|
||||
w1 = self.ffw1.matmul_t(x)
|
||||
w3 = self.ffw3.matmul_t(x)
|
||||
mlp = self.ffw2.matmul_t(nn.silu(w1) * w3)
|
||||
|
||||
return mlp + residual
|
||||
|
||||
def forward_attn(self, x:Tensor, mask:Tensor, index_pos:int):
|
||||
b_size, seq_len, n_embd = x.shape
|
||||
q = self.attention_wq.matmul_t(x)
|
||||
k = self.attention_wk.matmul_t(x)
|
||||
v = self.attention_wv.matmul_t(x)
|
||||
|
||||
q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2)
|
||||
k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
|
||||
v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
|
||||
|
||||
q = self.apply_rotary_emb(q, index_pos)
|
||||
k = self.apply_rotary_emb(k, index_pos)
|
||||
|
||||
if self.kv_cache is not None and index_pos > 0:
|
||||
prev_k, prev_v = self.kv_cache
|
||||
k = candle.cat([prev_k, k], 2).contiguous()
|
||||
v = candle.cat([prev_v, v], 2).contiguous()
|
||||
|
||||
self.kv_cache = (k, v)
|
||||
|
||||
# TODO: maybe repeat k/v here if we start supporting MQA.
|
||||
|
||||
att = q.matmul(k.t()) / self.head_dim**0.5
|
||||
mask = mask.broadcast_as(att.shape)
|
||||
att = masked_fill(att, mask, float("-inf"))
|
||||
att = nn.softmax(att, -1)
|
||||
y = att.matmul(v.contiguous())
|
||||
y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
|
||||
return self.attention_wo.matmul_t(y)
|
||||
|
||||
def apply_rotary_emb(self, x:Tensor, index_pos:int):
|
||||
(b_size, n_head, seq_len, n_embd) = x.shape
|
||||
cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
|
||||
sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
|
||||
x = x.reshape((b_size, n_head, seq_len, n_embd//2, 2))
|
||||
x0 = x.narrow(-1, 0, 1)
|
||||
x1 = x.narrow(-1, 1, 1)
|
||||
y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin)
|
||||
y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos)
|
||||
rope = candle.cat([y0, y1], -1)
|
||||
return rope.flatten_from(-2)
|
||||
|
||||
def precompute_freqs_cis(hparams, freq_base):
|
||||
head_dim = hparams["n_embd"] // hparams["n_head"]
|
||||
theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)]
|
||||
theta = candle.tensor(theta)
|
||||
idx_theta = [float(i) for i in range(MAX_SEQ_LEN)]
|
||||
idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1))
|
||||
m = idx_theta.matmul(theta.unsqueeze(0))
|
||||
return (m.cos(), m.sin())
|
||||
|
||||
class QuantizedLlama:
|
||||
def __init__(self, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor]):
|
||||
self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
|
||||
self.norm = RmsNorm(all_tensors["norm.weight"])
|
||||
self.output = all_tensors["output.weight"]
|
||||
self.layers = []
|
||||
rope_freq = hparams.get("rope_freq", 10000.)
|
||||
cos_sin = precompute_freqs_cis(hparams, rope_freq)
|
||||
for layer_idx in range(hparams["n_layer"]):
|
||||
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
|
||||
self.layers.append(layer)
|
||||
|
||||
def __call__(self, token:Tensor, index_pos:int):
|
||||
b_size, seq_len = token.shape
|
||||
vocab_size, hidden_size = self.tok_embeddings.shape
|
||||
token = token.reshape((b_size * seq_len,))
|
||||
x = self.tok_embeddings.index_select(token, 0)
|
||||
x = x.reshape((b_size, seq_len, hidden_size))
|
||||
|
||||
mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)]
|
||||
mask = candle.tensor(mask).reshape((seq_len, seq_len))
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, index_pos)
|
||||
x = self.norm(x)
|
||||
x = x.narrow(1, -1, 1).squeeze(1)
|
||||
x = self.output.matmul_t(x)
|
||||
return x
|
||||
|
||||
def gguf_rename(tensor_name:str):
|
||||
if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight'
|
||||
if tensor_name == 'output_norm.weight': return 'norm.weight'
|
||||
tensor_name = tensor_name.replace('blk.', 'layers.')
|
||||
tensor_name = tensor_name.replace('.attn_q.', '.attention.wq.')
|
||||
tensor_name = tensor_name.replace('.attn_k.', '.attention.wk.')
|
||||
tensor_name = tensor_name.replace('.attn_v.', '.attention.wv.')
|
||||
tensor_name = tensor_name.replace('.attn_output.', '.attention.wo.')
|
||||
tensor_name = tensor_name.replace('.ffn_gate.', '.feed_forward.w1.')
|
||||
tensor_name = tensor_name.replace('.ffn_down.', '.feed_forward.w2.')
|
||||
tensor_name = tensor_name.replace('.ffn_up.', '.feed_forward.w3.')
|
||||
tensor_name = tensor_name.replace('.attn_norm.', '.attention_norm.')
|
||||
def gguf_rename(tensor_name: str):
|
||||
if tensor_name == "token_embd.weight":
|
||||
return "tok_embeddings.weight"
|
||||
if tensor_name == "output_norm.weight":
|
||||
return "norm.weight"
|
||||
tensor_name = tensor_name.replace("blk.", "layers.")
|
||||
tensor_name = tensor_name.replace(".attn_q.", ".attention.wq.")
|
||||
tensor_name = tensor_name.replace(".attn_k.", ".attention.wk.")
|
||||
tensor_name = tensor_name.replace(".attn_v.", ".attention.wv.")
|
||||
tensor_name = tensor_name.replace(".attn_output.", ".attention.wo.")
|
||||
tensor_name = tensor_name.replace(".ffn_gate.", ".feed_forward.w1.")
|
||||
tensor_name = tensor_name.replace(".ffn_down.", ".feed_forward.w2.")
|
||||
tensor_name = tensor_name.replace(".ffn_up.", ".feed_forward.w3.")
|
||||
tensor_name = tensor_name.replace(".attn_norm.", ".attention_norm.")
|
||||
return tensor_name
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
raise ValueError("missing weight file argument")
|
||||
|
||||
filename = sys.argv[1]
|
||||
print(f"reading model file {filename}")
|
||||
if filename.endswith("gguf"):
|
||||
all_tensors, metadata = utils.load_gguf(sys.argv[1])
|
||||
all_tensors, metadata = utils.load_gguf(filename)
|
||||
vocab = metadata["tokenizer.ggml.tokens"]
|
||||
for i, v in enumerate(vocab):
|
||||
vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
|
||||
vocab[i] = "\n" if v == "<0x0A>" else v.replace("▁", " ")
|
||||
hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")}
|
||||
print(hparams)
|
||||
hparams = {
|
||||
'n_vocab': len(vocab),
|
||||
'n_embd': metadata['llama.embedding_length'],
|
||||
'n_mult': 256,
|
||||
'n_head': metadata['llama.attention.head_count'],
|
||||
'n_head_kv': metadata['llama.attention.head_count_kv'],
|
||||
'n_layer': metadata['llama.block_count'],
|
||||
'n_rot': metadata['llama.rope.dimension_count'],
|
||||
'rope_freq': metadata.get('llama.rope.freq_base', 10000.),
|
||||
'ftype': metadata['general.file_type'],
|
||||
"n_vocab": len(vocab),
|
||||
"n_embd": metadata["llama.embedding_length"],
|
||||
"n_mult": 256,
|
||||
"n_head": metadata["llama.attention.head_count"],
|
||||
"n_head_kv": metadata["llama.attention.head_count_kv"],
|
||||
"n_layer": metadata["llama.block_count"],
|
||||
"n_rot": metadata["llama.rope.dimension_count"],
|
||||
"rope_freq": metadata.get("llama.rope.freq_base", 10000.0),
|
||||
"ftype": metadata["general.file_type"],
|
||||
"context_length": metadata["llama.context_length"],
|
||||
}
|
||||
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
|
||||
|
||||
all_tensors = {gguf_rename(k): v for k, v in all_tensors.items()}
|
||||
else:
|
||||
all_tensors, hparams, vocab = utils.load_ggml(sys.argv[1])
|
||||
all_tensors, hparams, vocab = utils.load_ggml(filename)
|
||||
hparams["context_length"] = 2048
|
||||
|
||||
print(hparams)
|
||||
model = QuantizedLlama(hparams, all_tensors)
|
||||
print("model built, starting inference")
|
||||
@ -185,13 +63,14 @@ def main():
|
||||
for token_idx in range(500):
|
||||
last_token = tokens[-1]
|
||||
lt = candle.tensor([last_token]).unsqueeze(0)
|
||||
logits = model(lt, len(tokens))
|
||||
logits = model.forward(lt, len(tokens))
|
||||
# Greedy sampling for now
|
||||
# pr = candle.nn.softmax(logits, -1)
|
||||
m = logits.get(0).argmax_keepdim(-1)
|
||||
next_token = m.values()[0]
|
||||
print(vocab[next_token], end='', flush=True)
|
||||
print(vocab[next_token], end="", flush=True)
|
||||
tokens.append(next_token)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user