Streamline the glm4 example. (#2694)

This commit is contained in:
Laurent Mazare
2024-12-31 09:21:41 +01:00
committed by GitHub
parent e38e2a85dd
commit d60eba1408
3 changed files with 97 additions and 145 deletions

View File

@ -250,7 +250,11 @@ fn run(args: Args) -> Result<()> {
};
println!("img\n{img}");
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
let filename = match args.seed {
None => "out.jpg".to_string(),
Some(s) => format!("out-{s}.jpg"),
};
candle_examples::save_image(&img.i(0)?, filename)?;
Ok(())
}

View File

@ -7,48 +7,25 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
** Running with ~cuda~
#+begin_src shell
cargo run --example glm4 --release --features cuda
cargo run --example glm4 --release --features cuda -- --prompt "Hello world"
#+end_src
** Running with ~cpu~
#+begin_src shell
cargo run --example glm4 --release -- --cpu
cargo run --example glm4 --release -- --cpu--prompt "Hello world"
#+end_src
** Output Example
#+begin_src shell
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
Finished release [optimized] target(s) in 0.24s
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
cargo run --features cuda -r --example glm4 -- --prompt "Hello "
avx: true, neon: false, simd128: false, f16c: true
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
cache path .
retrieved the files in 6.88963ms
loaded the model in 6.113752297s
retrieved the files in 6.454375ms
loaded the model in 3.652383779s
starting the inference loop
[欢迎使用GLM-4,请输入prompt]
请你告诉我什么是FFT
266 tokens generated (34.50 token/s)
Result:
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换DFT的方法它广泛应用于信号处理、图像处理和数据分析等领域。
具体来说FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中我们通常需要知道信号的频率成分这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。
以下是使用 Python 中的 numpy 进行 FFT 的简单示例:
```python
import numpy as np
# 创建一个时域信号
t = np.linspace(0, 1, num=100)
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)
# 对该信号做FFT变换并计算其幅值谱
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))
```
在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
Hello 2018, hello new year! Im so excited to be back and sharing with you all my favorite things from the past month. This is a monthly series where I share whats been inspiring me lately in hopes that it will inspire you too!
...
#+end_src
This example will read prompt from stdin

View File

@ -12,59 +12,44 @@ struct TextGeneration {
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
args: Args,
dtype: DType,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
dtype: DType,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
fn new(model: Model, tokenizer: Tokenizer, args: Args, device: &Device, dtype: DType) -> Self {
let logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
args,
device: device.clone(),
dtype,
}
}
fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
use std::io::BufRead;
use std::io::BufReader;
fn run(&mut self) -> anyhow::Result<()> {
use std::io::Write;
let args = &self.args;
println!("starting the inference loop");
println!("[欢迎使用GLM-4,请输入prompt]");
let stdin = std::io::stdin();
let reader = BufReader::new(stdin);
for line in reader.lines() {
let line = line.expect("Failed to read line");
let tokens = self.tokenizer.encode(line, true).expect("tokens error");
let tokens = self
.tokenizer
.encode(args.prompt.to_string(), true)
.expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
if self.verbose_prompt {
if args.verbose {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
} else {
print!("{}", &args.prompt);
std::io::stdout().flush()?;
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
@ -76,22 +61,19 @@ impl TextGeneration {
std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();
let mut count = 0;
let mut result = vec![];
for index in 0..sample_len {
count += 1;
for index in 0..args.sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if self.repeat_penalty == 1. {
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
args.repeat_penalty,
&tokens[start_at..],
)?
};
@ -105,27 +87,22 @@ impl TextGeneration {
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("Token error");
if self.verbose_prompt {
.expect("token decode error");
if args.verbose {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
count, next_token, token
generated_tokens, next_token, token
);
}
result.push(token);
} else {
print!("{token}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
println!("Result:");
for tokens in result {
print!("{tokens}");
}
self.model.reset_kv_cache(); // clean the cache
}
Ok(())
}
}
@ -141,7 +118,11 @@ struct Args {
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
prompt: String,
/// Display the tokens for the specified prompt and outputs.
#[arg(long)]
verbose: bool,
/// The temperature used to generate samples.
#[arg(long)]
@ -197,28 +178,29 @@ fn main() -> anyhow::Result<()> {
);
let start = std::time::Instant::now();
println!("cache path {}", args.cache_path);
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(
args.cache_path.to_string().into(),
))
.build()
.map_err(anyhow::Error::msg)?;
let model_id = match args.model_id {
let model_id = match args.model_id.as_ref() {
Some(model_id) => model_id.to_string(),
None => "THUDM/glm-4-9b".to_string(),
};
let revision = match args.revision {
let revision = match args.revision.as_ref() {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
let tokenizer_filename = match args.tokenizer.as_ref() {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.map_err(anyhow::Error::msg)?,
};
let filenames = match args.weight_file {
let filenames = match args.weight_file.as_ref() {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
@ -238,18 +220,7 @@ fn main() -> anyhow::Result<()> {
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
dtype,
);
pipeline.run(args.sample_len)?;
let mut pipeline = TextGeneration::new(model, tokenizer, args, &device, dtype);
pipeline.run()?;
Ok(())
}