Commit Graph

51 Commits

Author SHA1 Message Date
b20acd622c Update for pyo3 0.21. (#1985)
* Update for pyo3 0.21.

* Also adapt the RL example.

* Fix for the pyo3-onnx bindings...

* Print details on failures.

* Revert pyi.
2024-04-01 17:07:02 +02:00
143c481c20 Expose candle gather op in pyo3. (#1870) 2024-03-18 21:54:15 +01:00
ad73e93da2 Detach the tensors on batch-norm eval. (#1702)
* Detach the tensors on batch-norm eval.

* Fix pyo3 bindings.

* Black tweak.

* Formatting.

* Also update the pyo3-onnx formatting.

* Apply black.
2024-02-13 14:26:32 +01:00
403680f17d Quantized GGUF style (#1523)
* Metal quantized modifications proposal.

- Add a device param, wherever needed.
- Create new QMetal storage thing that implements QuantizedType.
- Update everywhere needed.

Fix Python.

Fixing examples.

Fix: fmt + clippy + stub.

Moving everything around.

Only missing the actual implems.

Fixing everything + adding dequantized kernels.

More work.

Fixing matmul.

Fmt + Clippy

Some clippy fixes.

Working state.

Q2K Metal -> Bugged (also present in GGML).
Q4K CPU -> Bugged (present previously, new test catch it).
Q5K CPU -> Bugged (present previously).
Q8_1 Both -> Never really implemented it seems
Q8K metal -> Never implemented in metal

Fixing Q2K bug (present in ggml).

* Cleanup.

* Fix the rebase.

* Removing the fences speeds everything up and *is* correct this time...

* Cleanup the fence.

* After rebase.

* Bad code removal.

* Rebase after phi2 merge + fix replit default to CPU.

* Making the CI happy.

* More happy tests.

---------

Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
2024-01-17 10:27:58 +01:00
1e86717bf2 Fix a couple typos (#1451)
* Mixtral quantized instruct.

* Fix a couple typos.
2023-12-17 05:20:05 -06:00
bfa7c8fc01 Implement the module trait directly for QMatMul. (#1372) 2023-11-25 10:09:45 +00:00
26c4e5bf1d Metal part 1 - Scaffolding for metal. (#1308)
* Metal part 1 - Scaffolding for metal.

* Remove tracing.
2023-11-10 08:35:48 +01:00
f3a4f3db76 PyO3: Add optional candle.onnx module (#1282)
* Start onnx integration

* Merge remote-tracking branch 'upstream/main' into feat/pyo3-onnx

* Implement ONNXModel

* `fmt`

* add `onnx` flag to python ci

* Pin `protoc` to `25.0`

* Setup `protoc` in wheel builds

* Build wheels with `onnx`

* Install `protoc` in manylinux containers

* `apt` -> `yum`

* Download `protoc` via bash script

* Back to `manylinux: auto`

* Disable `onnx` builds for linux
2023-11-08 06:37:50 +01:00
c05c0a8213 PyO3: Add equal and __richcmp__ to candle.Tensor (#1099)
* add `equal` to tensor

* add `__richcmp__` support  for tensors and scalars

* typo

* more typos

* Add `abs` + `candle.testing`

* remove duplicated `broadcast_shape_binary_op`

* `candle.i16` => `candle.i64`

* `tensor.nelements` -> `tensor.nelement`

* Cleanup `abs`
2023-10-30 15:17:28 +00:00
174b208052 PyO3: Better shape handling (#1143)
* Negative and `*args` shape handling

* Rename to `PyShapeWithHole` + validate that only one hole exists

* Regenerate stubs

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2023-10-29 15:41:44 +00:00
6a446d9d73 convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor

* separate tests for convert pytorch tensor
2023-10-25 19:39:14 +01:00
7bd0faba75 Add support for accelerate in the pyo3 bindings. (#1167) 2023-10-24 06:34:37 +01:00
eae94a451b PyO3: Add mkl support (#1159)
* Add `mkl` support

* Set `mkl` path on linux
2023-10-23 20:10:59 +01:00
b43ab6cd1d PyO3: Add None and Tensor indexing to candle.Tensor (#1098)
* Add proper `None` and `tensor` indexing

* Allow indexing via lists + allow tensor/list indexing outside of first dimension
2023-10-20 09:59:00 +01:00
6684b7127a PyO3: Add pytorch like .to() operator to candle.Tensor (#1100)
* add `.to()` operator

* Only allow each value to be provided once via `args` or `kwargs`
2023-10-19 21:46:21 +01:00
b355ab4e2e Always broadcast magic methods (#1101) 2023-10-17 10:57:12 +01:00
2c110ac7d9 Add the pooling operators to the pyo3 layer. (#1086) 2023-10-13 20:18:10 +01:00
904bbdae65 Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations

* Add `state_dict` and `load_state_dict` functionality

* Move modules around and create `candle.nn.Linear`

* Add `nn.Embedding` and `nn.LayerNorm`

* Add BERT implementation

* Batch q-matmul

* Automatically dequantize `QTensors` if a `Tensor` is expected

* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality

* Unittests for `Module`, `Tensor` and `candle.utils`

* Add `pytorch` like slicing to `Tensor`

* Cleanup and BERT fixes

* `black` formatting + unit-test for `nn.Linear`

* Refactor slicing implementation
2023-10-06 19:01:07 +01:00
089fc3b584 Improve the quantized whisper setup. (#1018)
* Improve the quantized whisper setup.

* Fix the config file paths.

* Use the standard matmul where possible.
2023-10-02 17:17:46 +01:00
03e194123d Add return types to *.pyi stubs (#880)
* Start generating return types

* Finish tensor type hinting

* Add `save_gguf` to `utils`

* Typehint `quant-llama.py`
2023-09-17 22:11:01 +01:00
8658df3485 Generate *.pyi stubs for PyO3 wrapper (#870)
* Begin to generate typehints.

* generate correct stubs

* Correctly include stubs

* Add comments and typhints to static functions

* ensure candle-pyo3 directory

* Make `llama.rope.freq_base` optional

* `fmt`
2023-09-16 17:23:38 +01:00
000487c36f Add a python function to save as safetensors. (#740) 2023-09-04 20:32:14 +01:00
20512ba408 Return the metadata in the gguf pyo3 bindings. (#729)
* Return the metadata in the gguf pyo3 bindings.

* Read the metadata in the quantized llama example.

* Get inference to work on gguf files.
2023-09-04 07:07:00 +01:00
84d003ff53 Handle arbitrary shapes in Tensor::new. (#718) 2023-09-02 19:59:21 +01:00
ad796eb4be More quantized llama in python. (#716)
* More quantized llama in python.

* Expose a couple more functions.

* Apply the last layer.

* Use the vocab from the ggml files.
2023-09-02 13:41:48 +01:00
e8e33752f4 Sketch a quantized llama using the pyo3 api. (#715)
* Sketch a quantized llama using the pyo3 api.

* Add more ops.

* Expose a few more functions to use in the quantized model.

* Rope embeddings.

* Get the forward pass to work.
2023-09-02 11:26:05 +01:00
1e5b2cc1d5 Add some quantized functions to pyo3. (#708) 2023-09-01 19:45:36 +02:00
2ed78ab336 Support for quantized tensors in the python api. (#706)
* Add more pyo3 support.

* Add some support for quantized tensors in pyo3.

* Add an arc layer on qmatmul.

* Add the quantized matmul.

* Quantization support.

* More quantization support.

* Test the python quantization.
2023-09-01 15:53:42 +01:00
237323c2bc Cleanup the pyo3 setup. (#705) 2023-09-01 14:26:18 +01:00
e21c686cdc Fixes for clippy 1.72. (#587) 2023-08-24 17:46:17 +01:00
9a5c7db91a Add support for i64 (#563)
* Add the i64 dtype.

* Adapt the cuda kernels.
2023-08-23 10:42:19 +01:00
93cfe5642f Pyo3 dtype (#327)
* Better handling of dtypes in pyo3.

* More pyo3 dtype.
2023-08-06 10:17:43 +01:00
88bd3b604a Add some tensor creation functions to the pyo3 bindings. (#326) 2023-08-06 06:50:33 +01:00
2bfa791336 Use the same default as pytorch for sum. (#164) 2023-07-13 21:32:32 +01:00
50b0946a2d Tensor mutability (#154)
* Working towards tensor mutability.

* Use a ref-cell to provide tensor mutability.
2023-07-13 11:04:40 +01:00
5b0ee2e0ba Get cuda to work on pyo3. 2023-07-02 21:04:11 +01:00
fbfe74caab Preliminary pyo3 support for device. 2023-07-02 20:42:55 +01:00
bdb257ceab Add the tensor function. 2023-07-02 20:15:50 +01:00
78871ffe38 Add dtype support. 2023-07-02 20:12:26 +01:00
5b8c6764b0 Add matmul/where_cond. 2023-07-02 07:34:14 +01:00
9a9858bbe0 Expose a couple more ops. 2023-07-02 07:30:00 +01:00
dfe197f791 Handle more input types to create tensors. 2023-07-02 07:19:46 +01:00
4a28dcf828 Rename the method. 2023-07-02 07:08:11 +01:00
c62cb73a7f Support higher order shapes for conversions. 2023-07-02 07:07:22 +01:00
fa58c7643d Add a trait to avoid repeating the dtype matching. 2023-07-02 06:58:10 +01:00
2370b1675d More pyo3. 2023-07-01 22:15:58 +01:00
86df4ad79c Get shape to return a tuple. 2023-07-01 21:34:38 +01:00
fbbde5b02c Add some binary operators. 2023-07-01 21:27:35 +01:00
42d1a52d01 Add two methods. 2023-07-01 20:55:15 +01:00
52db2a6849 Apply rustfmt. 2023-07-01 20:37:28 +01:00