mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a print command to tensor-tools. (#1967)
* Add a print command to tensor-tools. * Add some flags to tweak the formatting.
This commit is contained in:
@ -117,6 +117,24 @@ enum Command {
|
||||
verbose: bool,
|
||||
},
|
||||
|
||||
Print {
|
||||
file: std::path::PathBuf,
|
||||
|
||||
names: Vec<String>,
|
||||
|
||||
/// The file format to use, if unspecified infer from the file extension.
|
||||
#[arg(long, value_enum)]
|
||||
format: Option<Format>,
|
||||
|
||||
/// Print the whole content of each tensor.
|
||||
#[arg(long)]
|
||||
full: bool,
|
||||
|
||||
/// Line width for printing the tensors.
|
||||
#[arg(long)]
|
||||
line_width: Option<usize>,
|
||||
},
|
||||
|
||||
Quantize {
|
||||
/// The input file(s), in safetensors format.
|
||||
in_file: Vec<std::path::PathBuf>,
|
||||
@ -150,6 +168,105 @@ struct Args {
|
||||
command: Command,
|
||||
}
|
||||
|
||||
fn run_print(
|
||||
file: &std::path::PathBuf,
|
||||
names: Vec<String>,
|
||||
format: Option<Format>,
|
||||
full: bool,
|
||||
line_width: Option<usize>,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
if full {
|
||||
candle_core::display::set_print_options_full();
|
||||
}
|
||||
if let Some(line_width) = line_width {
|
||||
candle_core::display::set_line_width(line_width)
|
||||
}
|
||||
let format = match format {
|
||||
Some(format) => format,
|
||||
None => match Format::infer(file) {
|
||||
Some(format) => format,
|
||||
None => {
|
||||
println!(
|
||||
"{file:?}: cannot infer format from file extension, use the --format flag"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
},
|
||||
};
|
||||
match format {
|
||||
Format::Npz => {
|
||||
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match tensors.get(name)? {
|
||||
Some(tensor) => println!("{tensor}"),
|
||||
None => println!("not found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Format::Safetensors => {
|
||||
use candle_core::safetensors::Load;
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
||||
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match tensors.get(name) {
|
||||
Some(tensor_view) => {
|
||||
let tensor = tensor_view.load(device)?;
|
||||
println!("{tensor}")
|
||||
}
|
||||
None => println!("not found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Format::Pth => {
|
||||
let pth_file = candle_core::pickle::PthTensors::new(file, None)?;
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match pth_file.get(name)? {
|
||||
Some(tensor) => {
|
||||
println!("{tensor}")
|
||||
}
|
||||
None => println!("not found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Format::Pickle => {
|
||||
candle_core::bail!("pickle format is not supported for print")
|
||||
}
|
||||
Format::Ggml => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match content.tensors.get(name) {
|
||||
Some(tensor) => {
|
||||
let tensor = tensor.dequantize(device)?;
|
||||
println!("{tensor}")
|
||||
}
|
||||
None => println!("not found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Format::Gguf => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = gguf_file::Content::read(&mut file)?;
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match content.tensor(&mut file, name, device) {
|
||||
Ok(tensor) => {
|
||||
let tensor = tensor.dequantize(device)?;
|
||||
println!("{tensor}")
|
||||
}
|
||||
Err(_) => println!("not found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_ls(
|
||||
file: &std::path::PathBuf,
|
||||
format: Option<Format>,
|
||||
@ -377,6 +494,13 @@ fn main() -> anyhow::Result<()> {
|
||||
run_ls(file, format.clone(), verbose, &device)?
|
||||
}
|
||||
}
|
||||
Command::Print {
|
||||
file,
|
||||
names,
|
||||
format,
|
||||
full,
|
||||
line_width,
|
||||
} => run_print(&file, names, format, full, line_width, &device)?,
|
||||
Command::Quantize {
|
||||
in_file,
|
||||
out_file,
|
||||
|
Reference in New Issue
Block a user