Make the Python Wrapper more Hackable and simplify Quantization (#1010)

* Some first `Module` implementations

* Add `state_dict` and `load_state_dict` functionality

* Move modules around and create `candle.nn.Linear`

* Add `nn.Embedding` and `nn.LayerNorm`

* Add BERT implementation

* Batch q-matmul

* Automatically dequantize `QTensors` if a `Tensor` is expected

* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality

* Unittests for `Module`, `Tensor` and `candle.utils`

* Add `pytorch` like slicing to `Tensor`

* Cleanup and BERT fixes

* `black` formatting + unit-test for `nn.Linear`

* Refactor slicing implementation
This commit is contained in:
Lukas Kreussel
2023-10-06 20:01:07 +02:00
committed by GitHub
parent b0442eff8a
commit 904bbdae65
25 changed files with 2426 additions and 182 deletions

View File

@ -0,0 +1,74 @@
import candle
from candle import Tensor
def test_tensor_can_be_constructed():
t = Tensor(42.0)
assert t.values() == 42.0
def test_tensor_can_be_constructed_from_list():
t = Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
assert t.values() == [3.0, 1, 4, 1, 5, 9, 2, 6]
def test_tensor_can_be_constructed_from_list_of_lists():
t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]])
assert t.values() == [[3.0, 1, 4, 1], [5, 9, 2, 6]]
def test_tensor_can_be_quantized():
t = candle.randn((16, 256))
for format in [
"q4_0",
"q4_1",
"q5_0",
"q5_1",
"q8_0",
"q2k",
"q3k",
"q4k",
"q5k",
"q8k",
]:
for formatted_format in [format.upper(), format.lower()]:
quant_t = t.quantize(formatted_format)
assert quant_t.ggml_dtype.lower() == format.lower()
assert quant_t.shape == t.shape
def test_tensor_can_be_indexed():
t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]])
assert t[0].values() == [3.0, 1.0, 4.0, 1.0]
assert t[1].values() == [5.0, 9.0, 2.0, 6.0]
assert t[-1].values() == [5.0, 9.0, 2.0, 6.0]
assert t[-2].values() == [3.0, 1.0, 4.0, 1.0]
def test_tensor_can_be_sliced():
t = Tensor([3.0, 1, 4, 10, 5, 9, 2, 6])
assert t[0:4].values() == [3.0, 1.0, 4.0, 10.0]
assert t[4:8].values() == [5.0, 9.0, 2.0, 6.0]
assert t[-4:].values() == [5.0, 9.0, 2.0, 6.0]
assert t[:-4].values() == [3.0, 1.0, 4.0, 10.0]
assert t[-4:-2].values() == [5.0, 9.0]
def test_tensor_can_be_sliced_2d():
t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]])
assert t[:, 0].values() == [3.0, 5]
assert t[:, 1].values() == [1.0, 9.0]
assert t[0, 0].values() == 3.0
assert t[:, -1].values() == [1.0, 6.0]
assert t[:, -4].values() == [3.0, 5]
assert t[..., 0].values() == [3.0, 5]
def test_tensor_can_be_scliced_3d():
t = Tensor([[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]]])
assert t[:, :, 0].values() == [[1, 5], [9, 13]]
assert t[:, :, 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
assert t[:, 0, 0].values() == [1, 9]
assert t[..., 0].values() == [[1, 5], [9, 13]]
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]

View File

@ -0,0 +1,51 @@
import candle
from candle import Tensor, QTensor
from candle.utils import load_safetensors, save_gguf, load_gguf, save_safetensors
from pathlib import Path
TEST_DIR = Path(__file__).parent.parent / "_workdir"
TEST_DIR.mkdir(exist_ok=True)
def test_can_roundtrip_safetensors():
tensors = {
"a": candle.randn((16, 256)),
"b": candle.randn((16, 16)),
}
file = str(TEST_DIR / "test.safetensors")
save_safetensors(file, tensors)
loaded_tensors = load_safetensors(file)
assert set(tensors.keys()) == set(loaded_tensors.keys())
for key in tensors.keys():
assert tensors[key].values() == loaded_tensors[key].values(), "Values are not equal"
assert tensors[key].shape == loaded_tensors[key].shape, "Shapes are not equal"
assert str(tensors[key].dtype) == str(loaded_tensors[key].dtype), "Dtypes are not equal"
def test_can_roundtrip_gguf():
metadata = {
"a": 1,
"b": "foo",
"c": [1, 2, 3],
"d": [[1, 2], [3, 4]],
}
tensors = {
"a": candle.randn((16, 256)).quantize("q4_0"),
"b": candle.randn((16, 16)).quantize("f32"),
}
file = str(TEST_DIR / "test.gguf")
save_gguf(file, tensors, metadata)
loaded_tensors, loaded_metadata = load_gguf(file)
assert set(metadata.keys()) == set(loaded_metadata.keys())
for key in metadata.keys():
assert metadata[key] == loaded_metadata[key]
assert set(tensors.keys()) == set(loaded_tensors.keys())
for key in tensors.keys():
assert tensors[key].dequantize().values() == loaded_tensors[key].dequantize().values(), "Values are not equal"
assert tensors[key].shape == loaded_tensors[key].shape, "Shapes are not equal"
assert str(tensors[key].ggml_dtype) == str(loaded_tensors[key].ggml_dtype), "Dtypes are not equal"