mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Relax the contiguous check for cuda kernels. (#2000)
* Relax the contiguous check for cuda kernels. * Ensure contiguity for RNNs. * Unrelated fix for segment anything. * Better error message + allow concatenating empty slices.
This commit is contained in:
@ -99,7 +99,7 @@ pub trait WrapErr<O> {
|
|||||||
|
|
||||||
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
||||||
fn w(self) -> std::result::Result<O, crate::Error> {
|
fn w(self) -> std::result::Result<O, crate::Error> {
|
||||||
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())))
|
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1761,6 +1761,11 @@ impl BackendStorage for CudaStorage {
|
|||||||
let dev = &self.device;
|
let dev = &self.device;
|
||||||
let d1 = d1 as u32;
|
let d1 = d1 as u32;
|
||||||
let d2 = d2 as u32;
|
let d2 = d2 as u32;
|
||||||
|
// Nothing to copy so we exit early to avoid launching a kernel and some potential invalid
|
||||||
|
// argument with a null pointer.
|
||||||
|
if d1 == 0 || d2 == 0 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let dst_s = dst_s as u32;
|
let dst_s = dst_s as u32;
|
||||||
let src_s = src_s as u32;
|
let src_s = src_s as u32;
|
||||||
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
|
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
|
||||||
|
@ -14,7 +14,7 @@ __device__ bool is_contiguous(
|
|||||||
size_t acc = 1;
|
size_t acc = 1;
|
||||||
for (unsigned int d = 0; d < num_dims; d++) {
|
for (unsigned int d = 0; d < num_dims; d++) {
|
||||||
unsigned int dim_idx = num_dims - 1 - d;
|
unsigned int dim_idx = num_dims - 1 - d;
|
||||||
if (acc != strides[dim_idx]) {
|
if (dims[dim_idx] > 1 && acc != strides[dim_idx]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
acc *= dims[dim_idx];
|
acc *= dims[dim_idx];
|
||||||
|
@ -31,7 +31,7 @@ pub trait RNN {
|
|||||||
let (_b_size, seq_len, _features) = input.dims3()?;
|
let (_b_size, seq_len, _features) = input.dims3()?;
|
||||||
let mut output = Vec::with_capacity(seq_len);
|
let mut output = Vec::with_capacity(seq_len);
|
||||||
for seq_index in 0..seq_len {
|
for seq_index in 0..seq_len {
|
||||||
let input = input.i((.., seq_index, ..))?;
|
let input = input.i((.., seq_index, ..))?.contiguous()?;
|
||||||
let state = if seq_index == 0 {
|
let state = if seq_index == 0 {
|
||||||
self.step(&input, init_state)?
|
self.step(&input, init_state)?
|
||||||
} else {
|
} else {
|
||||||
|
@ -218,7 +218,8 @@ impl PromptEncoder {
|
|||||||
(Some(se_points), None) => se_points,
|
(Some(se_points), None) => se_points,
|
||||||
(None, Some(se_boxes)) => se_boxes,
|
(None, Some(se_boxes)) => se_boxes,
|
||||||
(None, None) => {
|
(None, None) => {
|
||||||
Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
|
let dev = self.no_mask_embed.embeddings().device();
|
||||||
|
Tensor::zeros((1, 0, self.embed_dim), DType::F32, dev)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user