mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
PyO3: Add optional candle.onnx
module (#1282)
* Start onnx integration * Merge remote-tracking branch 'upstream/main' into feat/pyo3-onnx * Implement ONNXModel * `fmt` * add `onnx` flag to python ci * Pin `protoc` to `25.0` * Setup `protoc` in wheel builds * Build wheels with `onnx` * Install `protoc` in manylinux containers * `apt` -> `yum` * Download `protoc` via bash script * Back to `manylinux: auto` * Disable `onnx` builds for linux
This commit is contained in:
@ -19,12 +19,14 @@ extern crate accelerate_src;
|
||||
|
||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||
|
||||
mod utils;
|
||||
use utils::wrap_err;
|
||||
|
||||
mod shape;
|
||||
use shape::{PyShape, PyShapeWithHole};
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
mod onnx;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[pyclass(name = "Tensor")]
|
||||
@ -1559,6 +1561,14 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
|
||||
m.add_class::<PyONNXModel>()?;
|
||||
m.add_class::<PyONNXTensorDescriptor>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let utils = PyModule::new(py, "utils")?;
|
||||
@ -1567,6 +1577,12 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let nn = PyModule::new(py, "functional")?;
|
||||
candle_functional_m(py, nn)?;
|
||||
m.add_submodule(nn)?;
|
||||
#[cfg(feature = "onnx")]
|
||||
{
|
||||
let onnx = PyModule::new(py, "onnx")?;
|
||||
candle_onnx_m(py, onnx)?;
|
||||
m.add_submodule(onnx)?;
|
||||
}
|
||||
m.add_class::<PyTensor>()?;
|
||||
m.add_class::<PyQTensor>()?;
|
||||
m.add_class::<PyDType>()?;
|
||||
|
Reference in New Issue
Block a user