Finished scaffolding, lots of TODOs

- Most kernels just copy themselfs to get the shapes correct
- Matmul works only in 1 case and simply empty allocates otherwise
- Logits and randomized to make the demo finish itself.

Performance is quite bad (30ms/token), but lot's of prints and allocs and some actual sending to metal.

Couln't get it super high by removing the obvious blockers (println + the actual running matmuls).

Allocations takes between 1us and 100us and seems very stable, Maybe metal doesn't really have a smart allocator and we'll need to own it.
This commit is contained in:
Nicolas Patry
2023-11-02 15:32:28 +01:00
parent 82cce52e73
commit 7161002a34
11 changed files with 212 additions and 52 deletions

View File

@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor};
use candle::{Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::quantized_llama as model;
@ -367,9 +367,11 @@ fn main() -> anyhow::Result<()> {
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
logits_processor.sample(&logits)?
};
let prompt_dt = start_prompt_processing.elapsed();
@ -380,7 +382,7 @@ fn main() -> anyhow::Result<()> {
let start_post_prompt = std::time::Instant::now();
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
@ -393,6 +395,8 @@ fn main() -> anyhow::Result<()> {
&all_tokens[start_at..],
)?
};
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);