mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add a print command to tensor-tools. (#1967)
* Add a print command to tensor-tools. * Add some flags to tweak the formatting.
This commit is contained in:
@ -117,6 +117,24 @@ enum Command {
|
|||||||
verbose: bool,
|
verbose: bool,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
Print {
|
||||||
|
file: std::path::PathBuf,
|
||||||
|
|
||||||
|
names: Vec<String>,
|
||||||
|
|
||||||
|
/// The file format to use, if unspecified infer from the file extension.
|
||||||
|
#[arg(long, value_enum)]
|
||||||
|
format: Option<Format>,
|
||||||
|
|
||||||
|
/// Print the whole content of each tensor.
|
||||||
|
#[arg(long)]
|
||||||
|
full: bool,
|
||||||
|
|
||||||
|
/// Line width for printing the tensors.
|
||||||
|
#[arg(long)]
|
||||||
|
line_width: Option<usize>,
|
||||||
|
},
|
||||||
|
|
||||||
Quantize {
|
Quantize {
|
||||||
/// The input file(s), in safetensors format.
|
/// The input file(s), in safetensors format.
|
||||||
in_file: Vec<std::path::PathBuf>,
|
in_file: Vec<std::path::PathBuf>,
|
||||||
@ -150,6 +168,105 @@ struct Args {
|
|||||||
command: Command,
|
command: Command,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn run_print(
|
||||||
|
file: &std::path::PathBuf,
|
||||||
|
names: Vec<String>,
|
||||||
|
format: Option<Format>,
|
||||||
|
full: bool,
|
||||||
|
line_width: Option<usize>,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<()> {
|
||||||
|
if full {
|
||||||
|
candle_core::display::set_print_options_full();
|
||||||
|
}
|
||||||
|
if let Some(line_width) = line_width {
|
||||||
|
candle_core::display::set_line_width(line_width)
|
||||||
|
}
|
||||||
|
let format = match format {
|
||||||
|
Some(format) => format,
|
||||||
|
None => match Format::infer(file) {
|
||||||
|
Some(format) => format,
|
||||||
|
None => {
|
||||||
|
println!(
|
||||||
|
"{file:?}: cannot infer format from file extension, use the --format flag"
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
match format {
|
||||||
|
Format::Npz => {
|
||||||
|
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
||||||
|
for name in names.iter() {
|
||||||
|
println!("==== {name} ====");
|
||||||
|
match tensors.get(name)? {
|
||||||
|
Some(tensor) => println!("{tensor}"),
|
||||||
|
None => println!("not found"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Format::Safetensors => {
|
||||||
|
use candle_core::safetensors::Load;
|
||||||
|
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
||||||
|
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
|
||||||
|
for name in names.iter() {
|
||||||
|
println!("==== {name} ====");
|
||||||
|
match tensors.get(name) {
|
||||||
|
Some(tensor_view) => {
|
||||||
|
let tensor = tensor_view.load(device)?;
|
||||||
|
println!("{tensor}")
|
||||||
|
}
|
||||||
|
None => println!("not found"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Format::Pth => {
|
||||||
|
let pth_file = candle_core::pickle::PthTensors::new(file, None)?;
|
||||||
|
for name in names.iter() {
|
||||||
|
println!("==== {name} ====");
|
||||||
|
match pth_file.get(name)? {
|
||||||
|
Some(tensor) => {
|
||||||
|
println!("{tensor}")
|
||||||
|
}
|
||||||
|
None => println!("not found"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Format::Pickle => {
|
||||||
|
candle_core::bail!("pickle format is not supported for print")
|
||||||
|
}
|
||||||
|
Format::Ggml => {
|
||||||
|
let mut file = std::fs::File::open(file)?;
|
||||||
|
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||||
|
for name in names.iter() {
|
||||||
|
println!("==== {name} ====");
|
||||||
|
match content.tensors.get(name) {
|
||||||
|
Some(tensor) => {
|
||||||
|
let tensor = tensor.dequantize(device)?;
|
||||||
|
println!("{tensor}")
|
||||||
|
}
|
||||||
|
None => println!("not found"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Format::Gguf => {
|
||||||
|
let mut file = std::fs::File::open(file)?;
|
||||||
|
let content = gguf_file::Content::read(&mut file)?;
|
||||||
|
for name in names.iter() {
|
||||||
|
println!("==== {name} ====");
|
||||||
|
match content.tensor(&mut file, name, device) {
|
||||||
|
Ok(tensor) => {
|
||||||
|
let tensor = tensor.dequantize(device)?;
|
||||||
|
println!("{tensor}")
|
||||||
|
}
|
||||||
|
Err(_) => println!("not found"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn run_ls(
|
fn run_ls(
|
||||||
file: &std::path::PathBuf,
|
file: &std::path::PathBuf,
|
||||||
format: Option<Format>,
|
format: Option<Format>,
|
||||||
@ -377,6 +494,13 @@ fn main() -> anyhow::Result<()> {
|
|||||||
run_ls(file, format.clone(), verbose, &device)?
|
run_ls(file, format.clone(), verbose, &device)?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Command::Print {
|
||||||
|
file,
|
||||||
|
names,
|
||||||
|
format,
|
||||||
|
full,
|
||||||
|
line_width,
|
||||||
|
} => run_print(&file, names, format, full, line_width, &device)?,
|
||||||
Command::Quantize {
|
Command::Quantize {
|
||||||
in_file,
|
in_file,
|
||||||
out_file,
|
out_file,
|
||||||
|
Reference in New Issue
Block a user