mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Handle arbitrary shapes in Tensor::new. (#718)
This commit is contained in:
@ -205,14 +205,29 @@ impl PyTensor {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<i64>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
|
||||
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
|
||||
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<f32>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
|
||||
let len = vs.len();
|
||||
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
|
||||
let len = vs.len();
|
||||
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
|
||||
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
||||
let len = vs.len();
|
||||
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else {
|
||||
let ty = vs.as_ref(py).get_type();
|
||||
Err(PyTypeError::new_err(format!(
|
||||
|
Reference in New Issue
Block a user