mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
More quantized llama in python. (#716)
* More quantized llama in python. * Expose a couple more functions. * Apply the last layer. * Use the vocab from the ggml files.
This commit is contained in:
@ -117,7 +117,6 @@ def precompute_freqs_cis(hparams, freq_base):
|
||||
idx_theta = [float(i) for i in range(MAX_SEQ_LEN)]
|
||||
idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1))
|
||||
m = idx_theta.matmul(theta.unsqueeze(0))
|
||||
print(m.shape)
|
||||
return (m.cos(), m.sin())
|
||||
|
||||
class QuantizedLlama:
|
||||
@ -143,28 +142,36 @@ class QuantizedLlama:
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, index_pos)
|
||||
x = self.norm(x)
|
||||
x = x.narrow(1, -1, 1).squeeze(1)
|
||||
x = self.output.matmul_t(x)
|
||||
return x
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
raise ValueError("missing weight file argument")
|
||||
filename = sys.argv[1]
|
||||
print(f"reading model file {filename}")
|
||||
if filename.endswith("gguf"):
|
||||
all_tensors = candle.load_gguf(sys.argv[1])
|
||||
hparams = None
|
||||
vocab = None
|
||||
else:
|
||||
all_tensors, hparams = candle.load_ggml(sys.argv[1])
|
||||
all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
|
||||
print(hparams)
|
||||
model = QuantizedLlama(hparams, all_tensors)
|
||||
print("model built, starting inference")
|
||||
|
||||
tokens = [1]
|
||||
for token_idx in range(1):
|
||||
print(tokens)
|
||||
for token_idx in range(500):
|
||||
last_token = tokens[-1]
|
||||
lt = candle.tensor([last_token]).unsqueeze(0)
|
||||
logits = model(lt, len(tokens))
|
||||
print(logits)
|
||||
next_token = "TODO: sample"
|
||||
# Greedy sampling for now
|
||||
# pr = candle.nn.softmax(logits, -1)
|
||||
m = logits.get(0).argmax_keepdim(-1)
|
||||
next_token = m.values()[0]
|
||||
print(vocab[next_token], end='', flush=True)
|
||||
tokens.append(next_token)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user