mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00
Tensor tools print all (#2543)
* Support whisper large-v3 turbo in the whisper-microphone example. * Print all tensors when no argument is provided.
This commit is contained in:
@ -197,6 +197,11 @@ fn run_print(
|
||||
match format {
|
||||
Format::Npz => {
|
||||
let tensors = candle::npy::NpzTensors::new(file)?;
|
||||
let names = if names.is_empty() {
|
||||
tensors.names().into_iter().map(|v| v.to_string()).collect()
|
||||
} else {
|
||||
names
|
||||
};
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match tensors.get(name)? {
|
||||
@ -209,6 +214,11 @@ fn run_print(
|
||||
use candle::safetensors::Load;
|
||||
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
|
||||
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
|
||||
let names = if names.is_empty() {
|
||||
tensors.keys().map(|v| v.to_string()).collect()
|
||||
} else {
|
||||
names
|
||||
};
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match tensors.get(name) {
|
||||
@ -222,6 +232,15 @@ fn run_print(
|
||||
}
|
||||
Format::Pth => {
|
||||
let pth_file = candle::pickle::PthTensors::new(file, None)?;
|
||||
let names = if names.is_empty() {
|
||||
pth_file
|
||||
.tensor_infos()
|
||||
.keys()
|
||||
.map(|v| v.to_string())
|
||||
.collect()
|
||||
} else {
|
||||
names
|
||||
};
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match pth_file.get(name)? {
|
||||
@ -238,6 +257,11 @@ fn run_print(
|
||||
Format::Ggml => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||
let names = if names.is_empty() {
|
||||
content.tensors.keys().map(|v| v.to_string()).collect()
|
||||
} else {
|
||||
names
|
||||
};
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match content.tensors.get(name) {
|
||||
@ -252,6 +276,11 @@ fn run_print(
|
||||
Format::Gguf => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = gguf_file::Content::read(&mut file)?;
|
||||
let names = if names.is_empty() {
|
||||
content.tensor_infos.keys().map(|v| v.to_string()).collect()
|
||||
} else {
|
||||
names
|
||||
};
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match content.tensor(&mut file, name, device) {
|
||||
|
Reference in New Issue
Block a user