diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 949bded1..b6260108 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -197,8 +197,8 @@ impl Decoder { let (_, _, content_frames) = mel.shape().r3()?; let mut seek = 0; let mut segments = vec![]; - let start = std::time::Instant::now(); while seek < content_frames { + let start = std::time::Instant::now(); let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; let segment_size = usize::min(content_frames - seek, N_FRAMES); let mel_segment = mel.narrow(2, seek, segment_size)?; @@ -214,7 +214,7 @@ impl Decoder { duration: segment_duration, dr, }; - println!("{seek}: {segment:?} : Took {:?}", start.elapsed()); + println!("{seek}: {segment:?}, in {:?}", start.elapsed()); segments.push(segment) } Ok(segments) @@ -236,8 +236,9 @@ struct Args { #[arg(long)] revision: Option, - /// The input to be processed, in wav formats, will default to `jfk.wav` - /// https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav + /// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively + /// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following + /// repo: https://huggingface.co/datasets/Narsil/candle_demo/ #[arg(long)] input: Option, @@ -286,19 +287,27 @@ fn main() -> Result<()> { } else { let repo = Repo::with_revision(model_id, RepoType::Model, revision); let api = Api::new()?; + let sample = if let Some(input) = args.input { + if let Some(sample) = input.strip_prefix("sample:") { + api.get( + &Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset), + &format!("samples_{sample}.wav"), + )? + } else { + std::path::PathBuf::from(input) + } + } else { + println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav"); + api.get( + &Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset), + "samples_jfk.wav", + )? + }; ( api.get(&repo, "config.json")?, api.get(&repo, "tokenizer.json")?, api.get(&repo, "model.safetensors")?, - if let Some(input) = args.input { - std::path::PathBuf::from(input) - } else { - println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav"); - api.get( - &Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset), - "samples_jfk.wav", - )? - }, + sample, ) }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index 55ea7863..93ef56f3 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -1,7 +1,7 @@ #include "cuda_utils.cuh" #include -template +template __device__ void conv1d( const size_t src_numel, const size_t l_out, @@ -30,7 +30,7 @@ __device__ void conv1d( const size_t dst_l = dst_i % l_out; const size_t src_idx0 = b_idx * src_s[0]; - T d = 0; + A d = 0; for (size_t offset = 0; offset < k_size; ++offset) { const size_t src_l_plus = stride * dst_l + offset; if (k_over_2 <= src_l_plus && src_l_plus < l_in + k_over_2) { @@ -38,15 +38,15 @@ __device__ void conv1d( for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2]; const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2]; - d += src[src_idx] * kernel[k_idx]; + d += static_cast(src[src_idx]) * static_cast(kernel[k_idx]); } } } - dst[dst_i] = d; + dst[dst_i] = static_cast(d); } -#define CONV1D_OP(TYPENAME, FN_NAME) \ +#define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t src_numel, \ const size_t num_dims, \ @@ -56,19 +56,19 @@ extern "C" __global__ void FN_NAME( \ const TYPENAME *kernel, \ TYPENAME *dst \ ) { \ - conv1d(src_numel, num_dims, stride, info, src, kernel, dst); \ + conv1d(src_numel, num_dims, stride, info, src, kernel, dst); \ } \ #if __CUDA_ARCH__ >= 800 -CONV1D_OP(__nv_bfloat16, conv1d_bf16) +CONV1D_OP(__nv_bfloat16, float, conv1d_bf16) #endif #if __CUDA_ARCH__ >= 530 -CONV1D_OP(__half, conv1d_f16) +CONV1D_OP(__half, float, conv1d_f16) #endif -CONV1D_OP(float, conv1d_f32) -CONV1D_OP(double, conv1d_f64) -CONV1D_OP(uint8_t, conv1d_u8) -CONV1D_OP(uint32_t, conv1d_u32) +CONV1D_OP(float, float, conv1d_f32) +CONV1D_OP(double, double, conv1d_f64) +CONV1D_OP(uint8_t, uint8_t, conv1d_u8) +CONV1D_OP(uint32_t, uint32_t, conv1d_u32)