8 Commits

Author SHA1 Message Date
b20acd622c Update for pyo3 0.21. (#1985)
* Update for pyo3 0.21.

* Also adapt the RL example.

* Fix for the pyo3-onnx bindings...

* Print details on failures.

* Revert pyi.
2024-04-01 17:07:02 +02:00
ad73e93da2 Detach the tensors on batch-norm eval. (#1702)
* Detach the tensors on batch-norm eval.

* Fix pyo3 bindings.

* Black tweak.

* Formatting.

* Also update the pyo3-onnx formatting.

* Apply black.
2024-02-13 14:26:32 +01:00
1e86717bf2 Fix a couple typos (#1451)
* Mixtral quantized instruct.

* Fix a couple typos.
2023-12-17 05:20:05 -06:00
174b208052 PyO3: Better shape handling (#1143)
* Negative and `*args` shape handling

* Rename to `PyShapeWithHole` + validate that only one hole exists

* Regenerate stubs

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2023-10-29 15:41:44 +00:00
f9e93f5b69 Extend stub.py to accept external typehinting (#1102) 2023-10-17 11:07:26 +01:00
904bbdae65 Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations

* Add `state_dict` and `load_state_dict` functionality

* Move modules around and create `candle.nn.Linear`

* Add `nn.Embedding` and `nn.LayerNorm`

* Add BERT implementation

* Batch q-matmul

* Automatically dequantize `QTensors` if a `Tensor` is expected

* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality

* Unittests for `Module`, `Tensor` and `candle.utils`

* Add `pytorch` like slicing to `Tensor`

* Cleanup and BERT fixes

* `black` formatting + unit-test for `nn.Linear`

* Refactor slicing implementation
2023-10-06 19:01:07 +01:00
03e194123d Add return types to *.pyi stubs (#880)
* Start generating return types

* Finish tensor type hinting

* Add `save_gguf` to `utils`

* Typehint `quant-llama.py`
2023-09-17 22:11:01 +01:00
8658df3485 Generate *.pyi stubs for PyO3 wrapper (#870)
* Begin to generate typehints.

* generate correct stubs

* Correctly include stubs

* Add comments and typhints to static functions

* ensure candle-pyo3 directory

* Make `llama.rope.freq_base` optional

* `fmt`
2023-09-16 17:23:38 +01:00