mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Streamline the glm4 example. (#2694)
This commit is contained in:
@ -250,7 +250,11 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
println!("img\n{img}");
|
println!("img\n{img}");
|
||||||
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
|
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
|
||||||
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
|
let filename = match args.seed {
|
||||||
|
None => "out.jpg".to_string(),
|
||||||
|
Some(s) => format!("out-{s}.jpg"),
|
||||||
|
};
|
||||||
|
candle_examples::save_image(&img.i(0)?, filename)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,48 +7,25 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
|
|||||||
** Running with ~cuda~
|
** Running with ~cuda~
|
||||||
|
|
||||||
#+begin_src shell
|
#+begin_src shell
|
||||||
cargo run --example glm4 --release --features cuda
|
cargo run --example glm4 --release --features cuda -- --prompt "Hello world"
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
** Running with ~cpu~
|
** Running with ~cpu~
|
||||||
#+begin_src shell
|
#+begin_src shell
|
||||||
cargo run --example glm4 --release -- --cpu
|
cargo run --example glm4 --release -- --cpu--prompt "Hello world"
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
** Output Example
|
** Output Example
|
||||||
#+begin_src shell
|
#+begin_src shell
|
||||||
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
|
cargo run --features cuda -r --example glm4 -- --prompt "Hello "
|
||||||
Finished release [optimized] target(s) in 0.24s
|
|
||||||
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
|
|
||||||
avx: true, neon: false, simd128: false, f16c: true
|
avx: true, neon: false, simd128: false, f16c: true
|
||||||
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
|
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
|
||||||
cache path .
|
retrieved the files in 6.454375ms
|
||||||
retrieved the files in 6.88963ms
|
loaded the model in 3.652383779s
|
||||||
loaded the model in 6.113752297s
|
|
||||||
starting the inference loop
|
starting the inference loop
|
||||||
[欢迎使用GLM-4,请输入prompt]
|
Hello 2018, hello new year! I’m so excited to be back and sharing with you all my favorite things from the past month. This is a monthly series where I share what’s been inspiring me lately in hopes that it will inspire you too!
|
||||||
请你告诉我什么是FFT
|
...
|
||||||
266 tokens generated (34.50 token/s)
|
|
||||||
Result:
|
|
||||||
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换(DFT)的方法,它广泛应用于信号处理、图像处理和数据分析等领域。
|
|
||||||
|
|
||||||
具体来说,FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中,我们通常需要知道信号的频率成分,这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高,而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。
|
|
||||||
|
|
||||||
以下是使用 Python 中的 numpy 进行 FFT 的简单示例:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# 创建一个时域信号
|
|
||||||
t = np.linspace(0, 1, num=100)
|
|
||||||
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)
|
|
||||||
|
|
||||||
# 对该信号做FFT变换,并计算其幅值谱
|
|
||||||
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
|
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
This example will read prompt from stdin
|
This example will read prompt from stdin
|
||||||
|
@ -12,59 +12,44 @@ struct TextGeneration {
|
|||||||
device: Device,
|
device: Device,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
logits_processor: LogitsProcessor,
|
logits_processor: LogitsProcessor,
|
||||||
repeat_penalty: f32,
|
args: Args,
|
||||||
repeat_last_n: usize,
|
|
||||||
verbose_prompt: bool,
|
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TextGeneration {
|
impl TextGeneration {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn new(
|
fn new(model: Model, tokenizer: Tokenizer, args: Args, device: &Device, dtype: DType) -> Self {
|
||||||
model: Model,
|
let logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
verbose_prompt: bool,
|
|
||||||
device: &Device,
|
|
||||||
dtype: DType,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
logits_processor,
|
logits_processor,
|
||||||
repeat_penalty,
|
args,
|
||||||
repeat_last_n,
|
|
||||||
verbose_prompt,
|
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
dtype,
|
dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
|
fn run(&mut self) -> anyhow::Result<()> {
|
||||||
use std::io::BufRead;
|
|
||||||
use std::io::BufReader;
|
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
let args = &self.args;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
println!("[欢迎使用GLM-4,请输入prompt]");
|
|
||||||
let stdin = std::io::stdin();
|
|
||||||
let reader = BufReader::new(stdin);
|
|
||||||
for line in reader.lines() {
|
|
||||||
let line = line.expect("Failed to read line");
|
|
||||||
|
|
||||||
let tokens = self.tokenizer.encode(line, true).expect("tokens error");
|
let tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.encode(args.prompt.to_string(), true)
|
||||||
|
.expect("tokens error");
|
||||||
if tokens.is_empty() {
|
if tokens.is_empty() {
|
||||||
panic!("Empty prompts are not supported in the chatglm model.")
|
panic!("Empty prompts are not supported in the chatglm model.")
|
||||||
}
|
}
|
||||||
if self.verbose_prompt {
|
if args.verbose {
|
||||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
println!("{id:7} -> '{token}'");
|
println!("{id:7} -> '{token}'");
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
print!("{}", &args.prompt);
|
||||||
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||||
Some(token) => *token,
|
Some(token) => *token,
|
||||||
@ -76,22 +61,19 @@ impl TextGeneration {
|
|||||||
std::io::stdout().flush().expect("output flush error");
|
std::io::stdout().flush().expect("output flush error");
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
|
|
||||||
let mut count = 0;
|
for index in 0..args.sample_len {
|
||||||
let mut result = vec![];
|
|
||||||
for index in 0..sample_len {
|
|
||||||
count += 1;
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
let logits = self.model.forward(&input)?;
|
let logits = self.model.forward(&input)?;
|
||||||
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
|
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
|
||||||
let logits = if self.repeat_penalty == 1. {
|
let logits = if args.repeat_penalty == 1. {
|
||||||
logits
|
logits
|
||||||
} else {
|
} else {
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
&logits,
|
&logits,
|
||||||
self.repeat_penalty,
|
args.repeat_penalty,
|
||||||
&tokens[start_at..],
|
&tokens[start_at..],
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
@ -105,27 +87,22 @@ impl TextGeneration {
|
|||||||
let token = self
|
let token = self
|
||||||
.tokenizer
|
.tokenizer
|
||||||
.decode(&[next_token], true)
|
.decode(&[next_token], true)
|
||||||
.expect("Token error");
|
.expect("token decode error");
|
||||||
if self.verbose_prompt {
|
if args.verbose {
|
||||||
println!(
|
println!(
|
||||||
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
|
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
|
||||||
count, next_token, token
|
generated_tokens, next_token, token
|
||||||
);
|
);
|
||||||
}
|
} else {
|
||||||
result.push(token);
|
print!("{token}");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
);
|
);
|
||||||
println!("Result:");
|
|
||||||
for tokens in result {
|
|
||||||
print!("{tokens}");
|
|
||||||
}
|
|
||||||
self.model.reset_kv_cache(); // clean the cache
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -141,7 +118,11 @@ struct Args {
|
|||||||
|
|
||||||
/// Display the token for the specified prompt.
|
/// Display the token for the specified prompt.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
verbose_prompt: bool,
|
prompt: String,
|
||||||
|
|
||||||
|
/// Display the tokens for the specified prompt and outputs.
|
||||||
|
#[arg(long)]
|
||||||
|
verbose: bool,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -197,28 +178,29 @@ fn main() -> anyhow::Result<()> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
println!("cache path {}", args.cache_path);
|
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(
|
||||||
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
|
args.cache_path.to_string().into(),
|
||||||
|
))
|
||||||
.build()
|
.build()
|
||||||
.map_err(anyhow::Error::msg)?;
|
.map_err(anyhow::Error::msg)?;
|
||||||
|
|
||||||
let model_id = match args.model_id {
|
let model_id = match args.model_id.as_ref() {
|
||||||
Some(model_id) => model_id.to_string(),
|
Some(model_id) => model_id.to_string(),
|
||||||
None => "THUDM/glm-4-9b".to_string(),
|
None => "THUDM/glm-4-9b".to_string(),
|
||||||
};
|
};
|
||||||
let revision = match args.revision {
|
let revision = match args.revision.as_ref() {
|
||||||
Some(rev) => rev.to_string(),
|
Some(rev) => rev.to_string(),
|
||||||
None => "main".to_string(),
|
None => "main".to_string(),
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
let tokenizer_filename = match args.tokenizer {
|
let tokenizer_filename = match args.tokenizer.as_ref() {
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => api
|
None => api
|
||||||
.model("THUDM/codegeex4-all-9b".to_string())
|
.model("THUDM/codegeex4-all-9b".to_string())
|
||||||
.get("tokenizer.json")
|
.get("tokenizer.json")
|
||||||
.map_err(anyhow::Error::msg)?,
|
.map_err(anyhow::Error::msg)?,
|
||||||
};
|
};
|
||||||
let filenames = match args.weight_file {
|
let filenames = match args.weight_file.as_ref() {
|
||||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
};
|
};
|
||||||
@ -238,18 +220,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
let mut pipeline = TextGeneration::new(model, tokenizer, args, &device, dtype);
|
||||||
model,
|
pipeline.run()?;
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
args.verbose_prompt,
|
|
||||||
&device,
|
|
||||||
dtype,
|
|
||||||
);
|
|
||||||
pipeline.run(args.sample_len)?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user