mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Make it easier to use whisper samples from the repo. (#112)
* Make it easier to use samples from the repo. * Use f32 for accumulation in the f16/bf16 kernels.
This commit is contained in:
@ -197,8 +197,8 @@ impl Decoder {
|
|||||||
let (_, _, content_frames) = mel.shape().r3()?;
|
let (_, _, content_frames) = mel.shape().r3()?;
|
||||||
let mut seek = 0;
|
let mut seek = 0;
|
||||||
let mut segments = vec![];
|
let mut segments = vec![];
|
||||||
let start = std::time::Instant::now();
|
|
||||||
while seek < content_frames {
|
while seek < content_frames {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||||
@ -214,7 +214,7 @@ impl Decoder {
|
|||||||
duration: segment_duration,
|
duration: segment_duration,
|
||||||
dr,
|
dr,
|
||||||
};
|
};
|
||||||
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
|
println!("{seek}: {segment:?}, in {:?}", start.elapsed());
|
||||||
segments.push(segment)
|
segments.push(segment)
|
||||||
}
|
}
|
||||||
Ok(segments)
|
Ok(segments)
|
||||||
@ -236,8 +236,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
/// The input to be processed, in wav formats, will default to `jfk.wav`
|
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
|
||||||
/// https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
|
/// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following
|
||||||
|
/// repo: https://huggingface.co/datasets/Narsil/candle_demo/
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
input: Option<String>,
|
input: Option<String>,
|
||||||
|
|
||||||
@ -286,19 +287,27 @@ fn main() -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
|
let sample = if let Some(input) = args.input {
|
||||||
|
if let Some(sample) = input.strip_prefix("sample:") {
|
||||||
|
api.get(
|
||||||
|
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
||||||
|
&format!("samples_{sample}.wav"),
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
std::path::PathBuf::from(input)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||||
|
api.get(
|
||||||
|
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
||||||
|
"samples_jfk.wav",
|
||||||
|
)?
|
||||||
|
};
|
||||||
(
|
(
|
||||||
api.get(&repo, "config.json")?,
|
api.get(&repo, "config.json")?,
|
||||||
api.get(&repo, "tokenizer.json")?,
|
api.get(&repo, "tokenizer.json")?,
|
||||||
api.get(&repo, "model.safetensors")?,
|
api.get(&repo, "model.safetensors")?,
|
||||||
if let Some(input) = args.input {
|
sample,
|
||||||
std::path::PathBuf::from(input)
|
|
||||||
} else {
|
|
||||||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
|
||||||
api.get(
|
|
||||||
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
|
||||||
"samples_jfk.wav",
|
|
||||||
)?
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include "cuda_utils.cuh"
|
#include "cuda_utils.cuh"
|
||||||
#include<stdint.h>
|
#include<stdint.h>
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename A>
|
||||||
__device__ void conv1d(
|
__device__ void conv1d(
|
||||||
const size_t src_numel,
|
const size_t src_numel,
|
||||||
const size_t l_out,
|
const size_t l_out,
|
||||||
@ -30,7 +30,7 @@ __device__ void conv1d(
|
|||||||
const size_t dst_l = dst_i % l_out;
|
const size_t dst_l = dst_i % l_out;
|
||||||
|
|
||||||
const size_t src_idx0 = b_idx * src_s[0];
|
const size_t src_idx0 = b_idx * src_s[0];
|
||||||
T d = 0;
|
A d = 0;
|
||||||
for (size_t offset = 0; offset < k_size; ++offset) {
|
for (size_t offset = 0; offset < k_size; ++offset) {
|
||||||
const size_t src_l_plus = stride * dst_l + offset;
|
const size_t src_l_plus = stride * dst_l + offset;
|
||||||
if (k_over_2 <= src_l_plus && src_l_plus < l_in + k_over_2) {
|
if (k_over_2 <= src_l_plus && src_l_plus < l_in + k_over_2) {
|
||||||
@ -38,15 +38,15 @@ __device__ void conv1d(
|
|||||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2];
|
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2];
|
||||||
const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2];
|
const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2];
|
||||||
d += src[src_idx] * kernel[k_idx];
|
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dst[dst_i] = d;
|
dst[dst_i] = static_cast<T>(d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#define CONV1D_OP(TYPENAME, FN_NAME) \
|
#define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
const size_t src_numel, \
|
const size_t src_numel, \
|
||||||
const size_t num_dims, \
|
const size_t num_dims, \
|
||||||
@ -56,19 +56,19 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
const TYPENAME *kernel, \
|
const TYPENAME *kernel, \
|
||||||
TYPENAME *dst \
|
TYPENAME *dst \
|
||||||
) { \
|
) { \
|
||||||
conv1d(src_numel, num_dims, stride, info, src, kernel, dst); \
|
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, info, src, kernel, dst); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
CONV1D_OP(__nv_bfloat16, conv1d_bf16)
|
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
CONV1D_OP(__half, conv1d_f16)
|
CONV1D_OP(__half, float, conv1d_f16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
CONV1D_OP(float, conv1d_f32)
|
CONV1D_OP(float, float, conv1d_f32)
|
||||||
CONV1D_OP(double, conv1d_f64)
|
CONV1D_OP(double, double, conv1d_f64)
|
||||||
CONV1D_OP(uint8_t, conv1d_u8)
|
CONV1D_OP(uint8_t, uint8_t, conv1d_u8)
|
||||||
CONV1D_OP(uint32_t, conv1d_u32)
|
CONV1D_OP(uint32_t, uint32_t, conv1d_u32)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user