mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
786 Commits
metal4.5
...
0.9.0-alph
Author | SHA1 | Date | |
---|---|---|---|
fb660b8d43 | |||
2f9606b187 | |||
f3a73f80d1 | |||
b44d38de0e | |||
d9198deb37 | |||
15ed0b11ce | |||
34505fdf3a | |||
d7b7ce16e4 | |||
19fb6dac1f | |||
acc5bd335f | |||
eb478ece92 | |||
d339b01726 | |||
2f3bf42bcb | |||
e3370c6316 | |||
338f6a102e | |||
bc33df77e1 | |||
cf9d7bf24c | |||
9d31361c4f | |||
648596c073 | |||
d9904a3baf | |||
d6db305829 | |||
b4daa03e59 | |||
9541467d6b | |||
6429609090 | |||
ba473290da | |||
59c26195db | |||
cb02b389d5 | |||
0d4097031c | |||
10853b803c | |||
f3d472952f | |||
67b85f79f1 | |||
0b24f7f0a4 | |||
3afb04925a | |||
cbf5fc80c2 | |||
468d1d525f | |||
c930ab7e1a | |||
111edbc4ea | |||
e286cf7cc9 | |||
e4ffb85228 | |||
37db86ff79 | |||
add3a714aa | |||
26c16923b9 | |||
9e8bf70333 | |||
ac9cdbd448 | |||
e6cc76fc37 | |||
fd7f7242a1 | |||
3ddd20a5aa | |||
2423d633fc | |||
7c2449f623 | |||
0af3e428ec | |||
43017539ab | |||
e142bf9530 | |||
d2c53f4f2f | |||
2a2852d1c1 | |||
8f20f2a722 | |||
ab9019425a | |||
da02b59516 | |||
27996a1a9e | |||
1a32107fab | |||
333d94a19a | |||
3164a19a5d | |||
e6cd499e98 | |||
77db8396d0 | |||
85f0aaefe5 | |||
e4c3a71f11 | |||
17cbbe4286 | |||
6fd2f63a15 | |||
efd0e6822f | |||
158817f230 | |||
309cd0f7c7 | |||
ab7ff7081e | |||
461e8c1685 | |||
2344c4e4b8 | |||
32defdb7d5 | |||
236c35e578 | |||
6f8351dfda | |||
57f41da13b | |||
cbaa0ad46f | |||
b12c7c2888 | |||
94ffc2ec6f | |||
7354afc673 | |||
2a705e6f37 | |||
a594ef669c | |||
71cd6d5533 | |||
d60eba1408 | |||
e38e2a85dd | |||
460616fc84 | |||
91f1f019b1 | |||
cd639131f0 | |||
11aa30be10 | |||
1be6b090c7 | |||
62ced44ea9 | |||
5c2f893e5a | |||
67cab7d6b8 | |||
1807be84f4 | |||
145aa7193c | |||
6f715f9256 | |||
dba7a9c93e | |||
b52c2c6050 | |||
4f59ed38b0 | |||
54e7fc3c97 | |||
23ed8a9ded | |||
21c686387c | |||
b4deb5c5a9 | |||
c12db594e3 | |||
f86f4d6224 | |||
3159f91b90 | |||
1a0f9ccf16 | |||
e86565624b | |||
386fd8abb4 | |||
12d7e7b145 | |||
a3f200e369 | |||
00d8a0c178 | |||
f689ce5d39 | |||
0ed24b9852 | |||
06350c31c7 | |||
9453cc3095 | |||
3769206583 | |||
e2b6b367fa | |||
6454597943 | |||
3fba2b5fc4 | |||
530ab96036 | |||
7ac0de15a9 | |||
d232e132f6 | |||
139ff56aeb | |||
498bc2cdc9 | |||
0e2c8c17fb | |||
594d984f9c | |||
37e0ab8c64 | |||
07849aa595 | |||
3699c1a053 | |||
a2e9d41b20 | |||
7c09215ef4 | |||
dcd83336b6 | |||
a01aa89799 | |||
3d1dc06cdb | |||
f553ab5eb4 | |||
41ade774e8 | |||
6eab6b57f5 | |||
ca7cf5cb3b | |||
0d96ec31e8 | |||
937e8eda74 | |||
edf7668291 | |||
e4a96f9e7c | |||
f856b5c3a7 | |||
d2e432914e | |||
410c89f72a | |||
56aacb05da | |||
6faecaa616 | |||
90d04ff622 | |||
7b60bda4ed | |||
936300678d | |||
f479840ce6 | |||
fd08d3d0a4 | |||
a2bcc227df | |||
def4c6cdee | |||
888d886dd8 | |||
6110ad8d4f | |||
aa35bf2ff5 | |||
724650446c | |||
dfe9a00683 | |||
683ab698de | |||
2f49e1b534 | |||
0ebb38813b | |||
3a3c48b14b | |||
261ed65f36 | |||
62525e8352 | |||
2c25754281 | |||
ed48f54b54 | |||
ad8a4c5e5a | |||
c3c392f45c | |||
a0184a4fe4 | |||
10d47183c0 | |||
d01207dbf3 | |||
8097559c1a | |||
829dcfa8dc | |||
c2fca0ca11 | |||
844d45cde4 | |||
af2104078f | |||
5fc4f17727 | |||
c58c5d5b01 | |||
382c6b51af | |||
6eea45a761 | |||
ebf722b446 | |||
c09afc211c | |||
b60faebea4 | |||
72d649058b | |||
0cb0bd1dfa | |||
afb6575835 | |||
5635650d38 | |||
13b2a8a4a0 | |||
e3261216b1 | |||
c02b7c3272 | |||
86613c00e2 | |||
29e25c458d | |||
aafa24ed93 | |||
fdc2622686 | |||
ccdbe87639 | |||
2ec8729d51 | |||
e3c146ada6 | |||
1e96b8b695 | |||
a8288b7a72 | |||
6070278a31 | |||
b47c0bc475 | |||
14fd2d97e0 | |||
31a1075f4b | |||
236b29ff15 | |||
58197e1896 | |||
736d8eb752 | |||
7cff5898ec | |||
b75ef051cf | |||
c1b9e07e35 | |||
69fdcfe96a | |||
2b75dd9551 | |||
53ce65f706 | |||
68aa9c7320 | |||
35e5f31397 | |||
d3fe989d08 | |||
14db029494 | |||
6e6c1c99b0 | |||
b7d9af00cc | |||
59bbc0d287 | |||
dfdce2b602 | |||
500c9f2882 | |||
2be9bd211e | |||
89eae41efd | |||
c0a559d427 | |||
aa7ac1832d | |||
19db6b9723 | |||
0fcb40b229 | |||
6991a37b94 | |||
9ca277a9d7 | |||
2e9c010609 | |||
ac51f477eb | |||
d4b6f6eef6 | |||
957d604a78 | |||
ce90287f45 | |||
1ba87a9450 | |||
bd80078acf | |||
fea46cb719 | |||
8696cf6494 | |||
4a52aeb437 | |||
24d54d0ff9 | |||
636eff652a | |||
0f5cbb08b3 | |||
ddafc61055 | |||
a925ae6bc6 | |||
6056fd5c90 | |||
ebc9aa60bc | |||
2489a606fe | |||
3c815b1dca | |||
42891cc613 | |||
f25173d68b | |||
6a4741bbf9 | |||
30cdd769f9 | |||
d74fbed334 | |||
c63048d374 | |||
a226a9736b | |||
25960676ca | |||
9cd54aa5d4 | |||
eec11ce2ce | |||
9182f9f5c2 | |||
ecff05d72b | |||
7f1ba8038c | |||
74e9e41911 | |||
e27aac0a06 | |||
a3dd87f15e | |||
242e006bbb | |||
6baa1d486b | |||
36cf54525d | |||
2b10aaa05d | |||
9f804af29d | |||
54ff971e35 | |||
b9fac7ec00 | |||
f65e90e7ef | |||
d39462856b | |||
cb180eb23a | |||
9182c828e6 | |||
3f13ad3d79 | |||
cd4d941ed1 | |||
03344d3c19 | |||
1ec3b2cc18 | |||
f7773d498a | |||
7abc3b8cd7 | |||
46012ed31f | |||
f3fade3b03 | |||
ea260aeffd | |||
0814dfd148 | |||
3ceca9901a | |||
1df2bddccf | |||
6f0b807ffd | |||
d54e02d73d | |||
45e235a747 | |||
31cf64147b | |||
77ea479a18 | |||
72e7ca529a | |||
7ff921c538 | |||
9b8537a62f | |||
7ebc3548e1 | |||
eefc1c77ef | |||
01545f7303 | |||
349c3e806a | |||
bdaa34216a | |||
cc80e065e5 | |||
13c64f6828 | |||
21f82a5155 | |||
9cff7bc3f4 | |||
d9bc5ec151 | |||
84328e2b60 | |||
82b641fd27 | |||
01794dc16e | |||
a75cd8164f | |||
b13a82a438 | |||
59b18d974e | |||
89f53b9d7b | |||
a09d451d11 | |||
fa06f5f5f9 | |||
09d4845aa8 | |||
a0d03aded1 | |||
3bbb88fcb4 | |||
ed7b99f525 | |||
287013ef28 | |||
eb26e2467e | |||
c68ed8963f | |||
e5c8b88f90 | |||
805f3be8e1 | |||
3b429f3023 | |||
96a48e5cc4 | |||
6cf82fd7a3 | |||
cfab6e7616 | |||
11d4a3c588 | |||
9d3f1c8af5 | |||
7211009179 | |||
6fadaf2eff | |||
8a05743a21 | |||
b2e816752b | |||
618ecf5e23 | |||
267601eec1 | |||
08a15cb79e | |||
c388be93e7 | |||
d22f1d4f4e | |||
0067fe00a8 | |||
587ee3bb6f | |||
dd78422701 | |||
9215e9ce8c | |||
52ae332910 | |||
8b390ddd29 | |||
c97d639fa0 | |||
b45c710dbf | |||
9c532aef47 | |||
f7a6468238 | |||
2b93dffe64 | |||
e6ee7ba4d4 | |||
1690ab45d2 | |||
8de0ce6cba | |||
ce6d08df94 | |||
2817643db9 | |||
4d14777673 | |||
f135b7963d | |||
af955f260c | |||
8ad822a983 | |||
e198bb0816 | |||
f7d5bf5b97 | |||
c119600d6e | |||
c449f65b12 | |||
db7dbf3071 | |||
4ecedb1598 | |||
53e5380bf6 | |||
50e49ecc5f | |||
4c88c3ce06 | |||
8b8fb630df | |||
fb805b8ca2 | |||
79e3bec789 | |||
e6d412b156 | |||
26cbbf8d84 | |||
2bf413caa3 | |||
3ad4770eb6 | |||
a0460cd2b1 | |||
b81ecf712d | |||
a4d5a414e3 | |||
798e0335cd | |||
718671a0d5 | |||
c5fe4a7f89 | |||
7f354473cf | |||
33c9b66554 | |||
9fd52b3b71 | |||
e662431acf | |||
ab892274d1 | |||
b869a659ec | |||
88f7793598 | |||
2ac302a5d1 | |||
ace282e5c2 | |||
c87381fc96 | |||
c5626b8271 | |||
e6a5b82ba6 | |||
5aebe53dd2 | |||
f76bb7794a | |||
30b145150f | |||
f48c07e242 | |||
8967c46563 | |||
1e46cf8b19 | |||
bd8db2a771 | |||
318d143224 | |||
2be1a35710 | |||
26226068a4 | |||
cd6b9e317c | |||
08c049def3 | |||
d17b2cdad9 | |||
fb918a23c8 | |||
b23436bf90 | |||
be9c200cbb | |||
ea0d8d3753 | |||
308ea070ed | |||
b20acd622c | |||
5522bbc57c | |||
888c09a3db | |||
318cb82f16 | |||
c7557b65dc | |||
cd29c7ccd4 | |||
f9954b73ba | |||
eead1dcead | |||
92f81d2fcb | |||
3144150b8d | |||
b190fd8592 | |||
efe4a0c84b | |||
665da30487 | |||
356a170ae9 | |||
7ecbc6d50b | |||
8ad12a0e81 | |||
eb1b27abcd | |||
708e422456 | |||
c5092f2c29 | |||
cdc8b57b5c | |||
b0340d72ec | |||
b3484e7a5e | |||
ada5d7c096 | |||
13ae5a34c7 | |||
ab86cd37c8 | |||
a9abde5f93 | |||
75b6d4b0da | |||
66f0a4eeea | |||
4523ecfb2a | |||
f5dfe883d7 | |||
196765e995 | |||
60676780a9 | |||
d3a8d291d5 | |||
cd254074f3 | |||
e7f8e72588 | |||
1b98f84a2b | |||
cf7d7fcf2f | |||
8c0db87992 | |||
e2b4829531 | |||
5e70821dd0 | |||
a62a97340c | |||
fdfe8fd129 | |||
790037390c | |||
6f877592a7 | |||
cc856db9ce | |||
fc1fe5e45b | |||
32f567bac4 | |||
fee33b45c2 | |||
6708870e63 | |||
a00e24d752 | |||
c07e4057ab | |||
c0bdd9c7a6 | |||
9563a5fee4 | |||
ec97c98e81 | |||
bb3ee48039 | |||
0c11e055be | |||
18036c6ccb | |||
0fddec762e | |||
74b7f59261 | |||
af7f8b87d3 | |||
b219903d0f | |||
469635a3eb | |||
455c42aa72 | |||
2a8679509e | |||
143c481c20 | |||
f115895b9e | |||
90fc82211f | |||
6a966cf9e0 | |||
04a61a9c72 | |||
58605252e8 | |||
d365ef32d9 | |||
754fa1e813 | |||
184105792f | |||
a15f859ab4 | |||
e316cb6997 | |||
ce9fbc3682 | |||
db8b24ae92 | |||
74bf6994b1 | |||
cdc4c172c4 | |||
e1f9c3776d | |||
3318fe30fb | |||
2bb9c683b9 | |||
ff03fd3fb3 | |||
df5f69444e | |||
0c5eecbc0f | |||
56c9d3ee7b | |||
dd00482ea3 | |||
936f6a4840 | |||
3440cec3a0 | |||
e7fc1daa21 | |||
be5b68cd0b | |||
ea984d0421 | |||
9634583781 | |||
758366160e | |||
0a3487a776 | |||
0c09d10f32 | |||
8a99cf7dd2 | |||
bd9ab9bc04 | |||
8cc0a183ba | |||
6530932285 | |||
924ccae30c | |||
60dc72b96b | |||
20abb72fec | |||
ca5d727ba2 | |||
09e0148cce | |||
de11623752 | |||
21f1d04976 | |||
4fff5b51f5 | |||
314630638d | |||
3e3def4134 | |||
6980774a91 | |||
64d4038e4f | |||
979deaca07 | |||
b485e4b6ee | |||
2c95b7394a | |||
4fd00b8900 | |||
57267cd536 | |||
60ee5cfd4d | |||
56e44aabe3 | |||
d0aca6c3c6 | |||
15e8644149 | |||
0c49e95dfb | |||
205767f9de | |||
5e526abc8c | |||
6400e1b0a0 | |||
32544a2ad6 | |||
badf886583 | |||
918136ba46 | |||
1a6043af51 | |||
2f22afd80e | |||
8d04f70f4d | |||
eeb7e2b683 | |||
11ea7aac4d | |||
32eb56d6b3 | |||
28057781aa | |||
544018b6d0 | |||
c753f72c85 | |||
8013b50829 | |||
45d5322d62 | |||
a2cb2edead | |||
fc67d878bb | |||
3ba37443e5 | |||
1fb728772d | |||
cb86b0c82c | |||
6284ad784c | |||
678d44a7f6 | |||
41416d2376 | |||
5ebcfeaf0f | |||
7c7400fb63 | |||
058a910d0e | |||
26fe162ab5 | |||
121a71e01f | |||
2d5f2a728d | |||
68f7655895 | |||
b60064780d | |||
14010a8498 | |||
0de0795220 | |||
c1b418586c | |||
ad73e93da2 | |||
13c67226e6 | |||
d0aa197b07 | |||
274bf11633 | |||
1e26d539d9 | |||
74497e6bf7 | |||
8ab384e63d | |||
27ffd644a9 | |||
bf20cc854c | |||
42ce593ec6 | |||
67589791d2 | |||
1c8d61f051 | |||
90447bc993 | |||
40ce16001b | |||
5657e596cd | |||
0dee8ea19b | |||
9cadd4e644 | |||
020a979de2 | |||
cdc3823d8f | |||
e5eb9602d0 | |||
b75e8945bc | |||
a90fc5ca5a | |||
adfae2460a | |||
678f64dd27 | |||
b545f54a19 | |||
1ba11f22d6 | |||
982722019b | |||
a83ca2ece0 | |||
153c940a9c | |||
50be8a98ba | |||
58cc896e69 | |||
5cdd84e0f6 | |||
a510ddec4e | |||
d32abbce53 | |||
dfab45e1c8 | |||
96bc704d17 | |||
a52d407ae6 | |||
9e824ec810 | |||
beadb1b434 | |||
6d83d42efb | |||
b6afb46601 | |||
fd7c856564 | |||
73d79e6092 | |||
b1879f17f6 | |||
4f79f5df8a | |||
1cf34368b7 | |||
17e6e2d7ee | |||
80b1c689f9 | |||
db923517b3 | |||
403680f17d | |||
86a8e58897 | |||
5270224f40 | |||
7e3349d7c3 | |||
1257fc6719 | |||
ea36f3b11f | |||
79478ff5a1 | |||
86b7c01b30 | |||
bdd8107fda | |||
ecf88a6d38 | |||
e6d86b0819 | |||
88618255cb | |||
539ead927a | |||
a46864bd56 | |||
bafe95b660 | |||
a3d92ab226 | |||
e90bcdcc7c | |||
8e06bfb4fd | |||
6242276c09 | |||
e06e8d0dbe | |||
e63bb8661b | |||
41915184bb | |||
c1876b8041 | |||
85e5680277 | |||
1327419776 | |||
402349d120 | |||
9f0c99f0c1 | |||
0fc95c9f0c | |||
2480c5dbdd | |||
63944714f2 | |||
d3bdd788cf | |||
ae06cb74bb | |||
a897fda74e | |||
1f1179913a | |||
6e98cf2a92 | |||
2cc1247999 | |||
edf3fcd1c4 | |||
53e4755015 | |||
87efb5d8eb | |||
ad181f9cdc | |||
88945f2c22 | |||
12b2a337f3 | |||
fb05af4c42 | |||
ad075a5f7e | |||
0eb90ed783 | |||
89b5a06858 | |||
3f04a79ada | |||
30313c3081 | |||
e72d52b1a2 | |||
b4cb982e49 | |||
6ebe043273 | |||
6bf52b9fdf | |||
84250bf52f | |||
8d1a57c9a0 | |||
955e63c803 | |||
3a7304cb0d | |||
fa3ea98ba9 | |||
135ae5f3eb | |||
41614b4a9b | |||
03ce8caf40 | |||
b0fe5e4453 | |||
1fb2dd905c | |||
a0facd0e67 | |||
4290b81244 | |||
51e577a682 | |||
0a245e6fa4 | |||
87d7f81b43 | |||
4373534d59 | |||
f4a2787217 | |||
488e02a3f6 | |||
adc95ca2bf | |||
4907c63ea1 | |||
d76ac20e0e | |||
f5c98f22c7 | |||
5b12fbb143 | |||
cc06ba2294 | |||
a6bd0b47a5 | |||
b59b1b2bb6 | |||
3922b42c18 | |||
1e442d4bb9 | |||
cd889c0f8a | |||
8e93e76a91 | |||
b3e838f3e2 | |||
8bf892403a | |||
d35f0a1376 | |||
65cb90bd40 | |||
996a7f2e24 | |||
3071ea6c3e | |||
37c539f2b7 | |||
eae3a20d43 | |||
13a5d15ebc | |||
1505d85276 | |||
95e18ef675 | |||
7135791dd5 | |||
88589d8815 | |||
5b35fd0fcf | |||
ba1fae590e | |||
78d982e1bd | |||
d8b9a727fc | |||
ceb78d3e28 | |||
f6408a3779 | |||
10d94659c3 | |||
563a79afa1 | |||
8ede5f4210 | |||
9fc210fae8 | |||
9b5e4843a6 | |||
03641293ee | |||
064ba17bd7 | |||
e8ee253ee0 | |||
8bd3d6b94b | |||
6a3ca7da0c | |||
96f1a28e39 | |||
586b6f6fff | |||
e4b0cc59f5 | |||
0a6e0a8c9a | |||
972903021c | |||
94817dac56 | |||
1e86717bf2 | |||
c630622a07 | |||
c4cfcf1539 | |||
1782e93de6 | |||
cfdf9640a3 | |||
e12cbfd73b | |||
30a958e5dd | |||
614842b311 | |||
79eab519fd | |||
6bc92e63cb | |||
aa04015098 | |||
8b5059e951 | |||
26540641c1 | |||
34d83377f6 | |||
77197379cc | |||
916a8c5464 | |||
243e83f2b9 | |||
cf27868b57 | |||
40c3e1bd5a | |||
ece4c69a68 | |||
4eeaf205d6 | |||
f419a38e1a | |||
361f2ad2af | |||
e60f9b5dfc | |||
7be982f6f7 | |||
104e196d46 | |||
5e33c85c8f | |||
2b3a018be7 | |||
931432ed55 | |||
0404a3eb5b | |||
a9d0657432 | |||
4cb443d00a | |||
87dc559817 | |||
77252ffb82 | |||
18eb87f25f | |||
da0af3cb3e | |||
9bd94c1ffa | |||
803ac8405b | |||
6e25822d4f | |||
236b820e28 | |||
2648e797c2 | |||
b5c283e86f | |||
8418154ee0 | |||
99b7273b03 | |||
16161145ae | |||
0738df5290 | |||
37bf1ed012 | |||
dd40edfe73 | |||
5aa1a65dab |
7
.github/dependabot.yml
vendored
Normal file
7
.github/dependabot.yml
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "cargo"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
40
.github/workflows/book-cd.yml
vendored
40
.github/workflows/book-cd.yml
vendored
@ -1,40 +0,0 @@
|
||||
name: Deploy Rust book
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write # To push a branch
|
||||
pull-requests: write # To create a PR from that branch
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install latest mdbook
|
||||
run: |
|
||||
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
|
||||
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
|
||||
mkdir mdbook
|
||||
curl -sSL $url | tar -xz --directory=./mdbook
|
||||
echo `pwd`/mdbook >> $GITHUB_PATH
|
||||
- name: Deploy GitHub Pages
|
||||
run: |
|
||||
# This assumes your book is in the root of your repository.
|
||||
# Just add a `cd` here if you need to change to another directory.
|
||||
cd candle-book
|
||||
mdbook build
|
||||
git worktree add gh-pages
|
||||
git config user.name "Deploy from CI"
|
||||
git config user.email ""
|
||||
cd gh-pages
|
||||
# Delete the ref to avoid keeping history.
|
||||
git update-ref -d refs/heads/gh-pages
|
||||
rm -rf *
|
||||
mv ../book/* .
|
||||
git add .
|
||||
git commit -m "Deploy $GITHUB_SHA to gh-pages"
|
||||
git push --force --set-upstream origin gh-pages
|
29
.github/workflows/book.yml
vendored
29
.github/workflows/book.yml
vendored
@ -1,29 +0,0 @@
|
||||
name: CI
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Test candle-book
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write # To push a branch
|
||||
pull-requests: write # To create a PR from that branch
|
||||
steps:
|
||||
- uses: actions/checkout@master
|
||||
- name: Install Rust
|
||||
run: |
|
||||
rustup set profile minimal
|
||||
rustup toolchain install stable
|
||||
rustup default stable
|
||||
- name: Install latest mdbook
|
||||
run: |
|
||||
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
|
||||
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
|
||||
mkdir bin
|
||||
curl -sSL $url | tar -xz --directory=bin
|
||||
echo "$(pwd)/bin" >> $GITHUB_PATH
|
||||
- name: Run tests
|
||||
run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/
|
||||
|
||||
|
73
.github/workflows/ci_cuda.yaml
vendored
73
.github/workflows/ci_cuda.yaml
vendored
@ -5,47 +5,16 @@ on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
start-runner:
|
||||
name: Start self-hosted EC2 runner
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||
EC2_INSTANCE_TYPE: g5.xlarge
|
||||
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
||||
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
||||
outputs:
|
||||
label: ${{ steps.start-ec2-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Start EC2 runner
|
||||
id: start-ec2-runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: start
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
||||
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
||||
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
||||
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
||||
aws-resource-tags: > # optional, requires additional permissions
|
||||
[
|
||||
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
||||
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
||||
]
|
||||
|
||||
test-cuda:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
needs: start-runner # required to start the main job when the runner is ready
|
||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
|
||||
options: --gpus 0
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
@ -56,32 +25,10 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Install dependencies
|
||||
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
|
||||
- name: Install Rust Stable
|
||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
uses: actions-rust-lang/setup-rust-toolchain@v1
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
|
||||
- name: Test (cuda)
|
||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||
stop-runner:
|
||||
name: Stop self-hosted EC2 runner
|
||||
needs:
|
||||
- start-runner
|
||||
- test-cuda
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Stop EC2 runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: stop
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
label: ${{ needs.start-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||
run: cargo test --features cuda
|
||||
|
BIN
.github/workflows/maturin.yml
vendored
BIN
.github/workflows/maturin.yml
vendored
Binary file not shown.
6
.github/workflows/python.yml
vendored
6
.github/workflows/python.yml
vendored
@ -18,9 +18,9 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest] # For now, only test on Linux
|
||||
steps:
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
@ -65,4 +65,4 @@ jobs:
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
python -m pytest -s -v tests
|
||||
python -m pytest -s -v tests
|
||||
|
21
.github/workflows/rust-ci.yml
vendored
21
.github/workflows/rust-ci.yml
vendored
@ -1,6 +1,6 @@
|
||||
on:
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
@ -15,7 +15,10 @@ jobs:
|
||||
os: [ubuntu-latest, windows-latest, macOS-latest]
|
||||
rust: [stable]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
@ -34,7 +37,13 @@ jobs:
|
||||
os: [ubuntu-latest, windows-latest, macOS-latest]
|
||||
rust: [stable]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Delete huge unnecessary tools folder
|
||||
if: runner.os == 'Linux'
|
||||
run: rm -rf /opt/hostedtoolcache
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
@ -49,7 +58,7 @@ jobs:
|
||||
name: Rustfmt
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
@ -65,7 +74,7 @@ jobs:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
|
15
.github/workflows/trufflehog.yml
vendored
Normal file
15
.github/workflows/trufflehog.yml
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
on:
|
||||
push:
|
||||
|
||||
name: Secret Leaks
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
10
.gitignore
vendored
10
.gitignore
vendored
@ -9,6 +9,10 @@ target/
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# editor config
|
||||
.helix
|
||||
.vscode
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
@ -36,3 +40,9 @@ candle-wasm-examples/*/package-lock.json
|
||||
candle-wasm-examples/**/config*.json
|
||||
.DS_Store
|
||||
.idea/*
|
||||
__pycache__
|
||||
out.safetensors
|
||||
out.wav
|
||||
bria.mp3
|
||||
bria.safetensors
|
||||
bria.wav
|
||||
|
@ -63,7 +63,7 @@ This documents the main changes to the `candle` crate.
|
||||
[760](https://github.com/huggingface/candle/pull/760).
|
||||
- Add the Segment-Anything Model (SAM) as an example
|
||||
[773](https://github.com/huggingface/candle/pull/773).
|
||||
- TinyViT backbone for the segemnt anything example
|
||||
- TinyViT backbone for the segment anything example
|
||||
[787](https://github.com/huggingface/candle/pull/787).
|
||||
- Shape with holes support
|
||||
[770](https://github.com/huggingface/candle/pull/770).
|
||||
|
50
Cargo.toml
50
Cargo.toml
@ -3,14 +3,15 @@ members = [
|
||||
"candle-core",
|
||||
"candle-datasets",
|
||||
"candle-examples",
|
||||
"candle-book",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/*",
|
||||
"candle-wasm-tests",
|
||||
"tensor-tools",
|
||||
]
|
||||
exclude = [
|
||||
"candle-book",
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
"candle-metal-kernels",
|
||||
@ -19,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.3.1"
|
||||
version = "0.9.0-alpha.2"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -28,40 +29,53 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.2" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.2" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.2" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.2" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.2" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.2" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.2" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.2" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
imageproc = { version = "0.23.0", default-features = false }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.15.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.4.1"
|
||||
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
hound = "3.5.1"
|
||||
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
||||
imageproc = { version = "0.24.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "45.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
parquet = { version = "51.0.0" }
|
||||
rand = "0.9.0"
|
||||
rand_distr = "0.5.1"
|
||||
rayon = "1.7.0"
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
safetensors = "0.3.1"
|
||||
safetensors = "0.4.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.13.4", default-features = false }
|
||||
tokenizers = { version = "0.21.0", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
ug = "0.3.1"
|
||||
ug-cuda = "0.3.1"
|
||||
ug-metal = "0.3.1"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||
zip = { version = "1.1.1", default-features = false }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
102
README.md
102
README.md
@ -2,7 +2,8 @@
|
||||
[](https://discord.gg/hugging-face-879548962464493619)
|
||||
[](https://crates.io/crates/candle-core)
|
||||
[](https://docs.rs/candle-core)
|
||||

|
||||
[](https://github.com/huggingface/candle/blob/main/LICENSE-MIT)
|
||||
[](https://github.com/huggingface/candle/blob/main/LICENSE-APACHE)
|
||||
|
||||
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
|
||||
and ease of use. Try our online demos:
|
||||
@ -54,20 +55,37 @@ These online demos run entirely in your browser:
|
||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
|
||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||
- [Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
|
||||
|
||||
We also provide a some command line based examples using state of the art models:
|
||||
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
|
||||
the SOLAR-10.7B variant.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
||||
- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level
|
||||
- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM
|
||||
- [Gemma v1 and v2](./candle-examples/examples/gemma/): 2b and 7b+/9b general LLMs from Google Deepmind.
|
||||
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
||||
Griffin based models from Google that mix attention with a RNN like state.
|
||||
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
|
||||
2.7b, and 3.8b general LLMs with performance on par with 7b models.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets.
|
||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
||||
- [Mamba](./candle-examples/examples/mamba/): an inference only
|
||||
implementation of the Mamba state space model.
|
||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
better performance than all publicly available 13b models as of 2023-09-28.
|
||||
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
||||
much faster inference.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/) and
|
||||
[StarCoder2](./candle-examples/examples/starcoder2/): LLM specialized to code generation.
|
||||
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
|
||||
- [RWKV v5 and v6](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
|
||||
performance.
|
||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||
@ -78,7 +96,7 @@ We also provide a some command line based examples using state of the art models
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
|
||||
|
||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
|
||||
image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||
|
||||
@ -97,16 +115,31 @@ We also provide a some command line based examples using state of the art models
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||
|
||||
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmentation model.
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
|
||||
model using residual vector quantization.
|
||||
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
|
||||
text-to-speech.
|
||||
- [Parler-TTS](./candle-examples/examples/parler-tts/): large text-to-speech
|
||||
model.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [VGG](./candle-examples/examples/vgg/),
|
||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [CLIP](./candle-examples/examples/clip/): multi-model vision and language
|
||||
model.
|
||||
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
||||
dedicated submodels for hand-writing and printed recognition.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
model, generates the translated text from the input text.
|
||||
- [Moondream](./candle-examples/examples/moondream/): tiny computer-vision model
|
||||
that can answer real-world questions about images.
|
||||
|
||||
Run them using commands like:
|
||||
```
|
||||
@ -122,7 +155,7 @@ There are also some wasm examples for whisper and
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||
[Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
|
||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||
@ -141,15 +174,22 @@ And then head over to
|
||||
## Useful External Resources
|
||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
|
||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has
|
||||
out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and
|
||||
ergonomic LoRA implementation for Candle. `candle-lora` has
|
||||
out-of-the-box LoRA support for many models from Candle, which can be found
|
||||
[here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
|
||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
|
||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||
serving local LLMs including an OpenAI compatible API server.
|
||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||
- [`candle-coursera-ml`](https://github.com/vishpat/candle-coursera-ml): Implementation of ML algorithms from Coursera's [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) course.
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
|
||||
- [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible.
|
||||
- [`llms-from-scratch-rs`](https://github.com/nerdai/llms-from-scratch-rs): A comprehensive Rust translation of the code from Sebastian Raschka's Build an LLM from Scratch book.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
|
||||
@ -168,33 +208,46 @@ If you have an addition to this list, please submit a pull request.
|
||||
- WASM support, run your models in a browser.
|
||||
- Included models.
|
||||
- Language Models.
|
||||
- LLaMA v1 and v2.
|
||||
- LLaMA v1, v2, and v3 with variants such as SOLAR-10.7B.
|
||||
- Falcon.
|
||||
- StarCoder.
|
||||
- Phi v1.5.
|
||||
- StarCoder, StarCoder2.
|
||||
- Phi 1, 1.5, 2, and 3.
|
||||
- Mamba, Minimal Mamba
|
||||
- Gemma v1 2b and 7b+, v2 2b and 9b.
|
||||
- Mistral 7b v0.1.
|
||||
- StableLM-3B-4E1T.
|
||||
- Mixtral 8x7b v0.1.
|
||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
||||
- Replit-code-v1.5-3B.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Qwen1.5, Qwen1.5 MoE.
|
||||
- RWKV v5 and v6.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
- Mistral 7b, and 7b instruct.
|
||||
- Zephyr 7b a and b (Mistral based).
|
||||
- OpenChat 3.5 (Mistral based).
|
||||
- Mixtral 8x7b.
|
||||
- Zephyr 7b a and b (Mistral-7b based).
|
||||
- OpenChat 3.5 (Mistral-7b based).
|
||||
- Text to text.
|
||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||
- Marian MT (Machine Translation).
|
||||
- Whisper (multi-lingual support).
|
||||
- Text to image.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- TrOCR.
|
||||
- Audio.
|
||||
- Whisper, multi-lingual speech-to-text.
|
||||
- EnCodec, audio compression model.
|
||||
- MetaVoice-1B, text-to-speech model.
|
||||
- Parler-TTS, text-to-speech model.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- SegFormer.
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
- Quantization support using the llama.cpp quantized types.
|
||||
@ -331,9 +384,9 @@ git submodule update --init
|
||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||
```
|
||||
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
|
||||
```
|
||||
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
```
|
||||
|
||||
#### Linking error on windows when running rustdoc or mdbook tests
|
||||
@ -363,3 +416,10 @@ This may be caused by the models being loaded from `/mnt/c`, more details on
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
error is generated.
|
||||
|
||||
#### CudaRC error
|
||||
|
||||
If you encounter an error like this one `called `Result::unwrap()` on an `Err` value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }` on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.
|
||||
`c:\Windows\System32\nvcuda.dll` -> `cuda.dll`
|
||||
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll` -> `cublas.dll`
|
||||
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll` -> `curand.dll`
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@ -25,7 +25,7 @@ cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
anyhow = { workspace = true }
|
||||
tokio = "1.29.1"
|
||||
tokio = "1.43.0"
|
||||
|
||||
[dev-dependencies]
|
||||
byteorder = { workspace = true }
|
||||
@ -37,7 +37,6 @@ tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
parquet = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
@ -28,6 +28,7 @@ let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap()
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_2() {
|
||||
{
|
||||
// ANCHOR: book_hub_2
|
||||
use candle::Device;
|
||||
use hf_hub::api::sync::Api;
|
||||
@ -45,9 +46,10 @@ let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap()
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_3() {
|
||||
// #[rustfmt::skip]
|
||||
// #[test]
|
||||
// fn book_hub_3() {
|
||||
{
|
||||
// ANCHOR: book_hub_3
|
||||
use candle::{DType, Device, Tensor};
|
||||
use hf_hub::api::sync::Api;
|
||||
@ -79,7 +81,7 @@ let mut tp_shape = view.shape().to_vec();
|
||||
let size = tp_shape[0];
|
||||
|
||||
if size % world_size != 0 {
|
||||
panic!("The dimension is not divisble by `world_size`");
|
||||
panic!("The dimension is not divisible by `world_size`");
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
@ -102,9 +104,10 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
||||
assert_eq!(view.shape(), &[768, 768]);
|
||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_training_1() -> Result<()>{
|
||||
// ANCHOR: book_training_1
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
@ -12,9 +12,9 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
candle-kernels = { workspace = true, optional = true }
|
||||
candle-metal-kernels = { workspace = true, optional = true }
|
||||
metal = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
@ -28,17 +28,35 @@ rand_distr = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
ug-cuda = { workspace = true, optional = true }
|
||||
ug-metal = { workspace = true, optional = true }
|
||||
yoke = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
ug = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
criterion = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]
|
||||
|
||||
[[bench]]
|
||||
name = "bench_main"
|
||||
harness = false
|
||||
|
||||
[[example]]
|
||||
name = "metal_basics"
|
||||
required-features = ["metal"]
|
||||
|
||||
[[example]]
|
||||
name = "cuda_basics"
|
||||
required-features = ["cuda"]
|
||||
|
14
candle-core/benches/bench_main.rs
Normal file
14
candle-core/benches/bench_main.rs
Normal file
@ -0,0 +1,14 @@
|
||||
mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::reduce::benches,
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
benchmarks::unary::benches
|
||||
);
|
43
candle-core/benches/benchmarks/affine.rs
Normal file
43
candle-core/benches/benchmarks/affine.rs
Normal file
@ -0,0 +1,43 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor) {
|
||||
a.affine(12.34, 56.78).unwrap();
|
||||
}
|
||||
|
||||
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let tensor = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||
|
||||
let flops = b * m * k * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
|
||||
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
|
||||
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
59
candle-core/benches/benchmarks/conv_transpose2d.rs
Normal file
59
candle-core/benches/benchmarks/conv_transpose2d.rs
Normal file
@ -0,0 +1,59 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(
|
||||
x: &Tensor,
|
||||
k: &Tensor,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) {
|
||||
x.conv_transpose2d(k, padding, output_padding, stride, dilation)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let t = Tensor::arange(0.0f32, 10000.0, device)
|
||||
.unwrap()
|
||||
.reshape((1, 4, 50, 50))
|
||||
.unwrap()
|
||||
.to_dtype(dtype)
|
||||
.unwrap();
|
||||
|
||||
let kernel = Tensor::arange(0.0f32, 100.0, device)
|
||||
.unwrap()
|
||||
.reshape((4, 1, 5, 5))
|
||||
.unwrap()
|
||||
.to_dtype(dtype)
|
||||
.unwrap();
|
||||
|
||||
let flops = t.dims().iter().product::<usize>() * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&t), black_box(&kernel), 1, 0, 1, 2);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
|
||||
run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
|
||||
run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
44
candle-core/benches/benchmarks/matmul.rs
Normal file
44
candle-core/benches/benchmarks/matmul.rs
Normal file
@ -0,0 +1,44 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor, b: &Tensor) {
|
||||
a.matmul(&b.t().unwrap()).unwrap();
|
||||
}
|
||||
|
||||
fn run_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
let m = 1;
|
||||
let n = 2048;
|
||||
let k = 2048;
|
||||
|
||||
let dtype = DType::F32;
|
||||
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
||||
|
||||
let flops = b * m * n * k;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("matmul"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&lhs), black_box(&rhs));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
72
candle-core/benches/benchmarks/mod.rs
Normal file
72
candle-core/benches/benchmarks/mod.rs
Normal file
@ -0,0 +1,72 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod conv_transpose2d;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod qmatmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod reduce;
|
||||
pub(crate) mod unary;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
||||
pub(crate) trait BenchDevice {
|
||||
fn sync(&self) -> Result<()>;
|
||||
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String;
|
||||
}
|
||||
|
||||
impl BenchDevice for Device {
|
||||
fn sync(&self) -> Result<()> {
|
||||
match self {
|
||||
Device::Cpu => Ok(()),
|
||||
Device::Cuda(device) => {
|
||||
#[cfg(feature = "cuda")]
|
||||
return Ok(device
|
||||
.synchronize()
|
||||
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
#[cfg(feature = "metal")]
|
||||
return Ok(device.wait_until_completed()?);
|
||||
#[cfg(not(feature = "metal"))]
|
||||
panic!("Metal device without metal feature enabled: {:?}", device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let cpu_type = if cfg!(feature = "accelerate") {
|
||||
"accelerate"
|
||||
} else if cfg!(feature = "mkl") {
|
||||
"mkl"
|
||||
} else {
|
||||
"cpu"
|
||||
};
|
||||
format!("{}_{}", cpu_type, name.into())
|
||||
}
|
||||
Device::Cuda(_) => format!("cuda_{}", name.into()),
|
||||
Device::Metal(_) => format!("metal_{}", name.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BenchDeviceHandler {
|
||||
devices: Vec<Device>,
|
||||
}
|
||||
|
||||
impl BenchDeviceHandler {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut devices = Vec::new();
|
||||
if cfg!(feature = "metal") {
|
||||
devices.push(Device::new_metal(0)?);
|
||||
} else if cfg!(feature = "cuda") {
|
||||
devices.push(Device::new_cuda(0)?);
|
||||
}
|
||||
devices.push(Device::Cpu);
|
||||
Ok(Self { devices })
|
||||
}
|
||||
}
|
72
candle-core/benches/benchmarks/qmatmul.rs
Normal file
72
candle-core/benches/benchmarks/qmatmul.rs
Normal file
@ -0,0 +1,72 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{
|
||||
quantized::{self, GgmlDType, QMatMul},
|
||||
Device, Module, Tensor,
|
||||
};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(matmul: &QMatMul, x: &Tensor) {
|
||||
matmul.forward(x).unwrap();
|
||||
}
|
||||
|
||||
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
||||
let b = 1;
|
||||
let m = 1;
|
||||
let n = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 / (m * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
let rhs = (0..(k * n))
|
||||
.map(|v| v as f32 / (n * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
|
||||
let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
|
||||
|
||||
let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
|
||||
|
||||
let flops = b * m * n * k;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
|
||||
group.sample_size(200);
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&matmul), black_box(&lhs));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
for dtype in [
|
||||
GgmlDType::F32,
|
||||
GgmlDType::F16,
|
||||
GgmlDType::Q4_0,
|
||||
GgmlDType::Q4_1,
|
||||
GgmlDType::Q5_0,
|
||||
GgmlDType::Q5_1,
|
||||
GgmlDType::Q8_0,
|
||||
GgmlDType::Q2K,
|
||||
GgmlDType::Q3K,
|
||||
GgmlDType::Q4K,
|
||||
GgmlDType::Q5K,
|
||||
GgmlDType::Q6K,
|
||||
] {
|
||||
run_bench(c, &device, dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
63
candle-core/benches/benchmarks/random.rs
Normal file
63
candle-core/benches/benchmarks/random.rs
Normal file
@ -0,0 +1,63 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn rand_uniform(a: &Tensor) {
|
||||
a.rand_like(-1.0, 123.0).unwrap();
|
||||
}
|
||||
|
||||
fn rand_normal(a: &Tensor) {
|
||||
a.randn_like(100.0, 15.0).unwrap();
|
||||
}
|
||||
|
||||
fn run_random_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
|
||||
let rows = 2048;
|
||||
let cols = 2048;
|
||||
|
||||
let dtype = DType::F32;
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let flops = b * rows * cols * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_uniform(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_normal"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_normal(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_random_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
158
candle-core/benches/benchmarks/reduce.rs
Normal file
158
candle-core/benches/benchmarks/reduce.rs
Normal file
@ -0,0 +1,158 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use half::{bf16, f16};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run_sum(a: &Tensor) {
|
||||
a.sum_keepdim(2).unwrap();
|
||||
}
|
||||
fn run_arg_min(a: &Tensor) {
|
||||
a.argmin_keepdim(2).unwrap();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
let (lo, up) = (-1000.0f32, 1000.0f32);
|
||||
for device in handler.devices {
|
||||
run_reduce(c, &device, (lo, up), false);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), false);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_reduce(c, &device, (lo, up), true);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), true);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"reduce_f32_strided"
|
||||
} else {
|
||||
"reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"reduce_f16_strided"
|
||||
} else {
|
||||
"reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"reduce_bf16_strided"
|
||||
} else {
|
||||
"reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "unknown",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_sum(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn run_arg_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"arg_reduce_f32_strided"
|
||||
} else {
|
||||
"arg_reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"arg_reduce_f16_strided"
|
||||
} else {
|
||||
"arg_reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"arg_reduce_bf16_strided"
|
||||
} else {
|
||||
"arg_reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "unknown",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_arg_min(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
49
candle-core/benches/benchmarks/unary.rs
Normal file
49
candle-core/benches/benchmarks/unary.rs
Normal file
@ -0,0 +1,49 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor) {
|
||||
a.sqrt().unwrap();
|
||||
}
|
||||
|
||||
fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, device)
|
||||
.unwrap()
|
||||
.to_dtype(dtype)
|
||||
.unwrap()
|
||||
.reshape((b, m, k))
|
||||
.unwrap();
|
||||
|
||||
let flops = b * m * k * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
for dtype in [DType::F32, DType::BF16, DType::F16] {
|
||||
let name = format!("sqrt_{:?}", dtype);
|
||||
run_unary_benchmark(c, &device, dtype, &name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
64
candle-core/benches/benchmarks/where_cond.rs
Normal file
64
candle-core/benches/benchmarks/where_cond.rs
Normal file
@ -0,0 +1,64 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
|
||||
a.where_cond(b, c).unwrap();
|
||||
}
|
||||
|
||||
const fn create_cond_arr<const N: usize>() -> [u8; N] {
|
||||
let mut arr = [0u8; N];
|
||||
let mut i = 0;
|
||||
while i < N {
|
||||
arr[i] = (i % 2) as u8;
|
||||
i += 1;
|
||||
}
|
||||
arr
|
||||
}
|
||||
|
||||
const B: usize = 1;
|
||||
const M: usize = 1024;
|
||||
const K: usize = 1024;
|
||||
const SIZE: usize = B * M * K;
|
||||
|
||||
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
||||
|
||||
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
|
||||
let on_true = Tensor::ones((B, M, K), dtype, device).unwrap();
|
||||
let on_false = Tensor::zeros((B, M, K), dtype, device).unwrap();
|
||||
|
||||
let elements = B * M * K;
|
||||
// E.g. 2 f32 tensors + 1 u8 tensor
|
||||
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(
|
||||
black_box(&tensor),
|
||||
black_box(&on_true),
|
||||
black_box(&on_false),
|
||||
);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let device = BenchDeviceHandler::new().unwrap();
|
||||
for d in device.devices {
|
||||
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
|
||||
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
|
||||
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -9,21 +9,25 @@ use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
|
||||
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
|
||||
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
println!("{out_t}");
|
||||
let in_t = in_t.to_device(&Device::Cpu)?;
|
||||
let k_t = k_t.to_device(&Device::Cpu)?;
|
||||
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
|
||||
.sqr()?
|
||||
.sum_all()?;
|
||||
println!("{diff}");
|
||||
|
||||
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
||||
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
||||
let res = t.conv2d(&w, 1, 1, 1, 1)?;
|
||||
println!("{res:?}");
|
||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("fp32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("tf32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
Ok(())
|
||||
}
|
||||
|
28
candle-core/examples/metal_basics.rs
Normal file
28
candle-core/examples/metal_basics.rs
Normal file
@ -0,0 +1,28 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// This requires the code to be run with MTL_CAPTURE_ENABLED=1
|
||||
let device = Device::new_metal(0)?;
|
||||
let metal_device = match &device {
|
||||
Device::Metal(m) => m,
|
||||
_ => anyhow::bail!("unexpected device"),
|
||||
};
|
||||
metal_device.capture("/tmp/candle.gputrace")?;
|
||||
// This first synchronize ensures that a new command buffer gets created after setting up the
|
||||
// capture scope.
|
||||
device.synchronize()?;
|
||||
let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?;
|
||||
let x1 = x.add(&x)?;
|
||||
println!("{x1:?}");
|
||||
// This second synchronize ensures that the command buffer gets commited before the end of the
|
||||
// capture scope.
|
||||
device.synchronize()?;
|
||||
Ok(())
|
||||
}
|
@ -380,6 +380,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
@ -402,6 +412,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vs_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vd_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -1,3 +1,5 @@
|
||||
//! Traits to Define Backend Behavior
|
||||
//!
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
|
||||
@ -98,6 +100,19 @@ pub trait BackendStorage: Sized {
|
||||
) -> Result<Self>;
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
// Similar to cudaMemcpy2D, though values are in elements and not in bytes.
|
||||
fn copy2d(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_d1: usize,
|
||||
_d2: usize,
|
||||
_src_stride1: usize,
|
||||
_dst_stride1: usize,
|
||||
_src_offset: usize,
|
||||
_dst_offset: usize,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
@ -114,11 +129,24 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
/// # Safety
|
||||
/// This function is unsafe as it doesn't initialize the underlying data store.
|
||||
/// The caller should ensure that the data is properly initialized as early as possible
|
||||
/// after this call.
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()>;
|
||||
|
||||
/// Synchronize should block until all the operations on the device are completed.
|
||||
fn synchronize(&self) -> Result<()>;
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
//! Methods for backpropagation of gradients.
|
||||
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::{Error, Result, Tensor, TensorId};
|
||||
use std::collections::HashMap;
|
||||
@ -31,7 +32,7 @@ impl Tensor {
|
||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||
/// argument.
|
||||
/// This assumes that the op graph is a DAG.
|
||||
fn sorted_nodes(&self) -> Vec<&Tensor> {
|
||||
pub fn sorted_nodes(&self) -> Vec<&Tensor> {
|
||||
// The vec of sorted nodes is passed as an owned value rather than a mutable reference
|
||||
// to get around some lifetime limitations.
|
||||
fn walk<'a>(
|
||||
@ -111,10 +112,11 @@ impl Tensor {
|
||||
}
|
||||
Op::Unary(_node, UnaryOp::Ceil)
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
| Op::Unary(_node, UnaryOp::Round)
|
||||
| Op::Unary(_node, UnaryOp::Sign) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D(node)
|
||||
| Op::UpsampleNearest1D { arg: node, .. }
|
||||
| Op::UpsampleNearest2D { arg: node, .. }
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
| Op::Copy(node)
|
||||
@ -175,7 +177,7 @@ impl Tensor {
|
||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||
// derivatives but these are out of scope at the moment.
|
||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||
let grad = if do_not_detach { grad } else { grad.detach()? };
|
||||
let grad = if do_not_detach { grad } else { grad.detach() };
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||
@ -250,6 +252,7 @@ impl Tensor {
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
/* groups */ 1,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
@ -309,9 +312,32 @@ impl Tensor {
|
||||
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose1d",
|
||||
})?,
|
||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose2d",
|
||||
})?,
|
||||
Op::ConvTranspose2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
output_padding: _output_padding,
|
||||
} => {
|
||||
let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
||||
let grad_kernel = grad
|
||||
.transpose(0, 1)?
|
||||
.conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
let (_, _, k0, k1) = kernel.dims4()?;
|
||||
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
||||
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
||||
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
||||
} else {
|
||||
grad_kernel
|
||||
};
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::AvgPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
@ -347,12 +373,39 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
}
|
||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest1d",
|
||||
})?,
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
Op::UpsampleNearest1D { arg, target_size } => {
|
||||
let (_n, c, size) = arg.dims3()?;
|
||||
if target_size % size != 0 {
|
||||
crate::bail!("backward not supported for non integer upscaling factors")
|
||||
}
|
||||
let scale = target_size / size;
|
||||
|
||||
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
|
||||
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = conv_sum;
|
||||
}
|
||||
Op::UpsampleNearest2D {
|
||||
arg,
|
||||
target_h,
|
||||
target_w,
|
||||
} => {
|
||||
let (_n, c, h, w) = arg.dims4()?;
|
||||
if target_h % h != 0 || target_w % w != 0 {
|
||||
crate::bail!("backward not supported for non integer upscaling factors")
|
||||
}
|
||||
let scale_h = target_h / h;
|
||||
let scale_w = target_w / w;
|
||||
|
||||
if scale_h != scale_w {
|
||||
crate::bail!("backward not supported for non uniform upscaling factors")
|
||||
};
|
||||
let kernel =
|
||||
Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
|
||||
let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = conv_sum;
|
||||
}
|
||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||
@ -436,7 +489,6 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad)?;
|
||||
}
|
||||
Op::Cmp(_args, _) => {}
|
||||
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
||||
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||
@ -526,20 +578,18 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
|
||||
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
|
||||
Op::Unary(_, UnaryOp::Floor)
|
||||
| Op::Unary(_, UnaryOp::Round)
|
||||
| Op::Reduce(_, ReduceOp::ArgMin, _)
|
||||
| Op::Reduce(_, ReduceOp::ArgMax, _)
|
||||
| Op::Unary(_, UnaryOp::Sign)
|
||||
| Op::Cmp(_, _) => {}
|
||||
Op::Reshape(arg) => {
|
||||
let arg_grad = grad.reshape(arg.dims())?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||
Op::Unary(_, UnaryOp::Floor) => {
|
||||
Err(Error::BackwardNotSupported { op: "floor" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Round) => {
|
||||
Err(Error::BackwardNotSupported { op: "round" })?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Gelu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let cube = arg.powf(3.)?;
|
||||
@ -571,13 +621,21 @@ impl Tensor {
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Silu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node
|
||||
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
|
||||
let silu_grad = &sigmoid_arg * (1. - *node) + *node;
|
||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||
}
|
||||
Op::Elu(arg, alpha) => {
|
||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let zeros = arg.zeros_like()?;
|
||||
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
||||
// node == alpha * (e^x - 1) for x <= 0, reuse it
|
||||
let negative_exp_mask = (negative_mask * (*node + *alpha))?;
|
||||
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||
}
|
||||
@ -655,30 +713,38 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
|
||||
#[derive(Debug)]
|
||||
pub struct GradStore(HashMap<TensorId, Tensor>);
|
||||
|
||||
impl GradStore {
|
||||
/// Create a new gradient store
|
||||
fn new() -> Self {
|
||||
GradStore(HashMap::new())
|
||||
}
|
||||
|
||||
/// Get the gradient tensor corresponding to the given tensor id
|
||||
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
|
||||
self.0.get(&id)
|
||||
}
|
||||
|
||||
/// Get the gradient tensor associated with the given tensor
|
||||
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
|
||||
self.0.get(&tensor.id())
|
||||
}
|
||||
|
||||
/// Remove the gradient tensor associated with the given tensor, returning it if it exists
|
||||
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
|
||||
self.0.remove(&tensor.id())
|
||||
}
|
||||
|
||||
/// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
|
||||
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
|
||||
self.0.insert(tensor.id(), grad)
|
||||
}
|
||||
|
||||
/// Get the gradient tensor associated with the given tensor, or, if it does not exist,
|
||||
/// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
|
||||
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
|
||||
use std::collections::hash_map::Entry;
|
||||
let grad = match self.0.entry(tensor.id()) {
|
||||
@ -690,4 +756,9 @@ impl GradStore {
|
||||
};
|
||||
Ok(grad)
|
||||
}
|
||||
|
||||
/// Get the tensor ids of the stored gradient tensors
|
||||
pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
|
||||
self.0.keys()
|
||||
}
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
//! 1D and 2D Convolutions
|
||||
//!
|
||||
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@ -12,6 +14,7 @@ pub struct ParamsConv1D {
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl ParamsConv1D {
|
||||
@ -172,6 +175,7 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv1d_single_group(kernel, ¶ms)
|
||||
@ -187,36 +191,16 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies a 1D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose1d(
|
||||
fn conv_transpose1d_single_group(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
params: &ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = ParamsConvTranspose1D {
|
||||
b_size,
|
||||
l_in,
|
||||
k_size,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
let storage = self.storage().conv_transpose1d(
|
||||
self.layout(),
|
||||
&kernel.storage(),
|
||||
kernel.layout(),
|
||||
¶ms,
|
||||
params,
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||
arg,
|
||||
@ -230,6 +214,49 @@ impl Tensor {
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 1D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
if c_in % groups != 0 {
|
||||
crate::bail!("in_channel {c_in} is not divisible by the number of groups")
|
||||
}
|
||||
let params = ParamsConvTranspose1D {
|
||||
b_size,
|
||||
l_in,
|
||||
k_size,
|
||||
c_out,
|
||||
c_in: c_in / groups,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv_transpose1d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let kernel = kernel.chunk(groups, 0)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.zip(&kernel)
|
||||
.map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
}
|
||||
|
||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
|
@ -1,6 +1,9 @@
|
||||
//! Traits and methods for CPU-backed Tensors
|
||||
|
||||
pub mod erf;
|
||||
pub mod kernels;
|
||||
|
||||
#[allow(unused)]
|
||||
trait Cpu<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
@ -18,6 +21,7 @@ trait Cpu<const ARR: usize> {
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
trait CpuF16<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
|
File diff suppressed because it is too large
Load Diff
360
candle-core/src/cpu_backend/utils.rs
Normal file
360
candle-core/src/cpu_backend/utils.rs
Normal file
@ -0,0 +1,360 @@
|
||||
/// Helper functions to write CPU kernels.
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{Error, Layout, Result, WithDType};
|
||||
|
||||
type C = super::CpuStorage;
|
||||
pub trait Map1 {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
||||
match vs {
|
||||
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
|
||||
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
|
||||
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
|
||||
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
|
||||
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
|
||||
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
|
||||
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map1Any {
|
||||
fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
|
||||
|
||||
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
||||
match vs {
|
||||
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
|
||||
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
|
||||
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
|
||||
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
|
||||
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
|
||||
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
|
||||
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2U8 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||
|
||||
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.zip(rhs[o_r1..o_r2].iter())
|
||||
.map(|(&l, &r)| f(l, r))
|
||||
.collect(),
|
||||
(Some((o_l1, o_l2)), None) => {
|
||||
// TODO: Maybe we want to avoid going through the layout twice.
|
||||
match rhs_l.offsets_b() {
|
||||
Some(ob) => {
|
||||
let mut i_in_block = 0;
|
||||
let mut i_right_broadcast = 0;
|
||||
lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.map(|&l| {
|
||||
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
||||
i_right_broadcast += 1;
|
||||
if i_right_broadcast >= ob.right_broadcast {
|
||||
i_in_block += 1;
|
||||
i_right_broadcast = 0;
|
||||
}
|
||||
if i_in_block >= ob.len {
|
||||
i_in_block = 0
|
||||
}
|
||||
f(l, *r)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
(None, Some((o_r1, o_r2))) => {
|
||||
// TODO: Maybe we want to avoid going through the layout twice.
|
||||
match lhs_l.offsets_b() {
|
||||
Some(ob) => {
|
||||
let mut i_in_block = 0;
|
||||
let mut i_right_broadcast = 0;
|
||||
rhs[o_r1..o_r2]
|
||||
.iter()
|
||||
.map(|&r| {
|
||||
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
||||
i_right_broadcast += 1;
|
||||
if i_right_broadcast >= ob.right_broadcast {
|
||||
i_in_block += 1;
|
||||
i_right_broadcast = 0;
|
||||
}
|
||||
if i_in_block >= ob.len {
|
||||
i_in_block = 0
|
||||
}
|
||||
f(*l, r)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
// Similar to binary_map but with vectorized variants.
|
||||
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<T> {
|
||||
let el_count = lhs_l.shape().elem_count();
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||
};
|
||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||
};
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||
f_vec(
|
||||
&lhs[src_i..src_i + ob.len],
|
||||
rhs,
|
||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||
);
|
||||
dst_i += ob.len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys = lhs[o_l1..o_l2].to_vec();
|
||||
for idx_l in 0..ob.left_broadcast {
|
||||
let start = idx_l * ob.len * ob.right_broadcast;
|
||||
for (i, &r) in rhs.iter().enumerate() {
|
||||
let start = start + i * ob.right_broadcast;
|
||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||
*v = f(*v, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
ys
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
},
|
||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||
};
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||
f_vec(
|
||||
lhs,
|
||||
&rhs[src_i..src_i + ob.len],
|
||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||
);
|
||||
dst_i += ob.len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys = rhs[o_r1..o_r2].to_vec();
|
||||
for idx_l in 0..ob.left_broadcast {
|
||||
let start = idx_l * ob.len * ob.right_broadcast;
|
||||
for (i, &l) in lhs.iter().enumerate() {
|
||||
let start = start + i * ob.right_broadcast;
|
||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||
*v = f(l, *v)
|
||||
}
|
||||
}
|
||||
}
|
||||
ys
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
},
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||
[start_offset..start_offset + len]
|
||||
.iter()
|
||||
.map(|&v| f(v))
|
||||
.collect(),
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
} else {
|
||||
for index in block_start_index {
|
||||
for offset in 0..block_len {
|
||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
|
||||
};
|
||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(len) };
|
||||
ys
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let el_count = layout.shape().elem_count();
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
let mut result = Vec::with_capacity(el_count);
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
result
|
||||
} else {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
|
||||
};
|
||||
let mut dst_index = 0;
|
||||
for src_index in block_start_index {
|
||||
let vs = &vs[src_index..src_index + block_len];
|
||||
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
||||
f_vec(vs, ys);
|
||||
dst_index += block_len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
225
candle-core/src/cuda_backend/cudnn.rs
Normal file
225
candle-core/src/cuda_backend/cudnn.rs
Normal file
@ -0,0 +1,225 @@
|
||||
use crate::WithDType;
|
||||
use cudarc;
|
||||
use cudarc::cudnn::safe::{ConvForward, Cudnn};
|
||||
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
// The cudnn handles are stored per thread here rather than on the CudaDevice as they are neither
|
||||
// send nor sync.
|
||||
thread_local! {
|
||||
static CUDNN: RefCell<HashMap<crate::cuda_backend::DeviceId, Arc<Cudnn>>> = HashMap::new().into();
|
||||
}
|
||||
|
||||
impl From<cudarc::cudnn::CudnnError> for crate::Error {
|
||||
fn from(err: cudarc::cudnn::CudnnError) -> Self {
|
||||
crate::Error::wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<cudarc::driver::DriverError> for crate::Error {
|
||||
fn from(err: cudarc::driver::DriverError) -> Self {
|
||||
crate::Error::wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn launch_conv2d<
|
||||
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
||||
Y: cudarc::cudnn::CudnnDataType,
|
||||
>(
|
||||
src: &CudaView<T>,
|
||||
src_l: &crate::Layout,
|
||||
filter: &CudaView<T>,
|
||||
dst: &mut CudaSlice<T>,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
dev: &crate::cuda_backend::CudaDevice,
|
||||
) -> crate::Result<()> {
|
||||
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||
|
||||
let device_id = dev.id();
|
||||
let cudnn = CUDNN.with(|cudnn| {
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
return Ok(cudnn.clone());
|
||||
}
|
||||
let c = Cudnn::new(dev.cuda_stream());
|
||||
if let Ok(c) = &c {
|
||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||
}
|
||||
c
|
||||
})?;
|
||||
let conv = cudnn.create_conv2d::<Y>(
|
||||
/* pad */ [params.padding as i32, params.padding as i32],
|
||||
/* stride */ [params.stride as i32, params.stride as i32],
|
||||
/* dilation */ [params.dilation as i32, params.dilation as i32],
|
||||
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
||||
)?;
|
||||
let x_shape = [
|
||||
params.b_size as i32,
|
||||
params.c_in as i32,
|
||||
params.i_h as i32,
|
||||
params.i_w as i32,
|
||||
];
|
||||
// Note that `src` already starts at the proper offset.
|
||||
let x = if src_l.is_contiguous() {
|
||||
cudnn.create_4d_tensor::<T>(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
x_shape,
|
||||
)?
|
||||
} else {
|
||||
let s = src_l.stride();
|
||||
cudnn.create_4d_tensor_ex::<T>(
|
||||
x_shape,
|
||||
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
|
||||
)?
|
||||
};
|
||||
let w = cudnn.create_4d_filter::<T>(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[
|
||||
params.c_out as i32,
|
||||
params.c_in as i32,
|
||||
params.k_h as i32,
|
||||
params.k_w as i32,
|
||||
],
|
||||
)?;
|
||||
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
||||
let y = cudnn.create_4d_tensor::<T>(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
||||
)?;
|
||||
let conv2d = ConvForward {
|
||||
conv: &conv,
|
||||
x: &x,
|
||||
w: &w,
|
||||
y: &y,
|
||||
};
|
||||
let alg = match params.cudnn_fwd_algo {
|
||||
None => conv2d.pick_algorithm()?,
|
||||
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||
}
|
||||
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
||||
alg,
|
||||
Some(&mut workspace),
|
||||
(T::one(), T::zero()),
|
||||
src,
|
||||
filter,
|
||||
dst,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn launch_conv1d<
|
||||
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
||||
Y: cudarc::cudnn::CudnnDataType,
|
||||
>(
|
||||
src: &CudaView<T>,
|
||||
src_l: &crate::Layout,
|
||||
filter: &CudaView<T>,
|
||||
dst: &mut CudaSlice<T>,
|
||||
params: &crate::conv::ParamsConv1D,
|
||||
dev: &crate::cuda_backend::CudaDevice,
|
||||
) -> crate::Result<()> {
|
||||
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||
|
||||
let device_id = dev.id();
|
||||
let cudnn = CUDNN.with(|cudnn| {
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
return Ok(cudnn.clone());
|
||||
}
|
||||
let c = Cudnn::new(dev.cuda_stream());
|
||||
if let Ok(c) = &c {
|
||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||
}
|
||||
c
|
||||
})?;
|
||||
let conv = cudnn.create_conv2d::<Y>(
|
||||
/* pad */ [params.padding as i32, 0],
|
||||
/* stride */ [params.stride as i32, 1],
|
||||
/* dilation */ [params.dilation as i32, 1],
|
||||
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
||||
)?;
|
||||
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/api/cudnn-ops-library.html#cudnnsettensornddescriptor
|
||||
// > Tensors are restricted to having at least 4 dimensions, and at most CUDNN_DIM_MAX
|
||||
// > dimensions (defined in cudnn.h). When working with lower dimensional data, it is
|
||||
// > recommended that the user create a 4D tensor, and set the size along unused dimensions
|
||||
// > to 1.
|
||||
let x_shape = [
|
||||
params.b_size as i32,
|
||||
params.c_in as i32,
|
||||
params.l_in as i32,
|
||||
1,
|
||||
];
|
||||
// Note that `src` already starts at the proper offset.
|
||||
let x = if src_l.is_contiguous() {
|
||||
cudnn.create_4d_tensor::<T>(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
x_shape,
|
||||
)?
|
||||
} else {
|
||||
let s = src_l.stride();
|
||||
cudnn.create_4d_tensor_ex::<T>(x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, 1i32])?
|
||||
};
|
||||
let w = cudnn.create_4d_filter::<T>(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[
|
||||
params.c_out as i32,
|
||||
params.c_in as i32,
|
||||
params.k_size as i32,
|
||||
1,
|
||||
],
|
||||
)?;
|
||||
let l_out = params.l_out() as i32;
|
||||
let y = cudnn.create_4d_tensor::<T>(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[params.b_size as i32, params.c_out as i32, l_out, 1],
|
||||
)?;
|
||||
let conv1d = ConvForward {
|
||||
conv: &conv,
|
||||
x: &x,
|
||||
w: &w,
|
||||
y: &y,
|
||||
};
|
||||
let alg = match params.cudnn_fwd_algo {
|
||||
None => conv1d.pick_algorithm()?,
|
||||
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||
}
|
||||
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let workspace_size = conv1d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
conv1d.launch::<CudaSlice<u8>, _, _, _>(
|
||||
alg,
|
||||
Some(&mut workspace),
|
||||
(T::one(), T::zero()),
|
||||
src,
|
||||
filter,
|
||||
dst,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
646
candle-core/src/cuda_backend/device.rs
Normal file
646
candle-core/src/cuda_backend/device.rs
Normal file
@ -0,0 +1,646 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
||||
use half::{bf16, f16};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
struct CudaRng(cudarc::curand::CudaRng);
|
||||
unsafe impl Send for CudaRng {}
|
||||
|
||||
pub struct ModuleStore {
|
||||
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CudaDevice {
|
||||
id: DeviceId,
|
||||
context: Arc<cudarc::driver::CudaContext>,
|
||||
modules: Arc<std::sync::RwLock<ModuleStore>>,
|
||||
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
|
||||
stream: Arc<cudarc::driver::CudaStream>,
|
||||
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
||||
curand: Arc<Mutex<CudaRng>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "CudaDevice({:?})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
#[allow(clippy::missing_safety_doc)]
|
||||
pub unsafe fn alloc<T: cudarc::driver::DeviceRepr>(
|
||||
&self,
|
||||
len: usize,
|
||||
) -> Result<cudarc::driver::CudaSlice<T>> {
|
||||
self.stream.alloc::<T>(len).w()
|
||||
}
|
||||
|
||||
pub fn alloc_zeros<T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits>(
|
||||
&self,
|
||||
len: usize,
|
||||
) -> Result<cudarc::driver::CudaSlice<T>> {
|
||||
self.stream.alloc_zeros::<T>(len).w()
|
||||
}
|
||||
|
||||
pub fn memcpy_htod<
|
||||
T: cudarc::driver::DeviceRepr,
|
||||
Src: cudarc::driver::HostSlice<T> + ?Sized,
|
||||
Dst: cudarc::driver::DevicePtrMut<T>,
|
||||
>(
|
||||
&self,
|
||||
src: &Src,
|
||||
dst: &mut Dst,
|
||||
) -> Result<()> {
|
||||
self.stream.memcpy_htod(src, dst).w()
|
||||
}
|
||||
|
||||
pub fn memcpy_dtov<T: cudarc::driver::DeviceRepr, Src: cudarc::driver::DevicePtr<T>>(
|
||||
&self,
|
||||
src: &Src,
|
||||
) -> Result<Vec<T>> {
|
||||
self.stream.memcpy_dtov(src).w()
|
||||
}
|
||||
|
||||
pub fn memcpy_dtod<
|
||||
T,
|
||||
Src: cudarc::driver::DevicePtr<T>,
|
||||
Dst: cudarc::driver::DevicePtrMut<T>,
|
||||
>(
|
||||
&self,
|
||||
src: &Src,
|
||||
dst: &mut Dst,
|
||||
) -> Result<()> {
|
||||
self.stream.memcpy_dtod(src, dst).w()
|
||||
}
|
||||
|
||||
pub fn memcpy_stod<
|
||||
T: cudarc::driver::DeviceRepr,
|
||||
Src: cudarc::driver::HostSlice<T> + ?Sized,
|
||||
>(
|
||||
&self,
|
||||
src: &Src,
|
||||
) -> Result<cudarc::driver::CudaSlice<T>> {
|
||||
self.stream.memcpy_stod(src).w()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CudaFunc {
|
||||
func: CudaFunction,
|
||||
stream: Arc<cudarc::driver::CudaStream>,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaFunc {
|
||||
type Target = CudaFunction;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.func
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaFunc {
|
||||
pub fn into_cuda_function(self) -> CudaFunction {
|
||||
self.func
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! builder_arg {
|
||||
($b:ident, $($arg:expr),*) => {
|
||||
$(
|
||||
let __arg = $arg;
|
||||
$b.arg(&__arg);
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl CudaFunc {
|
||||
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
|
||||
self.stream.launch_builder(&self.func)
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
|
||||
self.stream.clone()
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn compile(
|
||||
&self,
|
||||
func_name: &'static str,
|
||||
kernel: ug::lang::ssa::Kernel,
|
||||
) -> Result<CudaFunc> {
|
||||
let mut buf = vec![];
|
||||
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
|
||||
let cuda_code = String::from_utf8(buf)?;
|
||||
let opts = cudarc::nvrtc::CompileOptions {
|
||||
use_fast_math: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
|
||||
let module = self.context.load_module(ptx).w()?;
|
||||
let func = module.load_function(func_name).w()?;
|
||||
Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DeviceId {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u8;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as i64;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = bf16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = f16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as f32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_custom_func(
|
||||
&self,
|
||||
fn_name: &str,
|
||||
module_name: &str,
|
||||
ptx: &str,
|
||||
) -> Result<CudaFunc> {
|
||||
let ms = self.custom_modules.read().unwrap();
|
||||
if let Some(mdl) = ms.get(module_name).as_ref() {
|
||||
let func = mdl.load_function(fn_name).w()?;
|
||||
return Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
});
|
||||
}
|
||||
drop(ms);
|
||||
let mut ms = self.custom_modules.write().unwrap();
|
||||
let cuda_module = self.context.load_module(ptx.into()).w()?;
|
||||
ms.insert(module_name.to_string(), cuda_module.clone());
|
||||
let func = cuda_module.load_function(fn_name).w()?;
|
||||
Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
|
||||
let ms = self.modules.read().unwrap();
|
||||
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
|
||||
let func = mdl.load_function(fn_name).w()?;
|
||||
return Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
});
|
||||
}
|
||||
drop(ms);
|
||||
let mut ms = self.modules.write().unwrap();
|
||||
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
|
||||
ms.mdls[mdl.index()] = Some(cuda_module.clone());
|
||||
let func = cuda_module.load_function(fn_name).w()?;
|
||||
Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
|
||||
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||
let stream = context.new_stream().w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||
let module_store = ModuleStore {
|
||||
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||
};
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
context,
|
||||
stream,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for CudaDevice {
|
||||
type Storage = CudaStorage;
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||
let stream = context.default_stream();
|
||||
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||
let module_store = ModuleStore {
|
||||
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||
};
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
context,
|
||||
stream,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||
// state will be identical and the same random numbers will be generated.
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.context.ordinal(),
|
||||
}
|
||||
}
|
||||
|
||||
fn same_device(&self, rhs: &Self) -> bool {
|
||||
self.id == rhs.id
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc_zeros::<u8>(elem_count)?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc_zeros::<u32>(elem_count)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc_zeros::<i64>(elem_count)?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc_zeros::<bf16>(elem_count)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc_zeros::<f16>(elem_count)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc_zeros::<f32>(elem_count)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc_zeros::<f64>(elem_count)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count)? };
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count)? };
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
let slice = if lo == 0. && up == 1.0 {
|
||||
slice
|
||||
} else {
|
||||
use super::utils::Map1;
|
||||
let layout = Layout::contiguous(shape);
|
||||
super::Affine(up - lo, lo).map(&slice, self, &layout)?
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
// curand can only generate an odd number of values.
|
||||
// https://github.com/huggingface/candle/issues/734
|
||||
let elem_count_round = if elem_count % 2 == 1 {
|
||||
elem_count + 1
|
||||
} else {
|
||||
elem_count
|
||||
};
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round)? };
|
||||
curand
|
||||
.0
|
||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||
.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round)? };
|
||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
self.const_impl(1., shape, dtype)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc::<u8>(elem_count)?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc::<u32>(elem_count)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc::<i64>(elem_count)?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc::<bf16>(elem_count)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc::<f16>(elem_count)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc::<f32>(elem_count)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc::<f64>(elem_count)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let slice = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorageRef::U32(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorageRef::I64(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorageRef::BF16(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorageRef::F16(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorageRef::F32(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorageRef::F64(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
self.stream.synchronize().map_err(crate::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
62
candle-core/src/cuda_backend/error.rs
Normal file
62
candle-core/src/cuda_backend/error.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use crate::{DType, Layout};
|
||||
|
||||
/// cudarc related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum CudaError {
|
||||
#[error(transparent)]
|
||||
Cuda(#[from] cudarc::driver::DriverError),
|
||||
|
||||
#[error(transparent)]
|
||||
Compiler(#[from] cudarc::nvrtc::CompileError),
|
||||
|
||||
#[error(transparent)]
|
||||
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||
|
||||
#[error(transparent)]
|
||||
Curand(#[from] cudarc::curand::result::CurandError),
|
||||
|
||||
#[error("missing kernel '{module_name}'")]
|
||||
MissingKernel { module_name: String },
|
||||
|
||||
#[error("unsupported dtype {dtype:?} for {op}")]
|
||||
UnsupportedDtype { dtype: DType, op: &'static str },
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Layout,
|
||||
rhs_stride: Layout,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType {
|
||||
msg: &'static str,
|
||||
expected: DType,
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("{cuda} when loading {module_name}")]
|
||||
Load {
|
||||
cuda: cudarc::driver::DriverError,
|
||||
module_name: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<CudaError> for crate::Error {
|
||||
fn from(val: CudaError) -> Self {
|
||||
crate::Error::Cuda(Box::new(val)).bt()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WrapErr<O> {
|
||||
fn w(self) -> std::result::Result<O, crate::Error>;
|
||||
}
|
||||
|
||||
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
||||
fn w(self) -> std::result::Result<O, crate::Error> {
|
||||
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
172
candle-core/src/cuda_backend/utils.rs
Normal file
172
candle-core/src/cuda_backend/utils.rs
Normal file
@ -0,0 +1,172 @@
|
||||
/// Helper functions to plug cuda kernels in candle.
|
||||
use crate::{Layout, Result, Shape, WithDType};
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||
|
||||
use super::{CudaDevice, CudaError, WrapErr};
|
||||
|
||||
pub type S = super::CudaStorageSlice;
|
||||
|
||||
pub trait Map1 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||
S::F64(s) => S::F64(self.f(s, d, l)?),
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map3 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
src3: &CudaSlice<T>,
|
||||
layout3: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn map(
|
||||
&self,
|
||||
s1: &S,
|
||||
l1: &Layout,
|
||||
s2: &S,
|
||||
l2: &Layout,
|
||||
s3: &S,
|
||||
l3: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<S> {
|
||||
let out = match (s1, s2, s3) {
|
||||
(S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()>;
|
||||
|
||||
fn map(
|
||||
&self,
|
||||
dst: &mut S,
|
||||
dst_s: &Shape,
|
||||
src: &S,
|
||||
src_l: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
wrap: W,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||
S::F64(s) => self.f(s, d, l, S::F64)?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
@ -1,123 +0,0 @@
|
||||
use crate::WithDType;
|
||||
use cudarc;
|
||||
use cudarc::cudnn::safe::{Conv2dForward, Cudnn};
|
||||
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
// The cudnn handles are stored per thread here rather than on the CudaDevice as they are neither
|
||||
// send nor sync.
|
||||
thread_local! {
|
||||
static CUDNN: RefCell<HashMap<crate::cuda_backend::DeviceId, Arc<Cudnn>>> = HashMap::new().into();
|
||||
}
|
||||
|
||||
impl From<cudarc::cudnn::CudnnError> for crate::Error {
|
||||
fn from(err: cudarc::cudnn::CudnnError) -> Self {
|
||||
crate::Error::wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<cudarc::driver::DriverError> for crate::Error {
|
||||
fn from(err: cudarc::driver::DriverError) -> Self {
|
||||
crate::Error::wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn launch_conv2d<
|
||||
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
||||
>(
|
||||
src: &CudaView<T>,
|
||||
src_l: &crate::Layout,
|
||||
filter: &CudaView<T>,
|
||||
dst: &mut CudaSlice<T>,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
dev: &crate::cuda_backend::CudaDevice,
|
||||
) -> crate::Result<()> {
|
||||
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||
|
||||
let device_id = dev.id();
|
||||
let cudnn = CUDNN.with(|cudnn| {
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
return Ok(cudnn.clone());
|
||||
}
|
||||
let c = Cudnn::new(dev.cuda_device());
|
||||
if let Ok(c) = &c {
|
||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||
}
|
||||
c
|
||||
})?;
|
||||
let conv = cudnn.create_conv2d::<T>(
|
||||
/* pad */ [params.padding as i32, params.padding as i32],
|
||||
/* stride */ [params.stride as i32, params.stride as i32],
|
||||
/* dilation */ [params.dilation as i32, params.dilation as i32],
|
||||
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
||||
)?;
|
||||
let x_shape = [
|
||||
params.b_size as i32,
|
||||
params.c_in as i32,
|
||||
params.i_h as i32,
|
||||
params.i_w as i32,
|
||||
];
|
||||
// Note that `src` already starts at the proper offset.
|
||||
let x = if src_l.is_contiguous() {
|
||||
cudnn.create_4d_tensor(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
x_shape,
|
||||
)?
|
||||
} else {
|
||||
let s = src_l.stride();
|
||||
cudnn.create_4d_tensor_ex(
|
||||
x_shape,
|
||||
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
|
||||
)?
|
||||
};
|
||||
let w = cudnn.create_4d_filter(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[
|
||||
params.c_out as i32,
|
||||
params.c_in as i32,
|
||||
params.k_h as i32,
|
||||
params.k_w as i32,
|
||||
],
|
||||
)?;
|
||||
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
||||
let y = cudnn.create_4d_tensor(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
||||
)?;
|
||||
let conv2d = Conv2dForward {
|
||||
conv: &conv,
|
||||
x: &x,
|
||||
w: &w,
|
||||
y: &y,
|
||||
};
|
||||
let alg = match params.cudnn_fwd_algo {
|
||||
None => conv2d.pick_algorithm()?,
|
||||
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||
}
|
||||
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
||||
alg,
|
||||
Some(&mut workspace),
|
||||
(T::one(), T::zero()),
|
||||
src,
|
||||
filter,
|
||||
dst,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
490
candle-core/src/custom_op.rs
Normal file
490
candle-core/src/custom_op.rs
Normal file
@ -0,0 +1,490 @@
|
||||
use crate::op::{BackpropOp, Op};
|
||||
use crate::tensor::from_storage;
|
||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Unary ops that can be defined in user-land.
|
||||
pub trait CustomOp1 {
|
||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||
/// The function should return the gradient of the argument.
|
||||
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp2 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
_arg2: &Tensor,
|
||||
_res: &Tensor,
|
||||
_grad_res: &Tensor,
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp3 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
_arg2: &Tensor,
|
||||
_arg3: &Tensor,
|
||||
_res: &Tensor,
|
||||
_grad_res: &Tensor,
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Applies a unary custom op without backward support
|
||||
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op without backward support
|
||||
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) =
|
||||
self.storage()
|
||||
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op without backward support
|
||||
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c,
|
||||
)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a unary custom op.
|
||||
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
||||
let (storage, shape) = self
|
||||
.storage()
|
||||
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
||||
self.apply_op1_arc(Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op.
|
||||
pub fn apply_op2_arc(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op2(
|
||||
self.layout(),
|
||||
&rhs.storage(),
|
||||
rhs.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
||||
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op.
|
||||
pub fn apply_op3_arc(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
||||
Op::CustomOp3(t1, t2, t3, c.clone())
|
||||
});
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: C,
|
||||
) -> Result<Self> {
|
||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
}
|
||||
|
||||
// In place ops.
|
||||
|
||||
/// Unary ops that can be defined in user-land.
|
||||
/// These ops work in place and as such back-prop is unsupported.
|
||||
pub trait InplaceOp1 {
|
||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait InplaceOp2 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
|
||||
-> Result<()>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &mut MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<()> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait InplaceOp3 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &mut CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<()>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &mut CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<()> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &mut MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<()> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Applies a unary custom op in place.
|
||||
pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
|
||||
self.storage_mut().inplace_op1(self.layout(), c)
|
||||
}
|
||||
|
||||
/// Applies a unary custom op in place (for the first tensor).
|
||||
pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
|
||||
self.storage_mut()
|
||||
.inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op in place (for the first tensor).
|
||||
pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
|
||||
self.storage_mut().inplace_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UgIOp1 {
|
||||
name: &'static str,
|
||||
#[cfg(feature = "cuda")]
|
||||
func: cudarc::driver::CudaFunction,
|
||||
#[cfg(feature = "metal")]
|
||||
func: metal::ComputePipelineState,
|
||||
}
|
||||
|
||||
impl UgIOp1 {
|
||||
#[allow(unused)]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
kernel: ug::lang::ssa::Kernel,
|
||||
device: &crate::Device,
|
||||
) -> Result<Self> {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let device = device.as_cuda_device()?;
|
||||
let func = device.compile(name, kernel)?;
|
||||
Ok(Self {
|
||||
name,
|
||||
func: func.into_cuda_function(),
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
{
|
||||
let device = device.as_metal_device()?;
|
||||
let func = device.compile(name, kernel)?;
|
||||
Ok(Self { name, func })
|
||||
}
|
||||
#[cfg(not(any(feature = "cuda", feature = "metal")))]
|
||||
{
|
||||
Ok(Self { name })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl InplaceOp1 for UgIOp1 {
|
||||
fn name(&self) -> &'static str {
|
||||
self.name
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
|
||||
crate::bail!("ug ops are only supported on metal/cuda at the moment")
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
|
||||
use crate::backend::BackendStorage;
|
||||
use candle_metal_kernels::utils::EncoderProvider;
|
||||
|
||||
let elem_count = layout.shape().elem_count();
|
||||
if sto.dtype() != crate::DType::F32 {
|
||||
// TODO: support more dtypes.
|
||||
crate::bail!("input is not a f32 tensor")
|
||||
}
|
||||
let device = sto.device();
|
||||
println!("here");
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let command_buffer = &command_buffer;
|
||||
let encoder = command_buffer.encoder();
|
||||
let encoder = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&self.func);
|
||||
let (g, b) = if elem_count % 32 == 0 {
|
||||
(elem_count / 32, 32)
|
||||
} else {
|
||||
(elem_count, 1)
|
||||
};
|
||||
let grid_dims = metal::MTLSize {
|
||||
width: g as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
|
||||
candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));
|
||||
|
||||
encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
|
||||
use crate::cuda_backend::WrapErr;
|
||||
use cudarc::driver::PushKernelArg;
|
||||
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let stream = sto.device.cuda_stream();
|
||||
// TODO: support more dtypes.
|
||||
let sto = sto.as_cuda_slice::<f32>()?;
|
||||
let sto = match layout.contiguous_offsets() {
|
||||
None => crate::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => sto.slice(o1..o2),
|
||||
};
|
||||
let (g, b) = if elem_count % 32 == 0 {
|
||||
(elem_count / 32, 32)
|
||||
} else {
|
||||
(elem_count, 1)
|
||||
};
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (g as u32, 1, 1),
|
||||
block_dim: (b as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let mut builder = stream.launch_builder(&self.func);
|
||||
builder.arg(&sto);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -11,6 +11,7 @@ pub enum DeviceLocation {
|
||||
Metal { gpu_id: usize },
|
||||
}
|
||||
|
||||
/// Cpu, Cuda, or Metal
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
@ -130,6 +131,26 @@ impl Device {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
|
||||
match self {
|
||||
Self::Cuda(d) => Ok(d),
|
||||
Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
|
||||
Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
|
||||
match self {
|
||||
Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
|
||||
Self::Cpu => crate::bail!("expected a metal device, got cpu"),
|
||||
Self::Metal(d) => Ok(d),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
||||
}
|
||||
@ -171,6 +192,22 @@ impl Device {
|
||||
matches!(self, Self::Metal(_))
|
||||
}
|
||||
|
||||
pub fn supports_bf16(&self) -> bool {
|
||||
match self {
|
||||
Self::Cuda(_) | Self::Metal(_) => true,
|
||||
Self::Cpu => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return `BF16` for devices that support it, otherwise default to `F32`.
|
||||
pub fn bf16_default_to_f32(&self) -> DType {
|
||||
if self.supports_bf16() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||
if crate::utils::cuda_is_available() {
|
||||
Self::new_cuda(ordinal)
|
||||
@ -201,10 +238,9 @@ impl Device {
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
Device::Metal(_device) => {
|
||||
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
// Ok(Storage::Metal(storage))
|
||||
crate::bail!("Metal rand_uniform not implemented")
|
||||
Device::Metal(device) => {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -290,17 +326,48 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuDevice.alloc_uninit(shape, dtype)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.alloc_uninit(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.alloc_uninit(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.storage_from_slice(data)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.storage_from_slice(data)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||
Device::Cuda(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
@ -311,14 +378,22 @@ impl Device {
|
||||
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
|
||||
Device::Cuda(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn synchronize(&self) -> Result<()> {
|
||||
match self {
|
||||
Self::Cpu => Ok(()),
|
||||
Self::Cuda(d) => d.synchronize(),
|
||||
Self::Metal(d) => d.synchronize(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
/// Pretty printing of tensors
|
||||
/// This implementation should be in line with the PyTorch version.
|
||||
/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py
|
||||
//! Pretty printing of tensors
|
||||
//!
|
||||
//! This implementation should be in line with the [PyTorch version](https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py).
|
||||
//!
|
||||
use crate::{DType, Result, Tensor, WithDType};
|
||||
use half::{bf16, f16};
|
||||
|
||||
@ -65,12 +66,13 @@ impl std::fmt::Debug for Tensor {
|
||||
}
|
||||
|
||||
/// Options for Tensor pretty printing
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PrinterOptions {
|
||||
precision: usize,
|
||||
threshold: usize,
|
||||
edge_items: usize,
|
||||
line_width: usize,
|
||||
sci_mode: Option<bool>,
|
||||
pub precision: usize,
|
||||
pub threshold: usize,
|
||||
pub edge_items: usize,
|
||||
pub line_width: usize,
|
||||
pub sci_mode: Option<bool>,
|
||||
}
|
||||
|
||||
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
|
||||
@ -89,6 +91,10 @@ impl PrinterOptions {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
|
||||
&PRINT_OPTS
|
||||
}
|
||||
|
||||
pub fn set_print_options(options: PrinterOptions) {
|
||||
*PRINT_OPTS.lock().unwrap() = options
|
||||
}
|
||||
@ -117,6 +123,26 @@ pub fn set_print_options_full() {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_line_width(line_width: usize) {
|
||||
PRINT_OPTS.lock().unwrap().line_width = line_width
|
||||
}
|
||||
|
||||
pub fn set_precision(precision: usize) {
|
||||
PRINT_OPTS.lock().unwrap().precision = precision
|
||||
}
|
||||
|
||||
pub fn set_edge_items(edge_items: usize) {
|
||||
PRINT_OPTS.lock().unwrap().edge_items = edge_items
|
||||
}
|
||||
|
||||
pub fn set_threshold(threshold: usize) {
|
||||
PRINT_OPTS.lock().unwrap().threshold = threshold
|
||||
}
|
||||
|
||||
pub fn set_sci_mode(sci_mode: Option<bool>) {
|
||||
PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
|
||||
}
|
||||
|
||||
struct FmtSize {
|
||||
current_size: usize,
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
//! Types for elements that can be stored and manipulated using tensors.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{CpuStorage, Error, Result};
|
||||
use crate::{CpuStorage, CpuStorageRef, Error, Result};
|
||||
|
||||
/// The different types of elements allowed in tensors.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
@ -23,7 +23,15 @@ pub enum DType {
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct DTypeParseError;
|
||||
pub struct DTypeParseError(String);
|
||||
|
||||
impl std::fmt::Display for DTypeParseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "cannot parse '{}' as a dtype", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for DTypeParseError {}
|
||||
|
||||
impl std::str::FromStr for DType {
|
||||
type Err = DTypeParseError;
|
||||
@ -36,7 +44,7 @@ impl std::str::FromStr for DType {
|
||||
"f16" => Ok(Self::F16),
|
||||
"f32" => Ok(Self::F32),
|
||||
"f64" => Ok(Self::F64),
|
||||
_ => Err(DTypeParseError),
|
||||
_ => Err(DTypeParseError(s.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -92,12 +100,14 @@ pub trait WithDType:
|
||||
+ 'static
|
||||
+ Send
|
||||
+ Sync
|
||||
+ std::any::Any
|
||||
+ crate::cpu::kernels::VecOps
|
||||
{
|
||||
const DTYPE: DType;
|
||||
|
||||
fn from_f64(v: f64) -> Self;
|
||||
fn to_f64(self) -> f64;
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||
@ -121,6 +131,10 @@ macro_rules! with_dtype {
|
||||
$to_f64(self)
|
||||
}
|
||||
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
||||
CpuStorageRef::$dtype(data)
|
||||
}
|
||||
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||
CpuStorage::$dtype(data)
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
//! Implementation of the Cuda backend when Cuda support has not been compiled in.
|
||||
//!
|
||||
#![allow(dead_code)]
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||
@ -14,6 +16,12 @@ macro_rules! fail {
|
||||
};
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn new_with_stream(_: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::backend::BackendStorage for CudaStorage {
|
||||
type Device = CudaDevice;
|
||||
|
||||
@ -154,6 +162,19 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -197,10 +218,22 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -208,4 +241,38 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
pub fn gemm_reduced_precision_f16() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f16(_: bool) {}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with bf16 GEMMs.
|
||||
pub fn gemm_reduced_precision_bf16() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with bf16 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn gemm_reduced_precision_f32() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f32(_b: bool) {}
|
||||
|
@ -166,6 +166,19 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
@ -209,10 +222,22 @@ impl crate::backend::BackendDevice for MetalDevice {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
@ -220,4 +245,8 @@ impl crate::backend::BackendDevice for MetalDevice {
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
//! Candle-specific Error and Result
|
||||
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -8,8 +9,14 @@ pub struct MatMulUnexpectedStriding {
|
||||
pub msg: &'static str,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{self}")
|
||||
}
|
||||
}
|
||||
|
||||
/// Main library error type.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[derive(thiserror::Error)]
|
||||
pub enum Error {
|
||||
// === DType Errors ===
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
@ -165,6 +172,10 @@ pub enum Error {
|
||||
#[error("Metal error {0}")]
|
||||
Metal(#[from] MetalError),
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[error(transparent)]
|
||||
Ug(#[from] ug::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||
|
||||
@ -179,6 +190,10 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
ParseInt(#[from] std::num::ParseIntError),
|
||||
|
||||
/// Utf8 parse error.
|
||||
#[error(transparent)]
|
||||
FromUtf8(#[from] std::string::FromUtf8Error),
|
||||
|
||||
/// I/O error.
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
@ -191,8 +206,14 @@ pub enum Error {
|
||||
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
||||
|
||||
/// Arbitrary errors wrapping.
|
||||
#[error(transparent)]
|
||||
Wrapped(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("{0}")]
|
||||
Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
|
||||
|
||||
#[error("{context}\n{inner}")]
|
||||
Context {
|
||||
inner: Box<Self>,
|
||||
context: Box<dyn std::fmt::Display + Send + Sync>,
|
||||
},
|
||||
|
||||
/// Adding path information to an error.
|
||||
#[error("path: {path:?} {inner}")]
|
||||
@ -210,19 +231,26 @@ pub enum Error {
|
||||
/// User generated error message, typically created via `bail!`.
|
||||
#[error("{0}")]
|
||||
Msg(String),
|
||||
|
||||
#[error("unwrap none")]
|
||||
UnwrapNone,
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
impl Error {
|
||||
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
|
||||
Self::Wrapped(Box::new(err)).bt()
|
||||
}
|
||||
|
||||
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
pub fn msg(err: impl std::fmt::Display) -> Self {
|
||||
Self::Msg(err.to_string()).bt()
|
||||
}
|
||||
|
||||
pub fn debug(err: impl std::fmt::Debug) -> Self {
|
||||
Self::Msg(format!("{err:?}")).bt()
|
||||
}
|
||||
|
||||
pub fn bt(self) -> Self {
|
||||
let backtrace = std::backtrace::Backtrace::capture();
|
||||
match backtrace.status() {
|
||||
@ -241,6 +269,13 @@ impl Error {
|
||||
path: p.as_ref().to_path_buf(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
|
||||
Self::Context {
|
||||
inner: Box::new(self),
|
||||
context: Box::new(c),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
@ -263,3 +298,41 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
|
||||
(_, Err(e)) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
// Taken from anyhow.
|
||||
pub trait Context<T> {
|
||||
/// Wrap the error value with additional context.
|
||||
fn context<C>(self, context: C) -> Result<T>
|
||||
where
|
||||
C: std::fmt::Display + Send + Sync + 'static;
|
||||
|
||||
/// Wrap the error value with additional context that is evaluated lazily
|
||||
/// only once an error does occur.
|
||||
fn with_context<C, F>(self, f: F) -> Result<T>
|
||||
where
|
||||
C: std::fmt::Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C;
|
||||
}
|
||||
|
||||
impl<T> Context<T> for Option<T> {
|
||||
fn context<C>(self, context: C) -> Result<T>
|
||||
where
|
||||
C: std::fmt::Display + Send + Sync + 'static,
|
||||
{
|
||||
match self {
|
||||
Some(v) => Ok(v),
|
||||
None => Err(Error::UnwrapNone.context(context).bt()),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_context<C, F>(self, f: F) -> Result<T>
|
||||
where
|
||||
C: std::fmt::Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C,
|
||||
{
|
||||
match self {
|
||||
Some(v) => Ok(v),
|
||||
None => Err(Error::UnwrapNone.context(f()).bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ impl Tensor {
|
||||
#[derive(Debug)]
|
||||
/// Generic structure used to index a slice of the tensor
|
||||
pub enum TensorIndexer {
|
||||
/// This selects the elemnts for which an index has some specific value.
|
||||
/// This selects the elements for which an index has some specific value.
|
||||
Select(usize),
|
||||
/// This is a regular slice, purely indexing a chunk of the tensor
|
||||
Narrow(Bound<usize>, Bound<usize>),
|
||||
@ -141,28 +141,117 @@ impl<T> IndexOp<T> for Tensor
|
||||
where
|
||||
T: Into<TensorIndexer>,
|
||||
{
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, DType, Device, IndexOp};
|
||||
/// let a = Tensor::new(&[
|
||||
/// [0., 1.],
|
||||
/// [2., 3.],
|
||||
/// [4., 5.]
|
||||
/// ], &Device::Cpu)?;
|
||||
///
|
||||
/// let b = a.i(0)?;
|
||||
/// assert_eq!(b.shape().dims(), &[2]);
|
||||
/// assert_eq!(b.to_vec1::<f64>()?, &[0., 1.]);
|
||||
///
|
||||
/// let c = a.i(..2)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 2]);
|
||||
/// assert_eq!(c.to_vec2::<f64>()?, &[
|
||||
/// [0., 1.],
|
||||
/// [2., 3.]
|
||||
/// ]);
|
||||
///
|
||||
/// let d = a.i(1..)?;
|
||||
/// assert_eq!(d.shape().dims(), &[2, 2]);
|
||||
/// assert_eq!(d.to_vec2::<f64>()?, &[
|
||||
/// [2., 3.],
|
||||
/// [4., 5.]
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
fn i(&self, index: T) -> Result<Tensor, Error> {
|
||||
self.index(&[index.into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> IndexOp<(A,)> for Tensor
|
||||
where
|
||||
A: Into<TensorIndexer>,
|
||||
{
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, DType, Device, IndexOp};
|
||||
/// let a = Tensor::new(&[
|
||||
/// [0f32, 1.],
|
||||
/// [2. , 3.],
|
||||
/// [4. , 5.]
|
||||
/// ], &Device::Cpu)?;
|
||||
///
|
||||
/// let b = a.i((0,))?;
|
||||
/// assert_eq!(b.shape().dims(), &[2]);
|
||||
/// assert_eq!(b.to_vec1::<f32>()?, &[0., 1.]);
|
||||
///
|
||||
/// let c = a.i((..2,))?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 2]);
|
||||
/// assert_eq!(c.to_vec2::<f32>()?, &[
|
||||
/// [0., 1.],
|
||||
/// [2., 3.]
|
||||
/// ]);
|
||||
///
|
||||
/// let d = a.i((1..,))?;
|
||||
/// assert_eq!(d.shape().dims(), &[2, 2]);
|
||||
/// assert_eq!(d.to_vec2::<f32>()?, &[
|
||||
/// [2., 3.],
|
||||
/// [4., 5.]
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
fn i(&self, (a,): (A,)) -> Result<Tensor, Error> {
|
||||
self.index(&[a.into()])
|
||||
}
|
||||
}
|
||||
#[allow(non_snake_case)]
|
||||
impl<A, B> IndexOp<(A, B)> for Tensor
|
||||
where
|
||||
A: Into<TensorIndexer>,
|
||||
B: Into<TensorIndexer>,
|
||||
{
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, DType, Device, IndexOp};
|
||||
/// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?;
|
||||
///
|
||||
/// let b = a.i((1, 0))?;
|
||||
/// assert_eq!(b.to_vec0::<f32>()?, 3.);
|
||||
///
|
||||
/// let c = a.i((..2, 1))?;
|
||||
/// assert_eq!(c.shape().dims(), &[2]);
|
||||
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||
///
|
||||
/// let d = a.i((2.., ..))?;
|
||||
/// assert_eq!(c.shape().dims(), &[2]);
|
||||
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
|
||||
self.index(&[a.into(), b.into()])
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! index_op_tuple {
|
||||
($($t:ident),+) => {
|
||||
($doc:tt, $($t:ident),+) => {
|
||||
#[allow(non_snake_case)]
|
||||
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
|
||||
where
|
||||
$($t: Into<TensorIndexer>,)*
|
||||
{
|
||||
#[doc=$doc]
|
||||
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
|
||||
self.index(&[$($t.into(),)*])
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
index_op_tuple!(A);
|
||||
index_op_tuple!(A, B);
|
||||
index_op_tuple!(A, B, C);
|
||||
index_op_tuple!(A, B, C, D);
|
||||
index_op_tuple!(A, B, C, D, E);
|
||||
index_op_tuple!(A, B, C, D, E, F);
|
||||
index_op_tuple!(A, B, C, D, E, F, G);
|
||||
|
||||
index_op_tuple!("see [TensorIndex#method.i]", A, B, C);
|
||||
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D);
|
||||
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E);
|
||||
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F);
|
||||
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G);
|
||||
|
@ -1,3 +1,4 @@
|
||||
//! Tensor Layouts including contiguous or sparse strides
|
||||
use crate::{Error, Result, Shape};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
@ -35,6 +36,12 @@ impl Layout {
|
||||
self.shape.dims()
|
||||
}
|
||||
|
||||
/// The dimension size for a specified dimension index.
|
||||
pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
|
||||
let dim = dim.to_index(&self.shape, "dim")?;
|
||||
Ok(self.dims()[dim])
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape {
|
||||
&self.shape
|
||||
}
|
||||
@ -70,7 +77,7 @@ impl Layout {
|
||||
self.shape.is_fortran_contiguous(&self.stride)
|
||||
}
|
||||
|
||||
pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.shape().dims();
|
||||
if dim >= dims.len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
@ -99,7 +106,7 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||
let rank = self.shape.rank();
|
||||
if rank <= dim1 || rank <= dim2 {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
@ -120,7 +127,7 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
let is_permutation =
|
||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||
if !is_permutation {
|
||||
|
@ -7,14 +7,14 @@
|
||||
//!
|
||||
//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
|
||||
//! let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
|
||||
//!
|
||||
//! let c = a.matmul(&b)?;
|
||||
//!
|
||||
//! # Ok(())}
|
||||
//! ```
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - Simple syntax (looks and like PyTorch)
|
||||
//! - Simple syntax (looks and feels like PyTorch)
|
||||
//! - CPU and Cuda backends (and M1 support)
|
||||
//! - Enable serverless (CPU) small and fast deployments
|
||||
//! - Model training
|
||||
@ -32,23 +32,36 @@
|
||||
//! Python can really add overhead in more complex workflows and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
|
||||
//!
|
||||
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
|
||||
//!
|
||||
//! ## Other Crates
|
||||
//!
|
||||
//! Candle consists of a number of crates. This crate holds core the common data structures but you may wish
|
||||
//! to look at the docs for the other crates which can be found here:
|
||||
//!
|
||||
//! - [candle-core](https://docs.rs/candle-core/). Core Datastructures and DataTypes.
|
||||
//! - [candle-nn](https://docs.rs/candle-nn/). Building blocks for Neural Nets.
|
||||
//! - [candle-datasets](https://docs.rs/candle-datasets/). Rust access to commonly used Datasets like MNIST.
|
||||
//! - [candle-examples](https://docs.rs/candle-examples/). Examples of Candle in Use.
|
||||
//! - [candle-onnx](https://docs.rs/candle-onnx/). Loading and using ONNX models.
|
||||
//! - [candle-pyo3](https://docs.rs/candle-pyo3/). Access to Candle from Python.
|
||||
//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models.
|
||||
//!
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
mod accelerate;
|
||||
pub mod backend;
|
||||
pub mod backprop;
|
||||
mod conv;
|
||||
pub mod conv;
|
||||
mod convert;
|
||||
pub mod cpu;
|
||||
pub mod cpu_backend;
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod cuda_backend;
|
||||
#[cfg(feature = "cudnn")]
|
||||
pub mod cudnn;
|
||||
mod custom_op;
|
||||
mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
mod dummy_cuda_backend;
|
||||
pub mod dummy_cuda_backend;
|
||||
mod dummy_metal_backend;
|
||||
pub mod error;
|
||||
mod indexer;
|
||||
@ -58,37 +71,46 @@ pub mod metal_backend;
|
||||
#[cfg(feature = "mkl")]
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
mod op;
|
||||
pub mod op;
|
||||
pub mod pickle;
|
||||
pub mod quantized;
|
||||
pub mod safetensors;
|
||||
pub mod scalar;
|
||||
pub mod shape;
|
||||
mod sort;
|
||||
mod storage;
|
||||
pub mod streaming;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
mod tensor_cat;
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
mod variable;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use device::{Device, DeviceLocation};
|
||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use indexer::IndexOp;
|
||||
#[cfg(feature = "cudnn")]
|
||||
pub use cuda_backend::cudnn;
|
||||
|
||||
pub use cpu_backend::{CpuStorage, CpuStorageRef};
|
||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
|
||||
pub use device::{Device, DeviceLocation, NdArray};
|
||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||
pub use error::{Context, Error, Result};
|
||||
pub use indexer::{IndexOp, TensorIndexer};
|
||||
pub use layout::Layout;
|
||||
pub use op::{CustomOp1, CustomOp2, CustomOp3};
|
||||
pub use shape::{Shape, D};
|
||||
pub use storage::Storage;
|
||||
pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule};
|
||||
pub use strided_index::{StridedBlocks, StridedIndex};
|
||||
pub use tensor::{Tensor, TensorId};
|
||||
pub use variable::Var;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||
pub use cuda_backend as cuda;
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||
pub use dummy_cuda_backend as cuda;
|
||||
|
||||
pub use cuda::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||
@ -118,7 +140,7 @@ impl ToUsize2 for (usize, usize) {
|
||||
}
|
||||
}
|
||||
|
||||
// A simple trait defining a module with forward method using a single argument.
|
||||
/// Defining a module with forward method using a single argument.
|
||||
pub trait Module {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
}
|
||||
@ -129,8 +151,17 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
}
|
||||
}
|
||||
|
||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||
// separate the training and evaluation behaviors.
|
||||
impl<M: Module> Module for Option<&M> {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
None => Ok(xs.clone()),
|
||||
Some(m) => m.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single forward method using a single single tensor argument and a flag to
|
||||
/// separate the training and evaluation behaviors.
|
||||
pub trait ModuleT {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
340
candle-core/src/metal_backend/device.rs
Normal file
340
candle-core/src/metal_backend/device.rs
Normal file
@ -0,0 +1,340 @@
|
||||
use crate::{DType, Result};
|
||||
use candle_metal_kernels::Kernels;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
|
||||
use super::MetalError;
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
pub(crate) fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||
pub(crate) struct Commands {
|
||||
/// Single command queue for the entire device.
|
||||
command_queue: CommandQueue,
|
||||
/// One command buffer at a time.
|
||||
/// The scheduler works by allowing multiple
|
||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||
/// to start to work).
|
||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||
/// command buffer2 starts (or there are metal bugs there)
|
||||
command_buffer: CommandBuffer,
|
||||
/// Keeps track of the current amount of compute command encoders on the current
|
||||
/// command buffer
|
||||
/// Arc, RwLock because of the interior mutability.
|
||||
command_buffer_index: usize,
|
||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||
compute_per_buffer: usize,
|
||||
}
|
||||
|
||||
impl Commands {
|
||||
pub(crate) fn new(command_queue: CommandQueue) -> Result<Self> {
|
||||
let command_buffer = command_queue.new_command_buffer().to_owned();
|
||||
command_buffer.enqueue();
|
||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 50,
|
||||
};
|
||||
Ok(Self {
|
||||
command_queue,
|
||||
command_buffer,
|
||||
command_buffer_index: 0,
|
||||
compute_per_buffer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> {
|
||||
let mut command_buffer = self.command_buffer.to_owned();
|
||||
let mut flushed = false;
|
||||
if self.command_buffer_index > self.compute_per_buffer {
|
||||
self.command_buffer.commit();
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
self.command_buffer = command_buffer.clone();
|
||||
self.command_buffer_index = 0;
|
||||
flushed = true;
|
||||
}
|
||||
self.command_buffer_index += 1;
|
||||
Ok((flushed, command_buffer))
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&mut self) -> Result<()> {
|
||||
match self.command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
| metal::MTLCommandBufferStatus::Completed => {
|
||||
panic!("Already committed");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
self.command_buffer.commit();
|
||||
self.command_buffer.wait_until_completed();
|
||||
self.command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
|
||||
/// the device itself.
|
||||
pub(crate) id: DeviceId,
|
||||
|
||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||
pub(crate) device: metal::Device,
|
||||
|
||||
pub(crate) commands: Arc<RwLock<Commands>>,
|
||||
|
||||
/// Simple allocator struct.
|
||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||
/// (could be linked to FFI communication overhead).
|
||||
///
|
||||
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
|
||||
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
|
||||
/// to be reused. However, in order for this to work, we need to guarantee the order of
|
||||
/// operation, so that this buffer is not being used by another kernel at the same time.
|
||||
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
||||
///
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||
/// (strong_count = 1).
|
||||
pub(crate) buffers: Arc<RwLock<BufferMap>>,
|
||||
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`]
|
||||
pub(crate) kernels: Arc<Kernels>,
|
||||
/// Seed for random number generation.
|
||||
pub(crate) seed: Arc<Mutex<Buffer>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MetalDevice({:?})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for MetalDevice {
|
||||
type Target = metal::DeviceRef;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalDevice {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn compile(
|
||||
&self,
|
||||
func_name: &'static str,
|
||||
kernel: ug::lang::ssa::Kernel,
|
||||
) -> Result<metal::ComputePipelineState> {
|
||||
let mut buf = vec![];
|
||||
ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?;
|
||||
let metal_code = String::from_utf8(buf)?;
|
||||
let lib = self
|
||||
.device
|
||||
.new_library_with_source(&metal_code, &metal::CompileOptions::new())
|
||||
.map_err(MetalError::from)?;
|
||||
let func = lib
|
||||
.get_function(func_name, None)
|
||||
.map_err(MetalError::from)?;
|
||||
let pl = self
|
||||
.device
|
||||
.new_compute_pipeline_state_with_function(&func)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(pl)
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DeviceId {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub fn metal_device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
fn drop_unused_buffers(&self) -> Result<()> {
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
.filter(|s| Arc::strong_count(*s) > 1)
|
||||
.map(Arc::clone)
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||
let mut commands = self.commands.write().map_err(MetalError::from)?;
|
||||
let (flushed, command_buffer) = commands.command_buffer()?;
|
||||
if flushed {
|
||||
self.drop_unused_buffers()?
|
||||
}
|
||||
Ok(command_buffer)
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) -> Result<()> {
|
||||
let mut commands = self.commands.write().map_err(MetalError::from)?;
|
||||
commands.wait_until_completed()
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
&self.kernels
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
/// Creates a new buffer (not necessarily zeroed).
|
||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// This means the buffer data cannot be read on the CPU directly.
|
||||
///
|
||||
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
||||
pub fn new_buffer(
|
||||
&self,
|
||||
element_count: usize,
|
||||
dtype: DType,
|
||||
name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||
}
|
||||
|
||||
/// Creates a new buffer (not necessarily zeroed).
|
||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// This means the buffer can be read on the CPU but will require manual
|
||||
/// synchronization when the CPU memory is modified
|
||||
/// Used as a bridge to gather data back from the GPU
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||
}
|
||||
|
||||
/// Creates a new buffer from data.
|
||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
///
|
||||
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
||||
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let new_buffer = self.device.new_buffer_with_data(
|
||||
data.as_ptr().cast(),
|
||||
size,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
|
||||
let subbuffers = buffers
|
||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||
.or_insert(vec![]);
|
||||
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
Ok(new_buffer)
|
||||
}
|
||||
|
||||
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
||||
let buffer = self.allocate_buffer(
|
||||
size_in_bytes as NSUInteger,
|
||||
MTLResourceOptions::StorageModePrivate,
|
||||
"allocate_zeros",
|
||||
)?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.fill_buffer(
|
||||
&buffer,
|
||||
metal::NSRange {
|
||||
location: 0,
|
||||
length: buffer.length(),
|
||||
},
|
||||
0,
|
||||
);
|
||||
blit.end_encoding();
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
/// The critical allocator algorithm
|
||||
fn allocate_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
_name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
if let Some(b) = find_available_buffer(size, option, &buffers) {
|
||||
// Cloning also ensures we increment the strong count
|
||||
return Ok(b.clone());
|
||||
}
|
||||
|
||||
let size = buf_size(size);
|
||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||
|
||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
|
||||
Ok(new_buffer)
|
||||
}
|
||||
|
||||
/// Create a metal GPU capture trace on [`path`].
|
||||
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||
let capture = metal::CaptureManager::shared();
|
||||
let descriptor = metal::CaptureDescriptor::new();
|
||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||
descriptor.set_capture_device(self);
|
||||
// The [set_output_url] call requires an absolute path so we convert it if needed.
|
||||
if path.as_ref().is_absolute() {
|
||||
descriptor.set_output_url(path);
|
||||
} else {
|
||||
let path = std::env::current_dir()?.join(path);
|
||||
descriptor.set_output_url(path);
|
||||
}
|
||||
|
||||
capture
|
||||
.start_capture(&descriptor)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||
size.saturating_sub(1).next_power_of_two() as NSUInteger
|
||||
}
|
||||
|
||||
fn find_available_buffer(
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
buffers: &BufferMap,
|
||||
) -> Option<Arc<Buffer>> {
|
||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||
for sub in subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
best_buffer = Some(sub);
|
||||
best_buffer_size = *buffer_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
best_buffer.cloned()
|
||||
}
|
2134
candle-core/src/metal_backend/mod.rs
Normal file
2134
candle-core/src/metal_backend/mod.rs
Normal file
File diff suppressed because it is too large
Load Diff
@ -333,6 +333,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
@ -355,6 +365,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vs_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vd_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -330,7 +330,7 @@ impl Tensor {
|
||||
path: P,
|
||||
) -> Result<()> {
|
||||
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
|
||||
let options =
|
||||
let options: zip::write::FileOptions<()> =
|
||||
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
|
||||
|
||||
for (name, tensor) in ts.iter() {
|
||||
|
@ -1,5 +1,7 @@
|
||||
//! Tensor Opertion Enums and Traits
|
||||
//!
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||
use crate::Tensor;
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
|
||||
@ -61,10 +63,12 @@ pub enum UnaryOp {
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Silu,
|
||||
Tanh,
|
||||
Floor,
|
||||
Ceil,
|
||||
Round,
|
||||
Sign,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -131,8 +135,15 @@ pub enum Op {
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
UpsampleNearest1D(Tensor),
|
||||
UpsampleNearest2D(Tensor),
|
||||
UpsampleNearest1D {
|
||||
arg: Tensor,
|
||||
target_size: usize,
|
||||
},
|
||||
UpsampleNearest2D {
|
||||
arg: Tensor,
|
||||
target_h: usize,
|
||||
target_w: usize,
|
||||
},
|
||||
|
||||
Cat(Vec<Tensor>, usize),
|
||||
|
||||
@ -153,168 +164,23 @@ pub enum Op {
|
||||
Permute(Tensor, Vec<usize>),
|
||||
Elu(Tensor, f64),
|
||||
Powf(Tensor, f64),
|
||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
|
||||
CustomOp1(
|
||||
Tensor,
|
||||
std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
|
||||
),
|
||||
CustomOp2(
|
||||
Tensor,
|
||||
Tensor,
|
||||
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||
std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
|
||||
),
|
||||
CustomOp3(
|
||||
Tensor,
|
||||
Tensor,
|
||||
Tensor,
|
||||
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||
std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
|
||||
),
|
||||
}
|
||||
|
||||
/// Unary ops that can be defined in user-land.
|
||||
pub trait CustomOp1 {
|
||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||
/// The function should return the gradient of the argument.
|
||||
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp2 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
_arg2: &Tensor,
|
||||
_res: &Tensor,
|
||||
_grad_res: &Tensor,
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp3 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
_arg2: &Tensor,
|
||||
_arg3: &Tensor,
|
||||
_res: &Tensor,
|
||||
_grad_res: &Tensor,
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait UnaryOpT {
|
||||
const NAME: &'static str;
|
||||
const KERNEL: &'static str;
|
||||
@ -386,10 +252,12 @@ pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Silu;
|
||||
pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
pub(crate) struct Round;
|
||||
pub(crate) struct Sign;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
@ -593,6 +461,13 @@ unary_op!(Recip, "recip", v, v.recip());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
// Hardcode the value for sqrt(2/pi)
|
||||
// https://github.com/huggingface/candle/issues/1982
|
||||
#[allow(clippy::excessive_precision)]
|
||||
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
|
||||
#[allow(clippy::excessive_precision)]
|
||||
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
|
||||
|
||||
/// Tanh based approximation of the `gelu` operation
|
||||
/// GeluErf is the more precise one.
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
@ -605,7 +480,7 @@ impl UnaryOpT for Gelu {
|
||||
* v
|
||||
* (bf16::ONE
|
||||
+ bf16::tanh(
|
||||
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
|
||||
bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
||||
* v
|
||||
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
||||
))
|
||||
@ -616,22 +491,18 @@ impl UnaryOpT for Gelu {
|
||||
* v
|
||||
* (f16::ONE
|
||||
+ f16::tanh(
|
||||
(f16::from_f32_const(2.0) / f16::PI).sqrt()
|
||||
f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
||||
* v
|
||||
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
||||
))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
0.5 * v
|
||||
* (1.0
|
||||
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
0.5 * v
|
||||
* (1.0
|
||||
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
@ -720,6 +591,77 @@ impl UnaryOpT for Erf {
|
||||
}
|
||||
}
|
||||
|
||||
/// Silu operation
|
||||
impl UnaryOpT for Silu {
|
||||
const NAME: &'static str = "silu";
|
||||
const V: Self = Silu;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v / (bf16::ONE + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v / (f16::ONE + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v / (1.0 + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v / (1.0 + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
const KERNEL: &'static str = "usilu";
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::mkl::vs_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::vd_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::accelerate::vs_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::accelerate::vd_silu(xs, ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
const KERNEL: &'static str = "uabs";
|
||||
@ -987,3 +929,37 @@ impl std::ops::Deref for BackpropOp {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Sign {
|
||||
const NAME: &'static str = "sign";
|
||||
const KERNEL: &'static str = "usign";
|
||||
const V: Self = Sign;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
f32::from(v > 0.) - f32::from(v < 0.)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
f64::from(v > 0.) - f64::from(v < 0.)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
u8::min(1, v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
u32::min(1, v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
(v > 0) as i64 - (v < 0) as i64
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Just enough pickle support to be able to read PyTorch checkpoints.
|
||||
//! Just enough pickle support to be able to read PyTorch checkpoints.
|
||||
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
|
||||
// composable/tensor agnostic at some point.
|
||||
use crate::{DType, Error as E, Layout, Result, Tensor};
|
||||
use crate::{Context, DType, Error as E, Layout, Result, Tensor};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
use std::io::BufRead;
|
||||
@ -42,9 +42,10 @@ pub enum OpCode {
|
||||
Stop = b'.',
|
||||
NewObj = 0x81,
|
||||
EmptyList = b']',
|
||||
BinFloat = b'g',
|
||||
BinFloat = b'G',
|
||||
Append = b'a',
|
||||
Appends = b'e',
|
||||
Long1 = 0x8a,
|
||||
}
|
||||
|
||||
// Avoid using FromPrimitive so as not to drag another dependency.
|
||||
@ -84,6 +85,7 @@ impl TryFrom<u8> for OpCode {
|
||||
b'G' => Ok(Self::BinFloat),
|
||||
b'a' => Ok(Self::Append),
|
||||
b'e' => Ok(Self::Appends),
|
||||
0x8a => Ok(Self::Long1),
|
||||
value => Err(value),
|
||||
}
|
||||
}
|
||||
@ -106,6 +108,7 @@ pub enum Object {
|
||||
class_name: String,
|
||||
},
|
||||
Int(i32),
|
||||
Long(i64),
|
||||
Float(f64),
|
||||
Unicode(String),
|
||||
Bool(bool),
|
||||
@ -170,6 +173,14 @@ impl Object {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn int_or_long(self) -> OResult<i64> {
|
||||
match self {
|
||||
Self::Int(t) => Ok(t as i64),
|
||||
Self::Long(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tuple(self) -> OResult<Vec<Self>> {
|
||||
match self {
|
||||
Self::Tuple(t) => Ok(t),
|
||||
@ -217,6 +228,13 @@ impl Object {
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
|
||||
let mut args = args.tuple()?;
|
||||
args.remove(0).reduce()?
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
@ -227,13 +245,11 @@ impl Object {
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||
let mut path = dir_name.to_path_buf();
|
||||
path.push(file_path);
|
||||
Ok(Some(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
|
||||
storage_size,
|
||||
}))
|
||||
}
|
||||
@ -345,8 +361,10 @@ impl Stack {
|
||||
module_name,
|
||||
class_name,
|
||||
} => {
|
||||
if module_name == "collections" && class_name == "OrderedDict" {
|
||||
// TODO: have a separate ordered dict.
|
||||
if module_name == "collections"
|
||||
&& (class_name == "OrderedDict" || class_name == "defaultdict")
|
||||
{
|
||||
// TODO: have a separate ordered dict and a separate default dict.
|
||||
Some(Object::Dict(vec![]))
|
||||
} else {
|
||||
None
|
||||
@ -455,7 +473,10 @@ impl Stack {
|
||||
self.push(Object::Int(arg))
|
||||
}
|
||||
OpCode::BinFloat => {
|
||||
let arg = r.read_f64::<LittleEndian>()?;
|
||||
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian.
|
||||
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
|
||||
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
|
||||
let arg = r.read_f64::<byteorder::BigEndian>()?;
|
||||
self.push(Object::Float(arg))
|
||||
}
|
||||
OpCode::BinUnicode => {
|
||||
@ -527,7 +548,7 @@ impl Stack {
|
||||
crate::bail!("setitems: not an even number of objects")
|
||||
}
|
||||
while let Some(value) = objs.pop() {
|
||||
let key = objs.pop().unwrap();
|
||||
let key = objs.pop().context("empty objs")?;
|
||||
d.push((key, value))
|
||||
}
|
||||
} else {
|
||||
@ -547,7 +568,7 @@ impl Stack {
|
||||
crate::bail!("setitems: not an even number of objects")
|
||||
}
|
||||
while let Some(value) = objs.pop() {
|
||||
let key = objs.pop().unwrap();
|
||||
let key = objs.pop().context("empty objs")?;
|
||||
pydict.push((key, value))
|
||||
}
|
||||
self.push(Object::Dict(pydict))
|
||||
@ -580,6 +601,15 @@ impl Stack {
|
||||
let obj = self.new_obj(class, args)?;
|
||||
self.push(obj)
|
||||
}
|
||||
OpCode::Long1 => {
|
||||
let n_bytes = r.read_u8()?;
|
||||
let mut v = 0;
|
||||
// Decode the next n bytes in little endian
|
||||
for i in 0..n_bytes {
|
||||
v |= (r.read_u8()? as i64) << (i * 8);
|
||||
}
|
||||
self.push(Object::Long(v))
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
@ -597,10 +627,10 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
||||
let mut args = args.tuple()?;
|
||||
let stride = Vec::<usize>::try_from(args.remove(3))?;
|
||||
let size = Vec::<usize>::try_from(args.remove(2))?;
|
||||
let offset = args.remove(1).int()? as usize;
|
||||
let offset = args.remove(1).int_or_long()? as usize;
|
||||
let storage = args.remove(0).persistent_load()?;
|
||||
let mut storage = storage.tuple()?;
|
||||
let storage_size = storage.remove(4).int()? as usize;
|
||||
let storage_size = storage.remove(4).int_or_long()? as usize;
|
||||
let path = storage.remove(2).unicode()?;
|
||||
let (_module_name, class_name) = storage.remove(1).class()?;
|
||||
let dtype = match class_name.as_str() {
|
||||
@ -614,7 +644,11 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
||||
crate::bail!("unsupported storage type {other}")
|
||||
}
|
||||
};
|
||||
let layout = Layout::new(crate::Shape::from(size), stride, offset);
|
||||
let layout = Layout::new(
|
||||
crate::Shape::from(size),
|
||||
stride,
|
||||
offset * dtype.size_in_bytes(),
|
||||
);
|
||||
Ok((layout, dtype, path, storage_size))
|
||||
}
|
||||
|
||||
@ -627,9 +661,16 @@ pub struct TensorInfo {
|
||||
pub storage_size: usize,
|
||||
}
|
||||
|
||||
/// Read the tensor info from a .pth file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `file` - The path to the .pth file.
|
||||
/// * `verbose` - Whether to print debug information.
|
||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
|
||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
file: P,
|
||||
verbose: bool,
|
||||
key: Option<&str>,
|
||||
) -> Result<Vec<TensorInfo>> {
|
||||
let file = std::fs::File::open(file)?;
|
||||
let zip_reader = std::io::BufReader::new(file);
|
||||
@ -644,15 +685,16 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
if !file_name.ends_with("data.pkl") {
|
||||
continue;
|
||||
}
|
||||
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
|
||||
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
|
||||
let reader = zip.by_name(file_name)?;
|
||||
let mut reader = std::io::BufReader::new(reader);
|
||||
let mut stack = Stack::empty();
|
||||
stack.read_loop(&mut reader)?;
|
||||
let obj = stack.finalize()?;
|
||||
if VERBOSE || verbose {
|
||||
println!("{obj:?}");
|
||||
println!("{obj:#?}");
|
||||
}
|
||||
|
||||
let obj = match obj {
|
||||
Object::Build { callable, args } => match *callable {
|
||||
Object::Reduce { callable, args: _ } => match *callable {
|
||||
@ -666,6 +708,24 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
},
|
||||
obj => obj,
|
||||
};
|
||||
|
||||
// If key is provided, then we need to extract the state_dict from the object.
|
||||
let obj = if let Some(key) = key {
|
||||
if let Object::Dict(key_values) = obj {
|
||||
key_values
|
||||
.into_iter()
|
||||
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
|
||||
.map(|(_, v)| v)
|
||||
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
|
||||
} else {
|
||||
obj
|
||||
}
|
||||
} else {
|
||||
obj
|
||||
};
|
||||
|
||||
// If the object is a dict, then we can extract the tensor info from it.
|
||||
// NOTE: We are assuming that the `obj` is state_dict by this stage.
|
||||
if let Object::Dict(key_values) = obj {
|
||||
for (name, value) in key_values.into_iter() {
|
||||
match value.into_tensor_info(name, &dir_name) {
|
||||
@ -688,8 +748,8 @@ pub struct PthTensors {
|
||||
}
|
||||
|
||||
impl PthTensors {
|
||||
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
|
||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
|
||||
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
|
||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
|
||||
let tensor_infos = tensor_infos
|
||||
.into_iter()
|
||||
.map(|ti| (ti.name.to_string(), ti))
|
||||
@ -703,6 +763,7 @@ impl PthTensors {
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
use std::io::Read;
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
None => return Ok(None),
|
||||
Some(tensor_info) => tensor_info,
|
||||
@ -711,27 +772,56 @@ impl PthTensors {
|
||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
|
||||
let rank = tensor_info.layout.shape().rank();
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
||||
// For now only support the basic case.
|
||||
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||
// case and when the tensor is fortran contiguous.
|
||||
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
)
|
||||
}
|
||||
let start_offset = tensor_info.layout.start_offset();
|
||||
if start_offset > 0 {
|
||||
std::io::copy(
|
||||
&mut reader.by_ref().take(start_offset as u64),
|
||||
&mut std::io::sink(),
|
||||
)?;
|
||||
}
|
||||
let tensor = Tensor::from_reader(
|
||||
tensor_info.layout.shape().clone(),
|
||||
tensor_info.dtype,
|
||||
&mut reader,
|
||||
)?;
|
||||
Ok(Some(tensor))
|
||||
|
||||
if rank > 1 && is_fortran_contiguous {
|
||||
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
|
||||
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
|
||||
let tensor = tensor.reshape(shape_reversed)?;
|
||||
|
||||
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
|
||||
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
|
||||
let tensor = tensor.permute(dim_indeces_reversed)?;
|
||||
Ok(Some(tensor))
|
||||
} else {
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file.
|
||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||
let pth = PthTensors::new(path)?;
|
||||
/// Read all the tensors from a PyTorch pth file with a given key.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the pth file.
|
||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||
path: P,
|
||||
key: Option<&str>,
|
||||
) -> Result<Vec<(String, Tensor)>> {
|
||||
let pth = PthTensors::new(path, key)?;
|
||||
let tensor_names = pth.tensor_infos.keys();
|
||||
let mut tensors = Vec::with_capacity(tensor_names.len());
|
||||
for name in tensor_names {
|
||||
@ -741,3 +831,11 @@ pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tenso
|
||||
}
|
||||
Ok(tensors)
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the pth file.
|
||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||
read_all_with_key(path, None)
|
||||
}
|
||||
|
@ -353,7 +353,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
||||
q3 = q3.add(32);
|
||||
|
||||
// Prepare low and high bits
|
||||
// We hardcode the shifts here to avoid loading them into a seperate register
|
||||
// We hardcode the shifts here to avoid loading them into a separate register
|
||||
let q3l_0 = _mm256_and_si256(q3bits, m3);
|
||||
let q3h_0 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
|
||||
@ -586,7 +586,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
||||
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
|
||||
q5 = q5.add(32);
|
||||
|
||||
//Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register
|
||||
//Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
|
||||
let q5l_0 = _mm256_and_si256(q5bits, m4);
|
||||
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
|
||||
let q5l_0_right_shift = match j {
|
||||
|
737
candle-core/src/quantized/cuda.rs
Normal file
737
candle-core/src/quantized/cuda.rs
Normal file
@ -0,0 +1,737 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||
use crate::{builder_arg as barg, CudaDevice, CudaStorage, Result};
|
||||
use half::f16;
|
||||
|
||||
use cudarc::driver::{CudaSlice, CudaView, PushKernelArg};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct PaddedCudaSlice {
|
||||
inner: CudaSlice<u8>,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct QCudaStorage {
|
||||
data: PaddedCudaSlice,
|
||||
dtype: GgmlDType,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
||||
|
||||
pub fn set_force_dmmv(f: bool) {
|
||||
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub const WARP_SIZE: usize = 32;
|
||||
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
|
||||
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
|
||||
pub const NWARPS_Q4_0_AMPERE: usize = 4;
|
||||
pub const GGML_CUDA_MMV_X: usize = 32;
|
||||
pub const GGML_CUDA_MMV_Y: usize = 1;
|
||||
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
|
||||
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
||||
pub const MATRIX_ROW_PADDING: usize = 512;
|
||||
|
||||
fn ceil_div(p: usize, q: usize) -> usize {
|
||||
p.div_ceil(q)
|
||||
}
|
||||
|
||||
fn pad(p: usize, q: usize) -> usize {
|
||||
ceil_div(p, q) * q
|
||||
}
|
||||
|
||||
fn quantize_q8_1(
|
||||
src: &CudaView<f32>,
|
||||
dst: &mut CudaSlice<u8>,
|
||||
elem_count: usize,
|
||||
ky: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
let kx = elem_count;
|
||||
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
||||
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
||||
let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?;
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, ky as u32, 1),
|
||||
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let mut builder = func.builder();
|
||||
builder.arg(src);
|
||||
builder.arg(dst);
|
||||
barg!(builder, kx as i32, kx_padded as i32);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dequantize_f32(
|
||||
data: &PaddedCudaSlice,
|
||||
dtype: GgmlDType,
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
let nb = elem_count.div_ceil(256);
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
||||
GgmlDType::Q5_0 => (
|
||||
"dequantize_block_q5_0_f32",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q5_1 => (
|
||||
"dequantize_block_q5_1_f32",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count)? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, 1, 1),
|
||||
block_dim: (block_dim as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
if is_k {
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&dst);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
} else {
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, nb32 as i32);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
}
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_f16(
|
||||
data: &PaddedCudaSlice,
|
||||
dtype: GgmlDType,
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
let nb = elem_count.div_ceil(256);
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
||||
GgmlDType::Q5_0 => (
|
||||
"dequantize_block_q5_0_f16",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q5_1 => (
|
||||
"dequantize_block_q5_1_f16",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, 1, 1),
|
||||
block_dim: (block_dim as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
if is_k {
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&dst);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
} else {
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, nb32 as i32);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
}
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_mul_mat_vec(
|
||||
data: &PaddedCudaSlice,
|
||||
y: &CudaView<f32>,
|
||||
dtype: GgmlDType,
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
}
|
||||
if y.len() != ncols {
|
||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||
}
|
||||
let kernel_name = match dtype {
|
||||
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
||||
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
||||
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
|
||||
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
|
||||
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
|
||||
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
|
||||
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
|
||||
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
|
||||
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
|
||||
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows)? };
|
||||
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (block_num_y as u32, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(y);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, ncols as i32, nrows as i32);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn mul_mat_vec_via_q8_1(
|
||||
data: &PaddedCudaSlice,
|
||||
y: &CudaView<f32>,
|
||||
dtype: GgmlDType,
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
b_size: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
}
|
||||
if y.len() != ncols * b_size {
|
||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||
}
|
||||
if b_size == 0 || b_size > 8 {
|
||||
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
|
||||
}
|
||||
// Start by quantizing y
|
||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes =
|
||||
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
|
||||
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
||||
|
||||
let kernel_name = match dtype {
|
||||
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
|
||||
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
|
||||
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
|
||||
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
|
||||
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
|
||||
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
|
||||
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
|
||||
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
|
||||
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
|
||||
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let kernel_name = format!("{kernel_name}{b_size}");
|
||||
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size)? };
|
||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||
let (nblocks, nwarps) = match b_size {
|
||||
1 => (nrows as u32, 4),
|
||||
2..=4 => ((nrows as u32).div_ceil(2), 4),
|
||||
5..=8 => ((nrows as u32).div_ceil(2), 2),
|
||||
_ => crate::bail!("unexpected bsize {b_size}"),
|
||||
};
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (nblocks, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, nwarps, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&y_q8_1);
|
||||
builder.arg(&dst);
|
||||
barg!(
|
||||
builder,
|
||||
/* ncols_x */ ncols as i32,
|
||||
/* nrows_x */ nrows as i32,
|
||||
/* nrows_y */ ncols_padded as i32,
|
||||
/* nrows_dst */ nrows as i32
|
||||
);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn mul_mat_via_q8_1(
|
||||
data: &PaddedCudaSlice,
|
||||
y: &CudaView<f32>,
|
||||
dtype: GgmlDType,
|
||||
x_rows: usize,
|
||||
x_cols: usize,
|
||||
y_rows: usize,
|
||||
y_cols: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < x_rows * x_cols {
|
||||
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
||||
}
|
||||
if y.len() != y_rows * y_cols {
|
||||
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
|
||||
}
|
||||
if x_cols != y_rows {
|
||||
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
|
||||
}
|
||||
let k = x_cols;
|
||||
// Start by quantizing y
|
||||
let k_padded = pad(k, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes =
|
||||
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
|
||||
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
|
||||
|
||||
let (kernel_name, mmq_x, mmq_y) = match dtype {
|
||||
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
|
||||
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
|
||||
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
|
||||
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
|
||||
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
|
||||
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
|
||||
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
|
||||
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
|
||||
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
|
||||
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols)? };
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (
|
||||
ceil_div(x_rows, mmq_y) as u32,
|
||||
ceil_div(y_cols, mmq_x) as u32,
|
||||
1,
|
||||
),
|
||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let mut builder = func.builder();
|
||||
builder.arg(/* vx */ &data.inner);
|
||||
builder.arg(/* vy */ &y_q8_1);
|
||||
builder.arg(/* dst */ &dst);
|
||||
barg!(
|
||||
builder,
|
||||
/* ncols_x */ x_cols as i32,
|
||||
/* nrows_x */ x_rows as i32,
|
||||
/* ncols_y */ y_cols as i32,
|
||||
/* nrows_y */ k_padded as i32,
|
||||
/* nrows_dst */ x_rows as i32
|
||||
);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
||||
let padded_size_in_bytes =
|
||||
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
|
||||
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes)?;
|
||||
Ok(QCudaStorage {
|
||||
data: PaddedCudaSlice {
|
||||
inner,
|
||||
len: size_in_bytes,
|
||||
},
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &CudaDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
|
||||
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
||||
let vec = slice.to_vec();
|
||||
T::to_float(&vec, dst)
|
||||
}
|
||||
|
||||
let fast_kernel = matches!(
|
||||
self.dtype,
|
||||
GgmlDType::Q4_0
|
||||
| GgmlDType::Q4_1
|
||||
| GgmlDType::Q5_0
|
||||
| GgmlDType::Q5_1
|
||||
| GgmlDType::Q8_0
|
||||
| GgmlDType::Q2K
|
||||
| GgmlDType::Q3K
|
||||
| GgmlDType::Q4K
|
||||
| GgmlDType::Q5K
|
||||
| GgmlDType::Q6K
|
||||
| GgmlDType::Q8K
|
||||
);
|
||||
if fast_kernel {
|
||||
return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
|
||||
}
|
||||
// Run the dequantization on cpu.
|
||||
|
||||
let buffer = self
|
||||
.device
|
||||
.memcpy_dtov(&self.data.inner.slice(..self.data.len))?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
|
||||
}
|
||||
|
||||
self.device
|
||||
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||
// Run the quantization on cpu.
|
||||
let src = match &src.slice {
|
||||
crate::cuda_backend::CudaStorageSlice::F32(data) => self.device.memcpy_dtov(data)?,
|
||||
_ => crate::bail!("only f32 can be quantized"),
|
||||
};
|
||||
let src_len = src.len();
|
||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
|
||||
qcpu_storage.quantize(&src)?;
|
||||
let data = qcpu_storage.data()?;
|
||||
let padded_len =
|
||||
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
|
||||
let mut inner = unsafe { self.device.alloc::<u8>(padded_len)? };
|
||||
self.device
|
||||
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?;
|
||||
self.data = PaddedCudaSlice {
|
||||
inner,
|
||||
len: data.len(),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.data.len
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
self_shape: &crate::Shape,
|
||||
storage: &CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
1
|
||||
} else {
|
||||
8
|
||||
};
|
||||
let use_vec_kernel = match layout.shape().dims() {
|
||||
[b, m, _k] => b * m <= max_bm,
|
||||
[b, _k] => *b <= max_bm,
|
||||
_ => false,
|
||||
};
|
||||
if use_vec_kernel {
|
||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
||||
} else {
|
||||
self.dequantize_matmul(self_shape, storage, layout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
fn dequantize_matmul_vec(
|
||||
&self,
|
||||
self_shape: &crate::Shape,
|
||||
rhs: &CudaStorage,
|
||||
rhs_l: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
let (nrows, ncols) = self_shape.dims2()?;
|
||||
let rhs = rhs.as_cuda_slice::<f32>()?;
|
||||
let rhs = match rhs_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => rhs.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
||||
};
|
||||
let (b_size, k) = match rhs_l.shape().dims() {
|
||||
[b, m, k] => (b * m, *k),
|
||||
[b, k] => (*b, *k),
|
||||
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
||||
};
|
||||
if ncols != k {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
||||
}
|
||||
|
||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||
} else {
|
||||
mul_mat_vec_via_q8_1(
|
||||
&self.data,
|
||||
&rhs,
|
||||
self.dtype,
|
||||
ncols,
|
||||
nrows,
|
||||
b_size,
|
||||
self.device(),
|
||||
)?
|
||||
};
|
||||
let mut out_shape = rhs_l.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
out_shape.push(nrows);
|
||||
Ok((out, out_shape.into()))
|
||||
}
|
||||
|
||||
fn dequantize_matmul(
|
||||
&self,
|
||||
self_shape: &crate::Shape,
|
||||
storage: &CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
use crate::backend::BackendStorage;
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let (b, m, k2) = match layout.shape().dims() {
|
||||
&[b, m, k2] => (b, m, k2),
|
||||
&[m, k2] => (1, m, k2),
|
||||
s => crate::bail!("unexpected shape for input {s:?}"),
|
||||
};
|
||||
if k2 != k {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
|
||||
}
|
||||
|
||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
let data_f32 = self.dequantize(n * k)?;
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
||||
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
|
||||
} else {
|
||||
let storage = storage.as_cuda_slice::<f32>()?;
|
||||
let storage = match layout.contiguous_offsets() {
|
||||
Some((o1, o2)) => storage.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous {
|
||||
op: "quantized-matmul",
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
mul_mat_via_q8_1(
|
||||
&self.data,
|
||||
&storage,
|
||||
self.dtype,
|
||||
/* x_rows */ n,
|
||||
/* x_cols */ k,
|
||||
/* y_rows */ k,
|
||||
/* y_cols */ b * m,
|
||||
self.device(),
|
||||
)?
|
||||
};
|
||||
let mut out_shape = layout.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
out_shape.push(n);
|
||||
Ok((out, out_shape.into()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &CudaDevice,
|
||||
data: &[T],
|
||||
) -> Result<super::QStorage> {
|
||||
let data = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
|
||||
};
|
||||
let dtype = T::DTYPE;
|
||||
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
|
||||
let mut inner = unsafe { device.alloc::<u8>(padded_len)? };
|
||||
device.memcpy_htod(data, &mut inner.slice_mut(..data.len()))?;
|
||||
Ok(QStorage::Cuda(QCudaStorage {
|
||||
data: PaddedCudaSlice {
|
||||
inner,
|
||||
len: data.len(),
|
||||
},
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cuda_quantize_q8_1() -> Result<()> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let el = 256;
|
||||
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes =
|
||||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
|
||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||
let y = dev.memcpy_stod(&vs)?;
|
||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_mmv_q8_1() -> Result<()> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let ncols = 256;
|
||||
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
||||
let y = dev.memcpy_stod(&vs)?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_vec_via_q8_1(
|
||||
&xs.data,
|
||||
&y.slice(..),
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* ncols */ ncols,
|
||||
/* nrows */ 1,
|
||||
/* b_size */ 1,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..))?;
|
||||
assert_eq!(vs.len(), 1);
|
||||
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
||||
// Q8 means 1/256 precision.
|
||||
assert_eq!(vs[0], 5561664.5);
|
||||
|
||||
let cuda_storage = dequantize_mul_mat_vec(
|
||||
&xs.data,
|
||||
&y.slice(..),
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* ncols */ ncols,
|
||||
/* nrows */ 1,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..))?;
|
||||
assert_eq!(vs.len(), 1);
|
||||
assert_eq!(vs[0], 5561851.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_mm_q8_1() -> Result<()> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let ncols = 256;
|
||||
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
||||
let y = dev.memcpy_stod(&vs)?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_via_q8_1(
|
||||
&xs.data,
|
||||
&y.slice(..),
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* x_rows */ 4,
|
||||
/* x_cols */ ncols,
|
||||
/* y_rows */ ncols,
|
||||
/* y_cols */ 4,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..))?;
|
||||
|
||||
/*
|
||||
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
||||
x @ x.t() / 16
|
||||
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
|
||||
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
|
||||
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
|
||||
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
|
||||
*/
|
||||
assert_eq!(vs.len(), 16);
|
||||
assert_eq!(vs[0], 347604.0);
|
||||
assert_eq!(vs[1], 888153.06);
|
||||
assert_eq!(vs[4], 869780.7);
|
||||
assert_eq!(vs[5], 2483145.0);
|
||||
assert_eq!(vs[11], 9407368.0);
|
||||
assert_eq!(vs[14], 9470856.0);
|
||||
assert_eq!(vs[15], 13138824.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// The following test used to fail under compute-sanitizer until #2526.
|
||||
#[test]
|
||||
fn cuda_mm_q8_1_pad() -> Result<()> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let (x_rows, ncols, y_cols) = (4, 16, 2048);
|
||||
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
|
||||
let y = dev.memcpy_stod(&vs)?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_via_q8_1(
|
||||
&xs.data,
|
||||
&y.slice(..),
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* x_rows */ x_rows,
|
||||
/* x_cols */ ncols,
|
||||
/* y_rows */ ncols,
|
||||
/* y_cols */ y_cols,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let _vs = dev.memcpy_dtov(&vs.slice(..))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
54
candle-core/src/quantized/dummy_cuda.rs
Normal file
54
candle-core/src/quantized/dummy_cuda.rs
Normal file
@ -0,0 +1,54 @@
|
||||
#![allow(unused)]
|
||||
use super::GgmlDType;
|
||||
use crate::{CudaDevice, CudaStorage, Error, Result};
|
||||
|
||||
pub struct QCudaStorage {
|
||||
dtype: GgmlDType,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &CudaDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
_self_shape: &crate::Shape,
|
||||
_storage: &CudaStorage,
|
||||
_layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
_device: &CudaDevice,
|
||||
_data: &[T],
|
||||
) -> Result<super::QStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
50
candle-core/src/quantized/dummy_metal.rs
Normal file
50
candle-core/src/quantized/dummy_metal.rs
Normal file
@ -0,0 +1,50 @@
|
||||
#![allow(unused)]
|
||||
use super::GgmlDType;
|
||||
use crate::{Error, MetalDevice, MetalStorage, Result};
|
||||
|
||||
pub struct QMetalStorage {
|
||||
dtype: GgmlDType,
|
||||
device: MetalDevice,
|
||||
}
|
||||
|
||||
impl QMetalStorage {
|
||||
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
_self_shape: &crate::Shape,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &crate::Layout,
|
||||
) -> Result<(MetalStorage, crate::Shape)> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
_device: &MetalDevice,
|
||||
_data: &[T],
|
||||
) -> Result<super::QStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
use super::{k_quants, GgmlDType};
|
||||
use crate::Result;
|
||||
use super::{k_quants, GgmlDType, QStorage};
|
||||
use crate::{Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -121,41 +121,68 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
raw_data: &[u8],
|
||||
size_in_bytes: usize,
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
super::QTensor::new(data.to_vec(), dims)
|
||||
let data: QStorage = match device {
|
||||
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
||||
Device::Metal(metal) => super::metal::load_quantized(metal, data)?,
|
||||
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?,
|
||||
};
|
||||
super::QTensor::new(data, dims)
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
/// Creates a Tensor from a raw GGML tensor.
|
||||
pub fn qtensor_from_ggml(
|
||||
ggml_dtype: GgmlDType,
|
||||
raw_data: &[u8],
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let blck_size = ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
let block_size = ggml_dtype.block_size();
|
||||
if tensor_elems % block_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::Q4_0 => {
|
||||
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
@ -163,6 +190,7 @@ pub fn qtensor_from_ggml(
|
||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
magic: VersionedMagic,
|
||||
device: &Device,
|
||||
) -> Result<(String, super::QTensor)> {
|
||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||
@ -183,11 +211,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
}
|
||||
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
|
||||
Ok(tensor) => Ok((name, tensor)),
|
||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||
}
|
||||
@ -198,10 +226,14 @@ pub struct Content {
|
||||
pub hparams: HParams,
|
||||
pub vocab: Vocab,
|
||||
pub tensors: HashMap<String, super::QTensor>,
|
||||
pub device: Device,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Content> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||
@ -211,14 +243,16 @@ impl Content {
|
||||
let mut tensors = HashMap::new();
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||
tensors.insert(name, tensor);
|
||||
}
|
||||
let device = device.clone();
|
||||
Ok(Self {
|
||||
magic,
|
||||
hparams,
|
||||
vocab,
|
||||
tensors,
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,8 @@
|
||||
//! Support for the GGUF file format.
|
||||
//! Support for the [GGUF file format](https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md).
|
||||
//!
|
||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||
|
||||
use super::{GgmlDType, QTensor};
|
||||
use crate::Result;
|
||||
use crate::{Context, Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -41,7 +40,7 @@ impl VersionedMagic {
|
||||
(Magic::Gguf, 1) => Self::GgufV1,
|
||||
(Magic::Gguf, 2) => Self::GgufV2,
|
||||
(Magic::Gguf, 3) => Self::GgufV3,
|
||||
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||
_ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"),
|
||||
};
|
||||
Ok(versioned_magic)
|
||||
}
|
||||
@ -59,19 +58,25 @@ impl TensorInfo {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
tensor_data_offset: u64,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let blck_size = self.ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
let block_size = self.ggml_dtype.block_size();
|
||||
if tensor_elems % block_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
|
||||
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size();
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
||||
super::ggml_file::qtensor_from_ggml(
|
||||
self.ggml_dtype,
|
||||
&raw_data,
|
||||
self.shape.dims().to_vec(),
|
||||
device,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -129,7 +134,6 @@ pub enum ValueType {
|
||||
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
||||
String,
|
||||
// The value is an array of other values, with the length and type prepended.
|
||||
///
|
||||
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
||||
Array,
|
||||
}
|
||||
@ -212,10 +216,16 @@ impl Value {
|
||||
}
|
||||
}
|
||||
|
||||
/// This will also automatically upcast any integral types which will not truncate.
|
||||
pub fn to_u64(&self) -> Result<u64> {
|
||||
match self {
|
||||
Self::U64(v) => Ok(*v),
|
||||
v => crate::bail!("not a u64 {v:?}"),
|
||||
// Autoupcast cases here
|
||||
Self::U8(v) => Ok(*v as u64),
|
||||
Self::U16(v) => Ok(*v as u64),
|
||||
Self::U32(v) => Ok(*v as u64),
|
||||
Self::Bool(v) => Ok(*v as u64),
|
||||
v => crate::bail!("not a u64 or upcastable to u64 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -328,7 +338,7 @@ impl Value {
|
||||
if value_type.len() != 1 {
|
||||
crate::bail!("multiple value-types in the same array {value_type:?}")
|
||||
}
|
||||
value_type.into_iter().next().unwrap()
|
||||
value_type.into_iter().next().context("empty value_type")?
|
||||
};
|
||||
w.write_u32::<LittleEndian>(value_type.to_u32())?;
|
||||
w.write_u64::<LittleEndian>(v.len() as u64)?;
|
||||
@ -447,7 +457,7 @@ impl Content {
|
||||
Some(Value::I32(v)) if *v >= 0 => *v as u64,
|
||||
_ => DEFAULT_ALIGNMENT,
|
||||
};
|
||||
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
|
||||
let tensor_data_offset = position.div_ceil(alignment) * alignment;
|
||||
Ok(Self {
|
||||
magic,
|
||||
metadata,
|
||||
@ -460,12 +470,13 @@ impl Content {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
name: &str,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
Some(tensor_info) => tensor_info,
|
||||
None => crate::bail!("cannot find tensor-infor for {name}"),
|
||||
None => crate::bail!("cannot find tensor info for {name}"),
|
||||
};
|
||||
tensor_info.read(reader, self.tensor_data_offset)
|
||||
tensor_info.read(reader, self.tensor_data_offset, device)
|
||||
}
|
||||
}
|
||||
|
||||
@ -517,10 +528,9 @@ pub fn write<W: std::io::Seek + std::io::Write>(
|
||||
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
||||
)
|
||||
}
|
||||
let data_ptr = tensor.as_ptr();
|
||||
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
w.write_all(data)?;
|
||||
let data = tensor.data()?;
|
||||
let size_in_bytes = data.len();
|
||||
w.write_all(&data)?;
|
||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||
w.write_all(&vec![0u8; padding])?;
|
||||
}
|
||||
|
@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K {
|
||||
let d2 = d * sc as f32;
|
||||
let m2 = min * m as f32;
|
||||
for (ql, qh) in ql.iter().zip(qh) {
|
||||
let to_add = if qh & u1 != 0 { 16 } else { 1 };
|
||||
y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
|
||||
let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
|
||||
y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
|
||||
ys_index += 1;
|
||||
}
|
||||
for (ql, qh) in ql.iter().zip(qh) {
|
||||
let to_add = if qh & u2 != 0 { 16 } else { 1 };
|
||||
y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
|
||||
let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
|
||||
y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
|
||||
ys_index += 1;
|
||||
}
|
||||
is += 2;
|
||||
@ -1850,8 +1850,8 @@ pub fn matmul<T: GgmlType>(
|
||||
crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
|
||||
}
|
||||
|
||||
let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE;
|
||||
let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE;
|
||||
let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE);
|
||||
let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE);
|
||||
// TODO: Do not make this copy if the DotType is f32.
|
||||
// TODO: Pre-allocate this.
|
||||
let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];
|
||||
|
230
candle-core/src/quantized/metal.rs
Normal file
230
candle-core/src/quantized/metal.rs
Normal file
@ -0,0 +1,230 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
|
||||
use metal::Buffer;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct QMetalStorage {
|
||||
dtype: GgmlDType,
|
||||
device: MetalDevice,
|
||||
buffer: Arc<Buffer>,
|
||||
}
|
||||
|
||||
impl QMetalStorage {
|
||||
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||
let buffer = device.allocate_zeros(size)?;
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.end_encoding();
|
||||
self.device.wait_until_completed()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let vec: Vec<f32> = read_to_vec(&buffer, block_len);
|
||||
f32::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
}
|
||||
|
||||
let buffer = self.device.new_buffer_with_data(&out)?;
|
||||
Ok(MetalStorage::new(
|
||||
buffer,
|
||||
self.device.clone(),
|
||||
elem_count,
|
||||
DType::F32,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
||||
// Quantization only happens on CPU for now.
|
||||
let src = src.to_cpu::<f32>()?;
|
||||
let elem_count = src.len();
|
||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
|
||||
qcpu_storage.quantize(&src)?;
|
||||
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
|
||||
self.buffer = buffer;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.buffer.length() as usize
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
self_shape: &Shape,
|
||||
storage: &MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
use crate::MetalError;
|
||||
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
// We always use a single batch dimension and stack all the tensors in the batch on the
|
||||
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
||||
// properly.
|
||||
let m = match dst_shape.len() {
|
||||
3 => dst_shape[0] * dst_shape[1],
|
||||
2 => dst_shape[0],
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let device = storage.device().clone();
|
||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
// In some cases it would be better to use the mm variant, though it has its drawbacks
|
||||
// around memory alignemnt.
|
||||
for batch_id in 0..m {
|
||||
candle_metal_kernels::call_quantized_matmul_mv_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
self.dtype.into(),
|
||||
(1, 1, n, k),
|
||||
storage.buffer(),
|
||||
(layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
|
||||
&self.buffer,
|
||||
batch_id * n * DType::F32.size_in_bytes(),
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &MetalDevice,
|
||||
data: &[T],
|
||||
) -> Result<QStorage> {
|
||||
let buffer = device.new_buffer_with_data(data)?;
|
||||
let device = device.clone();
|
||||
Ok(QStorage::Metal(QMetalStorage {
|
||||
dtype: T::DTYPE,
|
||||
device,
|
||||
buffer,
|
||||
}))
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
assert!(!ptr.is_null());
|
||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||
fn from(value: GgmlDType) -> Self {
|
||||
match value {
|
||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||
}
|
||||
}
|
||||
}
|
@ -1,23 +1,135 @@
|
||||
use crate::{Device, Result, Shape, Tensor};
|
||||
//! Code for GGML and GGUF files
|
||||
use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
|
||||
use k_quants::*;
|
||||
use std::borrow::Cow;
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
mod dummy_cuda;
|
||||
mod dummy_metal;
|
||||
pub mod ggml_file;
|
||||
pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal;
|
||||
#[cfg(not(feature = "metal"))]
|
||||
mod metal {
|
||||
pub use super::dummy_metal::*;
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod cuda;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
mod cuda {
|
||||
pub use super::dummy_cuda::*;
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
pub mod utils;
|
||||
use half::f16;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
|
||||
pub struct QTensor {
|
||||
data: Box<dyn QuantizedType>,
|
||||
storage: QStorage,
|
||||
shape: Shape,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = dtype.cpu_zeros(elem_count);
|
||||
Ok(QStorage::Cpu(storage))
|
||||
}
|
||||
Device::Metal(metal) => {
|
||||
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
|
||||
Ok(QStorage::Metal(storage))
|
||||
}
|
||||
Device::Cuda(cuda) => {
|
||||
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
|
||||
Ok(QStorage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum QStorage {
|
||||
Cpu(Box<dyn QuantizedType>),
|
||||
Metal(metal::QMetalStorage),
|
||||
Cuda(cuda::QCudaStorage),
|
||||
}
|
||||
|
||||
impl QStorage {
|
||||
fn block_size(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.block_size(),
|
||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||
QStorage::Cuda(storage) => storage.dtype().block_size(),
|
||||
}
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.dtype(),
|
||||
QStorage::Metal(storage) => storage.dtype(),
|
||||
QStorage::Cuda(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
fn device(&self) -> Device {
|
||||
match self {
|
||||
QStorage::Cpu(_storage) => Device::Cpu,
|
||||
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
|
||||
QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
fn quantize(&mut self, src: &Storage) -> Result<()> {
|
||||
match (self, src) {
|
||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
||||
storage.from_float(src.as_slice::<f32>()?)?;
|
||||
}
|
||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
|
||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
|
||||
}
|
||||
}
|
||||
|
||||
fn data(&self) -> Result<Cow<[u8]>> {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => {
|
||||
let data_ptr = storage.as_ptr();
|
||||
let size_in_bytes = storage.storage_size_in_bytes();
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
Ok(Cow::from(data))
|
||||
}
|
||||
QStorage::Metal(_) | QStorage::Cuda(_) => {
|
||||
crate::bail!("not implemented");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum GgmlDType {
|
||||
F32,
|
||||
@ -77,6 +189,25 @@ impl GgmlDType {
|
||||
}
|
||||
}
|
||||
|
||||
/// The block dtype
|
||||
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
|
||||
match self {
|
||||
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
|
||||
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
|
||||
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
|
||||
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
|
||||
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
|
||||
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
|
||||
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
|
||||
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
|
||||
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
|
||||
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
|
||||
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
|
||||
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
|
||||
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
|
||||
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
|
||||
}
|
||||
}
|
||||
/// The type size for blocks in bytes.
|
||||
pub fn type_size(&self) -> usize {
|
||||
use k_quants::*;
|
||||
@ -100,7 +231,7 @@ impl GgmlDType {
|
||||
}
|
||||
|
||||
/// The block size, i.e. the number of elements stored in each block.
|
||||
pub fn blck_size(&self) -> usize {
|
||||
pub fn block_size(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 1,
|
||||
Self::F16 => 1,
|
||||
@ -119,9 +250,13 @@ impl GgmlDType {
|
||||
pub trait QuantizedType: Send + Sync {
|
||||
fn dtype(&self) -> GgmlDType;
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
|
||||
fn storage_size_in_bytes(&self) -> usize;
|
||||
fn as_ptr(&self) -> *const u8;
|
||||
fn block_size(&self) -> usize;
|
||||
#[allow(clippy::wrong_self_convention)]
|
||||
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
|
||||
fn size(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
@ -129,12 +264,26 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.len() * core::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
|
||||
T::from_float(xs, self)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
T::DTYPE
|
||||
}
|
||||
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
||||
T::to_float(self.as_slice(), ys)
|
||||
fn block_size(&self) -> usize {
|
||||
T::BLCK_SIZE
|
||||
}
|
||||
|
||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
|
||||
let mut ys = vec![0.0f32; elem_count];
|
||||
T::to_float(self.as_slice(), &mut ys)?;
|
||||
Ok(CpuStorage::F32(ys))
|
||||
}
|
||||
|
||||
fn storage_size_in_bytes(&self) -> usize {
|
||||
@ -152,56 +301,53 @@ impl std::fmt::Debug for QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
|
||||
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
|
||||
let dims = shape.dims();
|
||||
if dims.is_empty() {
|
||||
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
||||
}
|
||||
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
|
||||
if dims[dims.len() - 1] % block_size != 0 {
|
||||
crate::bail!(
|
||||
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
||||
T::BLCK_SIZE
|
||||
block_size
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl QTensor {
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
) -> Result<Self> {
|
||||
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
check_shape::<T>(&shape)?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape,
|
||||
})
|
||||
check_shape(&shape, storage.block_size())?;
|
||||
Ok(Self { storage, shape })
|
||||
}
|
||||
|
||||
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
|
||||
let shape = src.shape();
|
||||
check_shape::<T>(shape)?;
|
||||
let src = src
|
||||
.to_dtype(crate::DType::F32)?
|
||||
.flatten_all()?
|
||||
.to_vec1::<f32>()?;
|
||||
if src.len() % T::BLCK_SIZE != 0 {
|
||||
let block_size = dtype.block_size();
|
||||
check_shape(shape, block_size)?;
|
||||
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
|
||||
let elem_count = shape.elem_count();
|
||||
if elem_count % block_size != 0 {
|
||||
crate::bail!(
|
||||
"tensor size ({shape:?}) is not divisible by block size {}",
|
||||
T::BLCK_SIZE
|
||||
block_size
|
||||
)
|
||||
}
|
||||
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||
T::from_float(&src, &mut data)?;
|
||||
let mut storage = src.device().qzeros(elem_count, dtype)?;
|
||||
storage.quantize(&src.storage())?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.data.dtype()
|
||||
self.storage.dtype()
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
self.storage.device()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
@ -213,21 +359,34 @@ impl QTensor {
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||
self.data.to_float(&mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, &self.shape, device)
|
||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
||||
let none = crate::op::BackpropOp::none();
|
||||
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
|
||||
}
|
||||
|
||||
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
self.data.matmul_t(mkn, lhs, dst)
|
||||
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
|
||||
// In the CUDA case, we have a specialized kernel as this can be useful for volta
|
||||
// architectures. https://github.com/huggingface/candle/issues/2136
|
||||
match &self.storage {
|
||||
QStorage::Cuda(s) => {
|
||||
let s = s.dequantize_f16(self.shape.elem_count())?;
|
||||
let none = crate::op::BackpropOp::none();
|
||||
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
|
||||
.to_device(device)
|
||||
}
|
||||
_ => {
|
||||
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.data.storage_size_in_bytes()
|
||||
self.storage.size_in_bytes()
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const u8 {
|
||||
self.data.as_ptr()
|
||||
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
|
||||
self.storage.data()
|
||||
}
|
||||
}
|
||||
|
||||
@ -235,6 +394,7 @@ impl QTensor {
|
||||
pub enum QMatMul {
|
||||
QTensor(std::sync::Arc<QTensor>),
|
||||
Tensor(Tensor),
|
||||
TensorF16(Tensor),
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
@ -248,6 +408,17 @@ thread_local! {
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static DEQUANTIZE_ALL_F16: bool = {
|
||||
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||
let dequantize = match qtensor.dtype() {
|
||||
@ -255,8 +426,11 @@ impl QMatMul {
|
||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||
};
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||
let tensor = qtensor.dequantize(&qtensor.device())?;
|
||||
Self::Tensor(tensor)
|
||||
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
|
||||
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
|
||||
Self::TensorF16(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
};
|
||||
@ -266,6 +440,25 @@ impl QMatMul {
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::QTensor(t) => t.dequantize_f16(&t.device()),
|
||||
Self::Tensor(t) => t.to_dtype(DType::F16),
|
||||
Self::TensorF16(t) => Ok(t.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let w = self.dequantize_f16()?;
|
||||
let in_dtype = xs.dtype();
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for QTensor {
|
||||
@ -288,23 +481,47 @@ impl crate::CustomOp1 for QTensor {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
let last_k = dst_shape.pop().context("empty dst_shape")?;
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let storage = storage.as_slice::<f32>()?;
|
||||
let storage =
|
||||
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cpu(storage) => storage,
|
||||
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
|
||||
};
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
self.matmul_t(
|
||||
(dst_shape.elem_count() / n, k, n),
|
||||
storage,
|
||||
&mut dst_storage,
|
||||
)?;
|
||||
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, Shape)> {
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Metal(metal) => metal,
|
||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
||||
};
|
||||
self_storage.fwd(&self.shape, storage, layout)
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
storage: &crate::CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CudaStorage, Shape)> {
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cuda(cuda) => cuda,
|
||||
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
|
||||
};
|
||||
self_storage.fwd(&self.shape, storage, layout)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for QMatMul {
|
||||
@ -319,6 +536,15 @@ impl crate::Module for QMatMul {
|
||||
};
|
||||
xs.matmul(&w)
|
||||
}
|
||||
Self::TensorF16(w) => {
|
||||
let in_dtype = xs.dtype();
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -12,6 +12,14 @@ use core::arch::arm::*;
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
use core::arch::aarch64::*;
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
|
||||
// TODO: dotprod
|
||||
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
|
||||
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
||||
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
|
||||
// TODO: Support dotprod when it's available outside of nightly.
|
||||
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||
|
||||
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
|
||||
let pl0 = vdotq_s32(v0_0ls, v1_0l);
|
||||
let ph0 = vdotq_s32(v0_0hs, v1_0h);
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||
@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
|
||||
// TODO dotprod once this is the intrinsics are.
|
||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||
|
||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||
let p0 = vdotq_s32(x0_0, y0_0);
|
||||
let p1 = vdotq_s32(x0_1, y0_1);
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
|
||||
for i in (0..QK_K).step_by(16) {
|
||||
let xs = vld1q_s8(xs.add(i));
|
||||
let ys = vld1q_s8(ys.add(i));
|
||||
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||
|
||||
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||
let xy = vdotq_s32(xs, ys);
|
||||
sum_i = vaddq_s32(sum_i, xy)
|
||||
}
|
||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||
@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
||||
|
||||
// TODO: dotprod case.
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||
scale = scale.add(2);
|
||||
}
|
||||
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
||||
@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
||||
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
||||
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
|
||||
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
|
||||
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
|
||||
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
|
||||
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
}
|
||||
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
||||
@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4bits = vld1q_u8_x2(q4);
|
||||
q4 = q4.add(32);
|
||||
// TODO: dotprod
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let q4bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
||||
);
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
|
||||
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
|
||||
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
|
||||
}
|
||||
sumf += d * (sumi1 + sumi2) as f32;
|
||||
}
|
||||
@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
|
||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
|
||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
|
||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
|
||||
isum += vaddvq_s32(p0) * *scale as i32
|
||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
||||
@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
|
||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
|
||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
|
||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
|
||||
isum += vaddvq_s32(p0) * *scale as i32
|
||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
if j == 0 {
|
||||
@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
|
||||
let mut is = 0usize;
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
for _j in 0..QK_K / 128 {
|
||||
let q2bits = vld1q_u8_x2(q2);
|
||||
q2 = q2.add(32);
|
||||
@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale(
|
||||
q2bytes: int8x16x2_t,
|
||||
q8bytes: int8x16x2_t,
|
||||
) -> i32 {
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
vaddvq_s16(p1) as i32 * aux[is + index] as i32
|
||||
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
|
||||
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
|
||||
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
|
||||
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
|
||||
}
|
||||
|
@ -1,3 +1,14 @@
|
||||
//! Module to load `safetensor` files into CPU/GPU memory.
|
||||
//!
|
||||
//! There are multiple ways to load tensors from safetensor files:
|
||||
//! - `load` function for loading directly into memory and returning a HashMap of tensors
|
||||
//! - `MmapedSafetensors` for memory mapping files and avoiding full allocation
|
||||
//! - `SliceSafetensors` for working with in-memory buffers
|
||||
//! - `BufferedSafetensors` for owning a buffer of data
|
||||
//!
|
||||
//! Tensors can also be serialized to safetensor format using the `save` function or
|
||||
//! `Tensor::save_safetensors` method.
|
||||
//!
|
||||
use crate::{DType, Device, Error, Result, Tensor, WithDType};
|
||||
use safetensors::tensor as st;
|
||||
use safetensors::tensor::SafeTensors;
|
||||
@ -171,7 +182,7 @@ pub trait Load {
|
||||
fn load(&self, device: &Device) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl<'a> Load for st::TensorView<'a> {
|
||||
impl Load for st::TensorView<'_> {
|
||||
fn load(&self, device: &Device) -> Result<Tensor> {
|
||||
convert(self, device)
|
||||
}
|
||||
@ -349,6 +360,30 @@ impl MmapedSafetensors {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SliceSafetensors<'a> {
|
||||
safetensors: SafeTensors<'a>,
|
||||
}
|
||||
|
||||
impl<'a> SliceSafetensors<'a> {
|
||||
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||
pub fn new(buffer: &'a [u8]) -> Result<Self> {
|
||||
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
|
||||
Ok(Self { safetensors })
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.safetensors.tensor(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
self.safetensors.tensors()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
Ok(self.safetensors.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BufferedSafetensors {
|
||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
//! TensorScalar Enum and Trait
|
||||
//!
|
||||
use crate::{Result, Tensor, WithDType};
|
||||
|
||||
pub enum TensorScalar {
|
||||
|
@ -43,43 +43,22 @@ impl From<usize> for Shape {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize,)> for Shape {
|
||||
fn from(d1: (usize,)) -> Self {
|
||||
Self(vec![d1.0])
|
||||
macro_rules! impl_from_tuple {
|
||||
($tuple:ty, $($index:tt),+) => {
|
||||
impl From<$tuple> for Shape {
|
||||
fn from(d: $tuple) -> Self {
|
||||
Self(vec![$(d.$index,)+])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize)> for Shape {
|
||||
fn from(d12: (usize, usize)) -> Self {
|
||||
Self(vec![d12.0, d12.1])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize)> for Shape {
|
||||
fn from(d123: (usize, usize, usize)) -> Self {
|
||||
Self(vec![d123.0, d123.1, d123.2])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize)> for Shape {
|
||||
fn from(d1234: (usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![d1234.0, d1234.1, d1234.2, d1234.3])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize, usize)> for Shape {
|
||||
fn from(d12345: (usize, usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
|
||||
fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![
|
||||
d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
|
||||
])
|
||||
}
|
||||
}
|
||||
impl_from_tuple!((usize,), 0);
|
||||
impl_from_tuple!((usize, usize), 0, 1);
|
||||
impl_from_tuple!((usize, usize, usize), 0, 1, 2);
|
||||
impl_from_tuple!((usize, usize, usize, usize), 0, 1, 2, 3);
|
||||
impl_from_tuple!((usize, usize, usize, usize, usize), 0, 1, 2, 3, 4);
|
||||
impl_from_tuple!((usize, usize, usize, usize, usize, usize), 0, 1, 2, 3, 4, 5);
|
||||
|
||||
impl From<Vec<usize>> for Shape {
|
||||
fn from(dims: Vec<usize>) -> Self {
|
||||
@ -142,6 +121,12 @@ impl Shape {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// The dimension size for a specified dimension index.
|
||||
pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
|
||||
let dim = dim.to_index(self, "dim")?;
|
||||
Ok(self.dims()[dim])
|
||||
}
|
||||
|
||||
/// The total number of elements, this is the product of all dimension sizes.
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
@ -171,7 +156,7 @@ impl Shape {
|
||||
}
|
||||
let mut acc = 1;
|
||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
||||
if stride != acc {
|
||||
if dim > 1 && stride != acc {
|
||||
return false;
|
||||
}
|
||||
acc *= dim;
|
||||
@ -186,7 +171,7 @@ impl Shape {
|
||||
}
|
||||
let mut acc = 1;
|
||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
||||
if stride != acc {
|
||||
if dim > 1 && stride != acc {
|
||||
return false;
|
||||
}
|
||||
acc *= dim;
|
||||
@ -304,6 +289,7 @@ impl Dim for usize {
|
||||
pub enum D {
|
||||
Minus1,
|
||||
Minus2,
|
||||
Minus(usize),
|
||||
}
|
||||
|
||||
impl D {
|
||||
@ -311,6 +297,7 @@ impl D {
|
||||
let dim = match self {
|
||||
Self::Minus1 => -1,
|
||||
Self::Minus2 => -2,
|
||||
Self::Minus(u) => -(*u as i32),
|
||||
};
|
||||
Error::DimOutOfRange {
|
||||
shape: shape.clone(),
|
||||
@ -327,6 +314,7 @@ impl Dim for D {
|
||||
match self {
|
||||
Self::Minus1 if rank >= 1 => Ok(rank - 1),
|
||||
Self::Minus2 if rank >= 2 => Ok(rank - 2),
|
||||
Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
|
||||
_ => Err(self.out_of_range(shape, op)),
|
||||
}
|
||||
}
|
||||
@ -336,6 +324,7 @@ impl Dim for D {
|
||||
match self {
|
||||
Self::Minus1 => Ok(rank),
|
||||
Self::Minus2 if rank >= 1 => Ok(rank - 1),
|
||||
Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
|
||||
_ => Err(self.out_of_range(shape, op)),
|
||||
}
|
||||
}
|
||||
@ -478,23 +467,6 @@ extract_dims!(
|
||||
(usize, usize, usize, usize, usize)
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn stride() {
|
||||
let shape = Shape::from(());
|
||||
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
|
||||
let shape = Shape::from(42);
|
||||
assert_eq!(shape.stride_contiguous(), [1]);
|
||||
let shape = Shape::from((42, 1337));
|
||||
assert_eq!(shape.stride_contiguous(), [1337, 1]);
|
||||
let shape = Shape::from((299, 792, 458));
|
||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ShapeWithOneHole {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape>;
|
||||
}
|
||||
@ -627,3 +599,36 @@ impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
||||
Ok((d1, d2, d3, d4, d).into())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn stride() {
|
||||
let shape = Shape::from(());
|
||||
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
|
||||
let shape = Shape::from(42);
|
||||
assert_eq!(shape.stride_contiguous(), [1]);
|
||||
let shape = Shape::from((42, 1337));
|
||||
assert_eq!(shape.stride_contiguous(), [1337, 1]);
|
||||
let shape = Shape::from((299, 792, 458));
|
||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_tuple() {
|
||||
let shape = Shape::from((2,));
|
||||
assert_eq!(shape.dims(), &[2]);
|
||||
let shape = Shape::from((2, 3));
|
||||
assert_eq!(shape.dims(), &[2, 3]);
|
||||
let shape = Shape::from((2, 3, 4));
|
||||
assert_eq!(shape.dims(), &[2, 3, 4]);
|
||||
let shape = Shape::from((2, 3, 4, 5));
|
||||
assert_eq!(shape.dims(), &[2, 3, 4, 5]);
|
||||
let shape = Shape::from((2, 3, 4, 5, 6));
|
||||
assert_eq!(shape.dims(), &[2, 3, 4, 5, 6]);
|
||||
let shape = Shape::from((2, 3, 4, 5, 6, 7));
|
||||
assert_eq!(shape.dims(), &[2, 3, 4, 5, 6, 7]);
|
||||
}
|
||||
}
|
||||
|
250
candle-core/src/sort.rs
Normal file
250
candle-core/src/sort.rs
Normal file
@ -0,0 +1,250 @@
|
||||
use crate::{Result, Tensor};
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct ArgSort {
|
||||
asc: bool,
|
||||
last_dim: usize,
|
||||
}
|
||||
|
||||
impl ArgSort {
|
||||
fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
|
||||
#[allow(clippy::uninit_vec)]
|
||||
// Safety: indexes are set later in the parallelized section.
|
||||
let mut sort_indexes = unsafe {
|
||||
let el_count = layout.shape().elem_count();
|
||||
let mut v = Vec::with_capacity(el_count);
|
||||
v.set_len(el_count);
|
||||
v
|
||||
};
|
||||
if self.asc {
|
||||
sort_indexes
|
||||
.par_chunks_exact_mut(self.last_dim)
|
||||
.zip(vs.par_chunks_exact(self.last_dim))
|
||||
.for_each(|(indexes, vs)| {
|
||||
indexes
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| *v = i as u32);
|
||||
indexes.sort_by(|&i, &j| {
|
||||
vs[i as usize]
|
||||
.partial_cmp(&vs[j as usize])
|
||||
.unwrap_or(std::cmp::Ordering::Greater)
|
||||
})
|
||||
});
|
||||
} else {
|
||||
sort_indexes
|
||||
.par_chunks_exact_mut(self.last_dim)
|
||||
.zip(vs.par_chunks_exact(self.last_dim))
|
||||
.for_each(|(indexes, vs)| {
|
||||
indexes
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| *v = i as u32);
|
||||
indexes.sort_by(|&j, &i| {
|
||||
vs[i as usize]
|
||||
.partial_cmp(&vs[j as usize])
|
||||
.unwrap_or(std::cmp::Ordering::Greater)
|
||||
})
|
||||
});
|
||||
}
|
||||
sort_indexes
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
mod cuda {
|
||||
use super::*;
|
||||
use crate::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
||||
use crate::{CudaDevice, WithDType};
|
||||
|
||||
impl crate::cuda_backend::Map1Any for ArgSort {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &crate::Layout,
|
||||
_wrap: W,
|
||||
) -> Result<S> {
|
||||
use cudarc::driver::PushKernelArg;
|
||||
|
||||
let slice = match layout.contiguous_offsets() {
|
||||
None => crate::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<u32>(elem_count)? };
|
||||
let func = if self.asc {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
|
||||
} else {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
|
||||
};
|
||||
let ncols = self.last_dim;
|
||||
let nrows = elem_count / ncols;
|
||||
let ncols_pad = next_power_of_2(ncols);
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (1, nrows as u32, 1),
|
||||
block_dim: (ncols_pad as u32, 1, 1),
|
||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||
};
|
||||
let stream = dev.cuda_stream();
|
||||
let mut builder = stream.launch_builder(&func);
|
||||
let ncols = ncols as i32;
|
||||
let ncols_pad = ncols_pad as i32;
|
||||
builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(S::U32(dst))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for ArgSort {
|
||||
fn name(&self) -> &'static str {
|
||||
"argsort"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
storage: &crate::CpuStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CpuStorage, crate::Shape)> {
|
||||
let sort_indexes = match storage {
|
||||
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
|
||||
};
|
||||
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
|
||||
Ok((sort_indexes, layout.shape().into()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
storage: &crate::CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::cuda_backend::Map1Any;
|
||||
let dev = storage.device();
|
||||
let slice = self.map(&storage.slice, dev, layout)?;
|
||||
let dst = crate::cuda_backend::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, crate::Shape)> {
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::DType;
|
||||
|
||||
let name = {
|
||||
if self.asc {
|
||||
match storage.dtype() {
|
||||
DType::BF16 => "asort_asc_bf16",
|
||||
DType::F16 => "asort_asc_f16",
|
||||
DType::F32 => "asort_asc_f32",
|
||||
DType::F64 => "asort_asc_f64",
|
||||
DType::U8 => "asort_asc_u8",
|
||||
DType::U32 => "asort_asc_u32",
|
||||
DType::I64 => "asort_asc_i64",
|
||||
}
|
||||
} else {
|
||||
match storage.dtype() {
|
||||
DType::BF16 => "asort_desc_bf16",
|
||||
DType::F16 => "asort_desc_f16",
|
||||
DType::F32 => "asort_desc_f32",
|
||||
DType::F64 => "asort_desc_f64",
|
||||
DType::U8 => "asort_desc_u8",
|
||||
DType::U32 => "asort_desc_u32",
|
||||
DType::I64 => "asort_desc_i64",
|
||||
}
|
||||
}
|
||||
};
|
||||
let device = storage.device();
|
||||
let kernels = device.kernels();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let el = layout.shape().elem_count();
|
||||
let ncols = self.last_dim;
|
||||
let nrows = el / ncols;
|
||||
let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
|
||||
let dst = device.new_buffer(el, DType::U32, "asort")?;
|
||||
let mut ncols_pad = 1;
|
||||
while ncols_pad < ncols {
|
||||
ncols_pad *= 2;
|
||||
}
|
||||
candle_metal_kernels::call_arg_sort(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
nrows,
|
||||
ncols,
|
||||
ncols_pad,
|
||||
src,
|
||||
&dst,
|
||||
)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn next_power_of_2(x: usize) -> usize {
|
||||
let mut n = 1;
|
||||
while n < x {
|
||||
n *= 2
|
||||
}
|
||||
n
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Returns the indices that sort the tensor along the last dimension.
|
||||
///
|
||||
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||
/// comes to ties.
|
||||
pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
|
||||
if !self.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous {
|
||||
op: "arg_sort_last_dim",
|
||||
});
|
||||
}
|
||||
let last_dim = match self.dims().last() {
|
||||
None => crate::bail!("empty last-dim in arg-sort"),
|
||||
Some(last_dim) => *last_dim,
|
||||
};
|
||||
// No need for a backward pass for arg sort.
|
||||
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
|
||||
}
|
||||
|
||||
/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
|
||||
/// sorted indexes.
|
||||
///
|
||||
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||
/// comes to ties.
|
||||
pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
|
||||
if !self.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous {
|
||||
op: "sort_last_dim",
|
||||
});
|
||||
}
|
||||
let asort = self.arg_sort_last_dim(asc)?;
|
||||
let sorted = self.gather(&asort, crate::D::Minus1)?;
|
||||
Ok((sorted, asort))
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||
use crate::op::{self, CmpOp, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
@ -43,9 +44,19 @@ impl Storage {
|
||||
}
|
||||
|
||||
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||
let lhs = self.device().location();
|
||||
let rhs = rhs.device().location();
|
||||
if lhs != rhs {
|
||||
let lhs_device = self.device();
|
||||
let rhs_device = rhs.device();
|
||||
let lhs = lhs_device.location();
|
||||
let rhs = rhs_device.location();
|
||||
let same_device = if self.device().is_metal() {
|
||||
// On metal, we require the device to be exactly the same rather than
|
||||
// having the same location. In cuda this is not necessary as all CudaDevice on the
|
||||
// same GPU will use the same cuda stream.
|
||||
lhs_device.same_device(&rhs_device)
|
||||
} else {
|
||||
lhs == rhs
|
||||
};
|
||||
if !same_device {
|
||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
||||
} else {
|
||||
Ok(())
|
||||
@ -252,6 +263,51 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
|
||||
match self {
|
||||
Self::Cpu(storage) => c.cpu_fwd(storage, l),
|
||||
Self::Cuda(storage) => c.cuda_fwd(storage, l),
|
||||
Self::Metal(storage) => c.metal_fwd(storage, l),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn inplace_op2(
|
||||
&mut self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
l2: &Layout,
|
||||
c: &dyn InplaceOp2,
|
||||
) -> Result<()> {
|
||||
self.same_device(t2, c.name())?;
|
||||
match (self, t2) {
|
||||
(Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
|
||||
(Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
|
||||
(Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn inplace_op3(
|
||||
&mut self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
l2: &Layout,
|
||||
t3: &Self,
|
||||
l3: &Layout,
|
||||
c: &dyn InplaceOp3,
|
||||
) -> Result<()> {
|
||||
self.same_device(t2, c.name())?;
|
||||
self.same_device(t3, c.name())?;
|
||||
match (self, t2, t3) {
|
||||
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
|
||||
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
|
||||
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
||||
c.metal_fwd(s1, l1, s2, l2, s3, l3)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
@ -352,6 +408,10 @@ impl Storage {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -697,4 +757,32 @@ impl Storage {
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => {
|
||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
||||
}
|
||||
(Self::Metal(src), Self::Metal(dst)) => {
|
||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "copy2d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
208
candle-core/src/streaming.rs
Normal file
208
candle-core/src/streaming.rs
Normal file
@ -0,0 +1,208 @@
|
||||
//! StreamTensror useful for streaming ops.
|
||||
//!
|
||||
use crate::{Result, Shape, Tensor};
|
||||
|
||||
pub trait Dim: crate::shape::Dim + Copy {}
|
||||
impl<T: crate::shape::Dim + Copy> Dim for T {}
|
||||
|
||||
/// A stream tensor is used in streaming module. It can either contain an actual tensor or be
|
||||
/// empty.
|
||||
#[derive(Clone)]
|
||||
pub struct StreamTensor(Option<Tensor>);
|
||||
|
||||
impl std::fmt::Debug for StreamTensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match &self.0 {
|
||||
Some(t) => write!(f, "{:?}", t.shape()),
|
||||
None => write!(f, "Empty"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::From<Option<Tensor>> for StreamTensor {
|
||||
fn from(value: Option<Tensor>) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::From<Tensor> for StreamTensor {
|
||||
fn from(value: Tensor) -> Self {
|
||||
Self(Some(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::From<()> for StreamTensor {
|
||||
fn from(_value: ()) -> Self {
|
||||
Self(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamTensor {
|
||||
pub fn empty() -> Self {
|
||||
Self(None)
|
||||
}
|
||||
|
||||
pub fn from_tensor(tensor: Tensor) -> Self {
|
||||
Self(Some(tensor))
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> Option<&Shape> {
|
||||
self.0.as_ref().map(|t| t.shape())
|
||||
}
|
||||
|
||||
pub fn cat2<D: Dim>(&self, rhs: &Self, dim: D) -> Result<Self> {
|
||||
let xs = match (&self.0, &rhs.0) {
|
||||
(Some(lhs), Some(rhs)) => {
|
||||
let xs = Tensor::cat(&[lhs, rhs], dim)?;
|
||||
Some(xs)
|
||||
}
|
||||
(Some(xs), None) | (None, Some(xs)) => Some(xs.clone()),
|
||||
(None, None) => None,
|
||||
};
|
||||
Ok(Self(xs))
|
||||
}
|
||||
|
||||
pub fn seq_len<D: Dim>(&self, dim: D) -> Result<usize> {
|
||||
match &self.0 {
|
||||
None => Ok(0),
|
||||
Some(v) => v.dim(dim),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.0 = None
|
||||
}
|
||||
|
||||
pub fn narrow<D: Dim>(&self, dim: D, offset: usize, len: usize) -> Result<StreamTensor> {
|
||||
let t = match &self.0 {
|
||||
None => None,
|
||||
Some(t) => {
|
||||
let seq_len = t.dim(dim)?;
|
||||
if seq_len <= offset {
|
||||
None
|
||||
} else {
|
||||
let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?;
|
||||
Some(t)
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self(t))
|
||||
}
|
||||
|
||||
/// Splits the Streaming Tensor on the time axis `dim` with the first `lhs_len` elements
|
||||
/// returned in the first output and the remaining in the second output.
|
||||
pub fn split<D: Dim>(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> {
|
||||
match &self.0 {
|
||||
None => Ok((Self::empty(), Self::empty())),
|
||||
Some(t) => {
|
||||
let seq_len = t.dim(dim)?;
|
||||
let lhs_len = usize::min(seq_len, lhs_len);
|
||||
if lhs_len == 0 {
|
||||
Ok((Self::empty(), t.clone().into()))
|
||||
} else {
|
||||
let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?);
|
||||
let rhs_len = seq_len - lhs_len;
|
||||
let rhs = if rhs_len == 0 {
|
||||
Self::empty()
|
||||
} else {
|
||||
Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?)
|
||||
};
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_option(&self) -> Option<&Tensor> {
|
||||
self.0.as_ref()
|
||||
}
|
||||
|
||||
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
||||
match &self.0 {
|
||||
None => Ok(Self::empty()),
|
||||
Some(t) => Ok(Self::from_tensor(t.apply(m)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Streaming modules take as input a stream tensor and return a stream tensor. They may perform
|
||||
/// some internal buffering so that enough data has been received for the module to be able to
|
||||
/// perform some operations.
|
||||
pub trait StreamingModule {
|
||||
// TODO: Should we also have a flush method?
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor>;
|
||||
fn reset_state(&mut self);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum BinOp {
|
||||
Add,
|
||||
Mul,
|
||||
Sub,
|
||||
Div,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingBinOp {
|
||||
prev_lhs: StreamTensor,
|
||||
prev_rhs: StreamTensor,
|
||||
pub op: BinOp,
|
||||
pub dim: crate::D,
|
||||
}
|
||||
|
||||
impl StreamingBinOp {
|
||||
pub fn new(op: BinOp, dim: crate::D) -> Self {
|
||||
Self {
|
||||
prev_lhs: StreamTensor::empty(),
|
||||
prev_rhs: StreamTensor::empty(),
|
||||
op,
|
||||
dim,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset_state(&mut self) {
|
||||
self.prev_lhs.reset();
|
||||
self.prev_rhs.reset();
|
||||
}
|
||||
|
||||
pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
|
||||
match self.op {
|
||||
BinOp::Add => Tensor::add(lhs, rhs),
|
||||
BinOp::Mul => Tensor::mul(lhs, rhs),
|
||||
BinOp::Sub => Tensor::sub(lhs, rhs),
|
||||
BinOp::Div => Tensor::div(lhs, rhs),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn step(&mut self, lhs: &StreamTensor, rhs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?;
|
||||
let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?;
|
||||
let lhs_len = lhs.seq_len(self.dim)?;
|
||||
let rhs_len = rhs.seq_len(self.dim)?;
|
||||
let common_len = usize::min(lhs_len, rhs_len);
|
||||
let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?;
|
||||
let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?;
|
||||
let ys = match (lhs.0, rhs.0) {
|
||||
(Some(lhs), Some(rhs)) => {
|
||||
let ys = self.forward(&lhs, &rhs)?;
|
||||
StreamTensor::from_tensor(ys)
|
||||
}
|
||||
(None, None) => StreamTensor::empty(),
|
||||
(lhs, rhs) => crate::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"),
|
||||
};
|
||||
self.prev_lhs = prev_lhs;
|
||||
self.prev_rhs = prev_rhs;
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple wrapper that doesn't do any buffering.
|
||||
pub struct Map<T: crate::Module>(T);
|
||||
|
||||
impl<T: crate::Module> StreamingModule for Map<T> {
|
||||
fn reset_state(&mut self) {}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
xs.apply(&self.0)
|
||||
}
|
||||
}
|
@ -32,14 +32,11 @@ impl<'a> StridedIndex<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for StridedIndex<'a> {
|
||||
impl Iterator for StridedIndex<'_> {
|
||||
type Item = usize;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let storage_index = match self.next_storage_index {
|
||||
None => return None,
|
||||
Some(storage_index) => storage_index,
|
||||
};
|
||||
let storage_index = self.next_storage_index?;
|
||||
let mut updated = false;
|
||||
let mut next_storage_index = storage_index;
|
||||
for ((multi_i, max_i), stride_i) in self
|
||||
|
@ -1,9 +1,7 @@
|
||||
//! Tensors are N-dimenional matrixes of elements using a single data type.
|
||||
//! Tensors are N-dimensional matrixes of elements using a single data type.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{
|
||||
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
||||
};
|
||||
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
@ -81,6 +79,9 @@ macro_rules! unary_op {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
if shape.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self
|
||||
.storage()
|
||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||
@ -94,6 +95,9 @@ macro_rules! binary_op {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
if shape.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage(),
|
||||
self.layout(),
|
||||
@ -116,6 +120,9 @@ macro_rules! binary_op_scalar {
|
||||
.broadcast_as(self.shape())?,
|
||||
};
|
||||
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage(),
|
||||
self.layout(),
|
||||
@ -235,7 +242,7 @@ impl Tensor {
|
||||
Self::zeros_impl(shape, dtype, device, false)
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other
|
||||
/// Creates a new tensor filled with zeros with same shape, dtype, and device as the other
|
||||
/// tensor.
|
||||
///
|
||||
/// ```rust
|
||||
@ -361,7 +368,33 @@ impl Tensor {
|
||||
Self::new_impl(array, shape, device, false)
|
||||
}
|
||||
|
||||
/// Returns a new tensor with all the elements having the same specified value. Note that
|
||||
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
|
||||
///
|
||||
/// assert_eq!(a.to_vec2::<f64>()?, &[
|
||||
/// [3.5, 3.5, 3.5, 3.5],
|
||||
/// [3.5, 3.5, 3.5, 3.5],
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
pub fn full<D: crate::WithDType, S: Into<Shape>>(
|
||||
value: D,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
|
||||
}
|
||||
|
||||
/// Creates a new 1D tensor from an iterator.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
|
||||
///
|
||||
/// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_iter<D: crate::WithDType>(
|
||||
iter: impl IntoIterator<Item = D>,
|
||||
device: &Device,
|
||||
@ -373,12 +406,26 @@ impl Tensor {
|
||||
|
||||
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||
/// difference `1` from `start`.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::arange(2., 5., &Device::Cpu)?;
|
||||
///
|
||||
/// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
|
||||
Self::arange_step(start, end, D::one(), device)
|
||||
}
|
||||
|
||||
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||
/// difference `step` from `start`.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
|
||||
///
|
||||
/// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn arange_step<D: crate::WithDType>(
|
||||
start: D,
|
||||
end: D,
|
||||
@ -386,7 +433,7 @@ impl Tensor {
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
if D::is_zero(&step) {
|
||||
crate::bail!("step cannot be zero")
|
||||
bail!("step cannot be zero")
|
||||
}
|
||||
let mut data = vec![];
|
||||
let mut current = start;
|
||||
@ -424,6 +471,16 @@ impl Tensor {
|
||||
/// Creates a new tensor initialized with values from the input vector. The number of elements
|
||||
/// in this vector must be the same as the number of elements defined by the shape.
|
||||
/// If the device is cpu, no data copy is made.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
|
||||
///
|
||||
/// assert_eq!(a.to_vec2::<f64>()?, &[
|
||||
/// [1., 2., 3.],
|
||||
/// [4., 5., 6.]
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
@ -434,12 +491,31 @@ impl Tensor {
|
||||
|
||||
/// Creates a new tensor initialized with values from the input slice. The number of elements
|
||||
/// in this vector must be the same as the number of elements defined by the shape.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
|
||||
/// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
|
||||
///
|
||||
/// assert_eq!(a.to_vec2::<f64>()?, &[
|
||||
/// [2., 3., 4.],
|
||||
/// [5., 6., 7.]
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||
array: &[D],
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
Self::new_impl(array, shape.into(), device, false)
|
||||
let shape = shape.into();
|
||||
let n: usize = shape.elem_count();
|
||||
let buffer_size: usize = array.len();
|
||||
if buffer_size != n {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let storage = device.storage_from_slice(array)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, false))
|
||||
}
|
||||
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||
@ -498,9 +574,11 @@ impl Tensor {
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
unary_op!(silu, Silu);
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
unary_op!(sign, Sign);
|
||||
|
||||
/// Round element of the input tensor to the nearest integer.
|
||||
///
|
||||
@ -563,9 +641,9 @@ impl Tensor {
|
||||
///
|
||||
/// * `args` - A slice of 1D tensors.
|
||||
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
|
||||
/// first dimension corresponds to the cardinality of the second input and the second
|
||||
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
||||
/// dimensions are in the same order as the cardinality of the inputs.
|
||||
/// first dimension corresponds to the cardinality of the second input and the second
|
||||
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
||||
/// dimensions are in the same order as the cardinality of the inputs.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
@ -636,6 +714,9 @@ impl Tensor {
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().affine(self.layout(), mul, add)?;
|
||||
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
@ -643,6 +724,9 @@ impl Tensor {
|
||||
|
||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().elu(self.layout(), alpha)?;
|
||||
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
@ -650,12 +734,15 @@ impl Tensor {
|
||||
|
||||
/// Raise the tensor to some float exponent `e`.
|
||||
pub fn powf(&self, e: f64) -> Result<Self> {
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().powf(self.layout(), e)?;
|
||||
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
if dim >= self.dims().len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
shape: self.shape().clone(),
|
||||
@ -669,7 +756,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
/// Split a tensor into the specified number of chunks, this may return less chunks than
|
||||
/// specificed.
|
||||
/// specified.
|
||||
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
|
||||
let dim = dim.to_index(self.shape(), "chunk")?;
|
||||
let size = self.dim(dim)?;
|
||||
@ -696,6 +783,30 @@ impl Tensor {
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + len`.
|
||||
/// ```
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[
|
||||
/// [0f32, 1., 2.],
|
||||
/// [3. , 4., 5.],
|
||||
/// [6. , 7., 8.]
|
||||
/// ], &Device::Cpu)?;
|
||||
///
|
||||
/// let b = a.narrow(0, 1, 2)?;
|
||||
/// assert_eq!(b.shape().dims(), &[2, 3]);
|
||||
/// assert_eq!(b.to_vec2::<f32>()?, &[
|
||||
/// [3., 4., 5.],
|
||||
/// [6., 7., 8.]
|
||||
/// ]);
|
||||
///
|
||||
/// let c = a.narrow(1, 1, 1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[3, 1]);
|
||||
/// assert_eq!(c.to_vec2::<f32>()?, &[
|
||||
/// [1.],
|
||||
/// [4.],
|
||||
/// [7.]
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.dims();
|
||||
let dim = dim.to_index(self.shape(), "narrow")?;
|
||||
@ -794,6 +905,35 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Roll the tensor input along the given dimension.
|
||||
/// Elements that are shifted beyond the last position are re-introduced at the first position.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.roll(1, 0)?;
|
||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.roll(-1, 0)?;
|
||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
|
||||
where
|
||||
D: Dim + Clone,
|
||||
{
|
||||
let dim = dim.to_index(self.shape(), "roll")?;
|
||||
let dim_size = self.dim(dim)?;
|
||||
let shift = shift.rem_euclid(dim_size as i32) as usize;
|
||||
if shift == 0 {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let a = self.narrow(dim, 0, dim_size - shift)?;
|
||||
let b = self.narrow(dim, dim_size - shift, shift)?;
|
||||
Tensor::cat(&[&b, &a], dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions.
|
||||
///
|
||||
@ -975,7 +1115,7 @@ impl Tensor {
|
||||
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
||||
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
||||
let (n, c, _l) = self.dims3()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
|
||||
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest1d(self.layout(), target_size)?;
|
||||
@ -994,7 +1134,11 @@ impl Tensor {
|
||||
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
||||
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
let (n, c, _h, _w) = self.dims4()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
|
||||
arg,
|
||||
target_h,
|
||||
target_w,
|
||||
});
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest2d(self.layout(), target_h, target_w)?;
|
||||
@ -1027,6 +1171,9 @@ impl Tensor {
|
||||
let kernel_size = kernel_size.to_usize2();
|
||||
let stride = stride.to_usize2();
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
if h < kernel_size.0 || w < kernel_size.1 {
|
||||
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
|
||||
}
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||
@ -1062,6 +1209,9 @@ impl Tensor {
|
||||
let kernel_size = kernel_size.to_usize2();
|
||||
let stride = stride.to_usize2();
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
if h < kernel_size.0 || w < kernel_size.1 {
|
||||
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
|
||||
}
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||
@ -1105,6 +1255,9 @@ impl Tensor {
|
||||
let n = b_dims[dim - 1];
|
||||
|
||||
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||
if c_shape.elem_count() == 0 || k == 0 {
|
||||
return Tensor::zeros(c_shape, self.dtype(), self.device());
|
||||
}
|
||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||
let batching_b: usize = b_dims[..dim - 2].iter().product();
|
||||
if k != k2 || batching != batching_b {
|
||||
@ -1301,7 +1454,7 @@ impl Tensor {
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||
let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let offset = start * src.dims()[1..].iter().product::<usize>();
|
||||
@ -1367,14 +1520,15 @@ impl Tensor {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `self` - The input tensor.
|
||||
/// * `indexes` - The indices of elements to gather, this should have the same shape as `self`
|
||||
/// but can have a different number of elements on the target dimension.
|
||||
/// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
|
||||
/// and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
|
||||
/// * `dim` - the target dimension.
|
||||
///
|
||||
/// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
|
||||
/// dimension `dim` by the values in `indexes`.
|
||||
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "gather")?;
|
||||
|
||||
let self_dims = self.dims();
|
||||
let indexes_dims = indexes.dims();
|
||||
let mismatch = if indexes_dims.len() != self_dims.len() {
|
||||
@ -1382,7 +1536,7 @@ impl Tensor {
|
||||
} else {
|
||||
let mut mismatch = false;
|
||||
for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
|
||||
if i != dim && d1 != d2 {
|
||||
if i != dim && d1 < d2 {
|
||||
mismatch = true;
|
||||
break;
|
||||
}
|
||||
@ -1606,6 +1760,42 @@ impl Tensor {
|
||||
&self.op
|
||||
}
|
||||
|
||||
/// Computes the max of all the elements in this tensor and returns a tensor holding this
|
||||
/// scalar with zero dimensions.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.max_all()?;
|
||||
/// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn max_all(&self) -> Result<Tensor> {
|
||||
if self.rank() == 0 {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
self.flatten_all()?.max(0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the min of all the elements in this tensor and returns a tensor holding this
|
||||
/// scalar with zero dimensions.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.min_all()?;
|
||||
/// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn min_all(&self) -> Result<Tensor> {
|
||||
if self.rank() == 0 {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
self.flatten_all()?.min(0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the sum of all the elements in this tensor and returns a tensor holding this
|
||||
/// scalar with zero dimensions.
|
||||
///
|
||||
@ -1784,7 +1974,7 @@ impl Tensor {
|
||||
let is_permutation =
|
||||
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
|
||||
if !is_permutation {
|
||||
crate::bail!(
|
||||
bail!(
|
||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||
self.dims(),
|
||||
dims
|
||||
@ -1833,9 +2023,9 @@ impl Tensor {
|
||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||
///
|
||||
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||
pub fn detach(&self) -> Result<Tensor> {
|
||||
pub fn detach(&self) -> Tensor {
|
||||
if self.op.is_none() && !self.is_variable {
|
||||
Ok(self.clone())
|
||||
self.clone()
|
||||
} else {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -1846,7 +2036,7 @@ impl Tensor {
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
Tensor(Arc::new(tensor_))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1863,10 +2053,7 @@ impl Tensor {
|
||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Metal(storage), Device::Cpu) => {
|
||||
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
|
||||
Storage::Cpu(storage.to_cpu_storage()?)
|
||||
}
|
||||
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||
// are the same.
|
||||
@ -1875,7 +2062,11 @@ impl Tensor {
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||
_ => {
|
||||
bail!("not implemented yet")
|
||||
bail!(
|
||||
"not implemented yet, self.device: {:?}, device: {:?}",
|
||||
self.device(),
|
||||
device
|
||||
)
|
||||
}
|
||||
};
|
||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||
@ -1954,7 +2145,7 @@ impl Tensor {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let op = BackpropOp::new1(self, Op::Copy);
|
||||
@ -1962,11 +2153,21 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a tensor that is in row major order. This always makes a copy.
|
||||
pub fn force_contiguous(&self) -> Result<Tensor> {
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let op = BackpropOp::new1(self, Op::Copy);
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
|
||||
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
||||
/// copied.
|
||||
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
||||
let shape = self.shape().clone();
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), true))
|
||||
@ -2019,7 +2220,7 @@ impl Tensor {
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
} else {
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
@ -2046,8 +2247,19 @@ impl Tensor {
|
||||
let dim = dim.to_index(self.shape(), "squeeze")?;
|
||||
if dims[dim] == 1 {
|
||||
let mut dims = dims.to_vec();
|
||||
let mut strides = self.stride().to_vec();
|
||||
dims.remove(dim);
|
||||
self.reshape(dims)
|
||||
strides.remove(dim);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
|
||||
op: BackpropOp::new1(self, Op::Reshape),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
} else {
|
||||
Ok(self.clone())
|
||||
}
|
||||
@ -2068,10 +2280,24 @@ impl Tensor {
|
||||
/// ```
|
||||
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let mut dims = self.dims().to_vec();
|
||||
let mut strides = self.stride().to_vec();
|
||||
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
|
||||
// Cannot panic because to_index_plus_one already checks dimensions
|
||||
dims.insert(dim, 1);
|
||||
self.reshape(dims)
|
||||
// Any stride would work here, but we pick one so as to maximize the probability to remain
|
||||
// C contiguous.
|
||||
let stride = if dim < strides.len() { strides[dim] } else { 1 };
|
||||
strides.insert(dim, stride);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
|
||||
op: BackpropOp::new1(self, Op::Reshape),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// Stacks two or more tensors along a particular dimension.
|
||||
@ -2102,152 +2328,6 @@ impl Tensor {
|
||||
Self::cat(&args, dim)
|
||||
}
|
||||
|
||||
/// Concatenates two or more tensors along a particular dimension.
|
||||
///
|
||||
/// All tensors must of the same rank, and the output will have
|
||||
/// the same rank
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
// TODO: Avoid these transpositions and have an implementation that works
|
||||
// for dim != 0...
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[0] = 0;
|
||||
let mut offsets = vec![0usize];
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == 0 {
|
||||
cat_dims[0] += v2;
|
||||
}
|
||||
if dim_idx != 0 && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage()
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
|
||||
/// input tensor values and `right` elements after.
|
||||
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||
@ -2282,7 +2362,7 @@ impl Tensor {
|
||||
if left == 0 && right == 0 {
|
||||
Ok(self.clone())
|
||||
} else if self.elem_count() == 0 {
|
||||
crate::bail!("cannot use pad_with_same on an empty tensor")
|
||||
bail!("cannot use pad_with_same on an empty tensor")
|
||||
} else if left == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||
@ -2330,6 +2410,10 @@ impl Tensor {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
||||
pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
|
||||
self.storage.write().unwrap()
|
||||
}
|
||||
|
||||
// If we extend the visibility of this function to be usable outside of this crate, we should
|
||||
// make it unsafe.
|
||||
pub(crate) fn storage_mut_and_layout(
|
||||
@ -2351,108 +2435,18 @@ impl Tensor {
|
||||
std::ptr::eq(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Applies a unary custom op without backward support
|
||||
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op without backward support
|
||||
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) =
|
||||
self.storage()
|
||||
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op without backward support
|
||||
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c,
|
||||
)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a unary custom op.
|
||||
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
||||
let (storage, shape) = self
|
||||
.storage()
|
||||
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
||||
self.apply_op1_arc(Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op.
|
||||
pub fn apply_op2_arc(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op2(
|
||||
self.layout(),
|
||||
&rhs.storage(),
|
||||
rhs.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
||||
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op.
|
||||
pub fn apply_op3_arc(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
||||
Op::CustomOp3(t1, t2, t3, c.clone())
|
||||
});
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: C,
|
||||
) -> Result<Self> {
|
||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||
/// values means counting the dimensions from the back.
|
||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||
let rank = self.rank() as i64;
|
||||
if rank <= axis {
|
||||
crate::bail!("axis {axis} is too large, tensor rank {rank}")
|
||||
bail!("axis {axis} is too large, tensor rank {rank}")
|
||||
} else if 0 <= axis {
|
||||
Ok(axis as usize)
|
||||
} else {
|
||||
let naxis = rank + axis;
|
||||
if naxis < 0 {
|
||||
crate::bail!("axis {axis} is too small, tensor rank {rank}")
|
||||
bail!("axis {axis} is too small, tensor rank {rank}")
|
||||
}
|
||||
Ok(naxis as usize)
|
||||
}
|
||||
@ -2514,14 +2508,14 @@ impl Tensor {
|
||||
let src_dims = src.dims();
|
||||
let self_dims = self.dims();
|
||||
if self_dims.len() != src_dims.len() {
|
||||
crate::bail!(
|
||||
bail!(
|
||||
"slice-assign requires input with the same rank {} <> {}",
|
||||
self_dims.len(),
|
||||
src_dims.len()
|
||||
)
|
||||
}
|
||||
if self_dims.len() != ranges.len() {
|
||||
crate::bail!(
|
||||
bail!(
|
||||
"slice-assign requires input with the same rank as there are ranges {} <> {}",
|
||||
self_dims.len(),
|
||||
ranges.len()
|
||||
@ -2541,18 +2535,16 @@ impl Tensor {
|
||||
std::ops::Bound::Excluded(v) => *v,
|
||||
};
|
||||
if end_excluded <= start_included {
|
||||
crate::bail!(
|
||||
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
|
||||
)
|
||||
bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
|
||||
}
|
||||
if self_dims[i] < end_excluded {
|
||||
crate::bail!(
|
||||
bail!(
|
||||
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
|
||||
self_dims[i]
|
||||
)
|
||||
}
|
||||
if end_excluded - start_included != src_dims[i] {
|
||||
crate::bail!(
|
||||
bail!(
|
||||
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
|
||||
)
|
||||
}
|
||||
@ -2561,6 +2553,55 @@ impl Tensor {
|
||||
}
|
||||
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
|
||||
}
|
||||
|
||||
/// Returns log(sum(exp(tensor), dim)).
|
||||
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||
let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
|
||||
if sum_dims.is_empty() {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let max = sum_dims[1..]
|
||||
.iter()
|
||||
.try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
|
||||
max.max_keepdim(dim)
|
||||
})?;
|
||||
let exp = self.broadcast_sub(&max)?.exp()?;
|
||||
let sum = exp.sum(sum_dims.clone())?;
|
||||
|
||||
sum.log()? + max.squeeze_dims(&sum_dims)
|
||||
}
|
||||
|
||||
/// Pointwise pow operation.
|
||||
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||
rhs.mul(&self.log()?)?.exp()
|
||||
}
|
||||
|
||||
/// Broadcasting version of `pow`.
|
||||
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||
rhs.broadcast_mul(&self.log()?)?.exp()
|
||||
}
|
||||
|
||||
/// Returns a new tensor with the order of elements reversed along the specified dimensions.
|
||||
/// This function makes a copy of the tensor’s data.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, Device};
|
||||
/// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
|
||||
/// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
/// let t_flipped = t.flip(&[0])?;
|
||||
/// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
|
||||
let mut result = self.clone();
|
||||
for &dim in dims.iter() {
|
||||
let size = result.dim(dim)?;
|
||||
let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
|
||||
let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
|
||||
result = result.index_select(&indices_tensor, dim)?;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
303
candle-core/src/tensor_cat.rs
Normal file
303
candle-core/src/tensor_cat.rs
Normal file
@ -0,0 +1,303 @@
|
||||
use crate::{shape::Dim, Context, Error, Result, Shape, Tensor};
|
||||
|
||||
impl Tensor {
|
||||
/// Concatenates two or more tensors along a particular dimension.
|
||||
///
|
||||
/// All tensors must of the same rank, and the output will have
|
||||
/// the same rank
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
||||
if all_contiguous {
|
||||
Self::cat_contiguous(args, dim)
|
||||
} else if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[0] = 0;
|
||||
let mut offsets = vec![0usize];
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == 0 {
|
||||
cat_dims[0] += v2;
|
||||
}
|
||||
if dim_idx != 0 && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
let next_offset = offsets.last().context("empty offsets")? + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage()
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[dim] = 0;
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == dim {
|
||||
cat_dims[dim] += v2;
|
||||
}
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
let cat_target_dim_len = cat_dims[dim];
|
||||
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||
let mut dst_o = 0;
|
||||
for arg in args.iter() {
|
||||
let arg = arg.as_ref();
|
||||
let arg_dims = arg.shape().dims();
|
||||
let d1: usize = arg_dims.iter().take(dim).product();
|
||||
let d2 = block_size * arg_dims[dim];
|
||||
let dst_s = block_size * cat_target_dim_len;
|
||||
let src_o = arg.layout().start_offset();
|
||||
arg.storage().copy2d(
|
||||
&mut storage,
|
||||
d1,
|
||||
d2,
|
||||
/* src_s */ d2,
|
||||
dst_s,
|
||||
src_o,
|
||||
dst_o,
|
||||
)?;
|
||||
dst_o += d2;
|
||||
}
|
||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Set the values on `self` using values from `src`. The copy starts at the specified
|
||||
/// `offset` for the target dimension `dim` on `self`.
|
||||
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||
///
|
||||
/// Note that this modifies `self` in place and as such is not compatibel with
|
||||
/// back-propagation.
|
||||
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||
if !self.is_contiguous() || !src.is_contiguous() {
|
||||
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
|
||||
}
|
||||
if self.same_storage(src) {
|
||||
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||
}
|
||||
if self.dtype() != src.dtype() {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: src.dtype(),
|
||||
op: "slice-set",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.device().location() != src.device().location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: self.device().location(),
|
||||
rhs: src.device().location(),
|
||||
op: "slice-set",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.rank() != src.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: self.rank(),
|
||||
got: src.rank(),
|
||||
shape: self.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
|
||||
if dim_idx == dim && *v2 + offset > *v1 {
|
||||
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
|
||||
}
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
|
||||
}
|
||||
}
|
||||
let block_size: usize = src.dims().iter().skip(1 + dim).product();
|
||||
let d1: usize = src.dims().iter().take(dim).product();
|
||||
let d2 = block_size * src.dims()[dim];
|
||||
let dst_o = self.layout().start_offset() + offset * block_size;
|
||||
let src_o = src.layout().start_offset();
|
||||
src.storage().copy2d(
|
||||
&mut self.storage_mut(),
|
||||
d1,
|
||||
d2,
|
||||
/* src_s */ d2,
|
||||
/* dst_s */ block_size * self.dims()[dim],
|
||||
src_o,
|
||||
dst_o,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -24,6 +24,15 @@ macro_rules! test_device {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
|
||||
assert_eq!(t1.shape(), t2.shape());
|
||||
// Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`)
|
||||
let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
|
||||
let all_equal = eq_tensor.sum_all()?;
|
||||
assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec0::<f32>()?;
|
||||
|
@ -1,3 +1,4 @@
|
||||
//! Useful functions for checking features.
|
||||
use std::str::FromStr;
|
||||
|
||||
pub fn get_num_threads() -> usize {
|
||||
|
@ -34,9 +34,14 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
// Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
|
||||
pub fn from_tensor(t: &Tensor) -> Result<Self> {
|
||||
let inner = t.make_var()?;
|
||||
Ok(Self(inner))
|
||||
if t.is_variable() {
|
||||
Ok(Self(t.clone()))
|
||||
} else {
|
||||
let inner = t.make_var()?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rand_f64<S: Into<Shape>>(
|
||||
@ -107,6 +112,10 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn as_detached_tensor(&self) -> Tensor {
|
||||
self.0.detach()
|
||||
}
|
||||
|
||||
pub fn as_tensor(&self) -> &Tensor {
|
||||
&self.0
|
||||
}
|
||||
|
@ -18,6 +18,9 @@ w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv1d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -50,8 +53,25 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
if dev.is_cpu() {
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
let res = {
|
||||
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
|
||||
t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?
|
||||
};
|
||||
assert_eq!(res.dims(), [3, 2, 5]);
|
||||
// Same as pytorch default padding: use zeros.
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
|
||||
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
|
||||
let w = w.transpose(0, 1)?;
|
||||
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
||||
for w in [w.clone(), w.contiguous()?] {
|
||||
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -60,6 +80,17 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 4, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
||||
[
|
||||
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
||||
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
||||
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
||||
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
||||
]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -118,7 +149,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
@ -146,7 +177,25 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
let res = {
|
||||
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
|
||||
t.conv2d(&w, 0, 1, 1, 1)?
|
||||
};
|
||||
assert_eq!(res.dims(), [3, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
|
||||
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
|
||||
[
|
||||
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
@ -171,6 +220,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
// Dilations.
|
||||
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
||||
@ -209,6 +259,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -255,13 +306,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
|
||||
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1640,
|
||||
-0.0111, -0.1742, 0.0, 0.0, 0.0, 0.0, 2.6437, -2.0268, 1.1823, 0.0, 0.0, 0.0, 0.0,
|
||||
3.2855, -1.0324, 0.2539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0
|
||||
]
|
||||
);
|
||||
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
@ -363,6 +414,7 @@ print(w.grad.shape)
|
||||
print(w.grad[0])
|
||||
*/
|
||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
// conv-transposes are not implemented for metal
|
||||
use candle_core::Var;
|
||||
let t = Var::from_slice(
|
||||
&[
|
||||
@ -375,7 +427,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
],
|
||||
(1, 4, 5, 5),
|
||||
dev,
|
||||
@ -560,6 +612,251 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
// Conv Transpose 2d Test
|
||||
//tested against following python
|
||||
|
||||
// import torch
|
||||
// torch.manual_seed(4242)
|
||||
// padding = 4
|
||||
// outpadding = 2
|
||||
// dilation = 3
|
||||
// stride = 3
|
||||
// input = torch.randn((1, 4, 7, 5), requires_grad=True)
|
||||
// kernel = torch.randn((4, 2, 3, 5), requires_grad=True)
|
||||
// print("input", input.flatten())
|
||||
// print("kernel", kernel.flatten())
|
||||
// res = torch.nn.functional.conv_transpose2d(
|
||||
// input,
|
||||
// kernel,
|
||||
// stride=stride,
|
||||
// padding=padding,
|
||||
// dilation=dilation,
|
||||
// output_padding=outpadding,
|
||||
// )
|
||||
// res.retain_grad()
|
||||
// print(res.shape)
|
||||
// loss = (res**2).sum()
|
||||
// print(loss)
|
||||
// loss.backward()
|
||||
// print(input.grad.shape)
|
||||
// print("input grad", torch.round(input.grad, decimals=1))
|
||||
// print(kernel.grad.shape)
|
||||
// print("kernel grad", torch.round(kernel.grad.flatten(), decimals=1))
|
||||
|
||||
let padding = 4;
|
||||
let outpadding = 2;
|
||||
let dilation = 3;
|
||||
let stride = 3;
|
||||
|
||||
let t = Var::from_slice(
|
||||
&[
|
||||
0.4056_f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997,
|
||||
3.0616, 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843,
|
||||
0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013,
|
||||
-0.6836, 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130,
|
||||
1.3123, 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071,
|
||||
1.1586, 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090,
|
||||
0.2049, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323,
|
||||
-1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742,
|
||||
0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.5411, -1.1715, -2.4997, 2.3249, -0.8912,
|
||||
-0.4733, -0.5701, -2.8888, -1.4112, -0.5471, -0.9234, -1.1660, 0.4189, -0.7465,
|
||||
-0.6473, 0.1402, 0.7875, 0.5377, -0.6779, -0.8088, -0.4864, -0.2312, 0.9279, 0.1264,
|
||||
1.5480, 0.8265, -0.1025, 0.5138, -0.2512, 0.1576, 1.2705, 0.3641, -0.9325, 0.6451,
|
||||
-0.8537, 0.2378, 0.1794, 0.2752, -0.3687, -1.1149, -0.1410, -0.5829, -0.0892, 1.4258,
|
||||
-2.2789, 0.5270, 0.1825, 1.7007, -0.5263, -0.2954, 0.4440, 0.5537, 0.3492, 0.6186,
|
||||
1.6475, 0.2219,
|
||||
],
|
||||
(1, 4, 7, 5),
|
||||
dev,
|
||||
)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let w = Var::from_slice(
|
||||
&[
|
||||
-1.1744_f32, 0.3266, 2.5893, 1.0142, 0.1763, 0.7752, 0.6604, 0.2029, -0.2145, 0.7234,
|
||||
-0.3441, -1.5400, -0.6333, 0.6613, 0.2083, 0.6230, -1.7002, 0.3393, 0.4049, 1.0762,
|
||||
0.2723, 1.4181, 0.0029, -0.2122, 1.7668, 1.4168, 0.3320, -0.2719, 0.7932, -0.7204,
|
||||
0.4447, 0.1211, 0.5908, 1.0089, -0.1646, 1.8033, -0.6286, 0.2016, -0.3370, 1.2555,
|
||||
0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026, 0.8828, 2.4990,
|
||||
0.6811, -0.3369, 1.3320, 1.7669, -1.1067, 1.2958, -0.9415, -0.9655, -0.4462, 0.7181,
|
||||
0.5181, -1.1658, -1.8467, -0.7763, 1.2769, 0.8651, 0.9890, 1.5092, 0.7207, -0.8481,
|
||||
0.7417, 0.3375, -1.2685, 1.4572, 1.0915, 0.1093, -0.8550, -0.5831, -0.6309, -0.2509,
|
||||
0.5220, -0.0914, 0.7900, 0.1096, 0.3258, 0.2723, -1.0942, -0.3393, -0.1653, 0.5732,
|
||||
-0.8014, 1.8194, -1.9023, 0.2127, 1.8636, -0.8979, 0.1927, -0.2778, 0.3105, 0.0071,
|
||||
-1.1823, 0.2476, -0.7178, -1.3821, 1.0769, -0.4376, -0.9967, -0.1227, 1.6197, -1.0604,
|
||||
0.1372, 0.8141, -0.6163, 0.7304, -0.8285, 2.0636, -0.7176, 0.2495, -0.2581, -0.4478,
|
||||
],
|
||||
(4, 2, 3, 5),
|
||||
dev,
|
||||
)?;
|
||||
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 2904.0);
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
|
||||
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
|
||||
[
|
||||
// torch gets 89.1
|
||||
-89.0, -135.3, 136.7, 102.0, -53.4, 117.9, 118.6, -43.9, -218.0, -58.5, -114.3, -150.0,
|
||||
-15.6, 172.1, 66.3, -64.3, -27.9, -19.8, 31.7, 62.1, 5.5, 92.6, 28.2, -29.6, 55.9,
|
||||
52.7, -72.7, -119.8, 53.8, -25.5, 128.8, 19.3, 68.0, 190.9, -64.1, -86.2, -111.2,
|
||||
106.6, -67.7, 37.8, 115.9, 50.4, -77.7, -54.9, 22.3, -4.6, 89.8, 61.7, 122.4, 192.6,
|
||||
-27.8, -104.6, 57.0, 166.4, 27.1, 6.1, 18.7, -93.2, 31.5, 168.2, -3.7, -99.5, -55.5,
|
||||
-10.8, 17.5, 20.8, 16.9, 43.8, 42.0, -89.2, 18.8, -9.6, -84.1, 212.6, 19.7, -50.0,
|
||||
-52.0, -40.0, -166.6, -73.2, -10.8, -73.3, 31.5, -23.4, -79.3, -27.0, -84.4, -42.9,
|
||||
-20.3, 51.8, -16.7, 76.3, -120.5, -65.8, 96.5, -10.7, -45.9, -88.1, 65.4, -7.0, -1.5,
|
||||
92.8, -25.1, -114.2, -5.8, -14.8, -51.2, -20.7, 54.2, -79.8, 47.7, -29.2, -8.8, 53.5,
|
||||
-28.4, 85.0, -18.3, 107.0, 28.3, -71.8
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
|
||||
[
|
||||
[
|
||||
[32.3, -41.6, -24.0, 14.1, 17.6],
|
||||
[-11.8, 72.5, 87.6, 46.4, 61.5],
|
||||
[115.0, 108.5, -48.6, -63.4, -50.0],
|
||||
[51.3, 5.4, 31.3, 91.1, -30.9],
|
||||
[52.7, 92.8, -68.0, -47.0, 83.0],
|
||||
// pytorch gets -107.1
|
||||
[-10.2, -107.0, -5.4, 213.1, -31.4],
|
||||
[-2.4, 65.1, 9.2, -146.2, -24.2]
|
||||
],
|
||||
[
|
||||
[-72.6, -63.9, -61.9, 45.3, 33.0],
|
||||
[79.3, -0.5, -26.2, 78.2, 42.7],
|
||||
[90.9, 141.6, 40.1, -62.7, 37.0],
|
||||
[32.8, 198.2, -0.8, -31.1, 27.3],
|
||||
// torch gets 48.0
|
||||
[34.5, 34.9, -47.9, 127.6, -12.3],
|
||||
[-61.4, -3.2, -2.9, -10.9, -16.6],
|
||||
[74.6, 60.1, -68.9, 34.5, -50.4]
|
||||
],
|
||||
[
|
||||
[37.5, -56.9, -43.6, -13.5, -9.9],
|
||||
[40.0, 97.3, 28.6, 14.2, -30.1],
|
||||
[-22.3, -126.3, -68.8, -8.2, 26.1],
|
||||
[-32.9, 37.3, 108.5, -54.8, 29.6],
|
||||
[34.9, -176.9, -125.0, -28.3, -13.9],
|
||||
[-54.9, 142.6, 62.1, -80.4, -65.6],
|
||||
[7.4, -91.1, -67.6, 35.0, 39.7]
|
||||
],
|
||||
[
|
||||
[-57.2, -40.9, -10.1, 32.6, 29.4],
|
||||
[18.7, -18.0, 29.5, -1.2, 59.2],
|
||||
[-14.0, -74.4, 19.8, -117.0, 58.2],
|
||||
[-21.8, 163.5, -71.1, -99.0, 80.9],
|
||||
[-58.9, -10.9, 93.8, -139.6, 98.0],
|
||||
// torch gets 54.5
|
||||
[-54.4, 135.3, 6.0, -79.1, 134.6],
|
||||
[27.5, -76.0, 43.4, -2.8, -7.8]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
// Test the same, but then with the following properties, t & w are unmodified.
|
||||
let padding = 1;
|
||||
let outpadding = 1;
|
||||
let dilation = 1;
|
||||
let stride = 2;
|
||||
|
||||
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 3627.0); // torch gives 3626.8560
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
|
||||
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
|
||||
|
||||
#[rustfmt::skip]
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
|
||||
[
|
||||
[
|
||||
[ 13.2, -40.7, -9.7, -47.3, -82.7],
|
||||
[ -98.2, 9.7, 57.7, -6.2, 180.7],
|
||||
[ 100.2, 24.1, 3.7, -100.5, -48.1],
|
||||
[ -0.3, 13.5, -2.9, 80.0, -49.8],
|
||||
[ 47.2, -25.6, -74.4, 61.2, -18.4],
|
||||
[ 4.6, -69.5, 27.9, 66.5, -88.1],
|
||||
// 4th column on next row; torch is 4.2
|
||||
[ -12.0, 79.2, -40.0, 4.1, -97.1],
|
||||
],
|
||||
[
|
||||
[ -42.2, -36.5, -51.1, 7.5, 32.3],
|
||||
[ 74.1, -44.6, -68.8, 19.5, 7.7],
|
||||
[ 137.1, 54.2, 153.8, -58.0, 45.5],
|
||||
[ 24.4, -56.8, 9.7, -41.0, -14.5],
|
||||
[ -3.7, 72.6, 8.3, 134.8, 40.5],
|
||||
[ 43.2, -56.9, -47.5, -89.4, -95.4],
|
||||
[ 68.2, 108.1, -80.0, 57.0, -121.1]
|
||||
],
|
||||
[
|
||||
[ 31.1, -11.4, -34.8, 33.1, -44.2],
|
||||
[ 29.4, -31.6, -40.2, 13.7, 13.1],
|
||||
[ -0.8, -83.8, -7.8, -17.3, 78.2],
|
||||
[ 12.0, -118.7, 137.5, -76.7, 50.8],
|
||||
[ -28.7, -114.2, -3.7, -96.3, -13.8],
|
||||
[ -31.8, 28.5, -14.3, 4.6, 13.4],
|
||||
[ 28.0, -0.2, -38.9, -29.7, -59.0]
|
||||
],
|
||||
[
|
||||
[ -16.8, 38.5, 15.5, 26.6, 48.9],
|
||||
[ 14.5, 49.6, -24.8, 65.6, 61.7],
|
||||
[ 22.1, -64.7, -4.3, -51.0, 36.3],
|
||||
[ 31.0, -88.9, 47.1, -123.5, -3.8],
|
||||
[ -14.8, -39.8, 128.2, -110.3, 42.6],
|
||||
// 1st column on next row; torch is -7.2
|
||||
[ -7.1, 95.3, -21.3, -58.7, -13.9],
|
||||
[ 26.9, 21.3, 16.1, 70.3, 32.1]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
#[rustfmt::skip]
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
|
||||
[
|
||||
// 2nd value; torch gets -3.2, 3rd value; torch gets 221.8
|
||||
-2.460e+01, -3.100e+00, 2.219e+02, 7.400e+00, 5.620e+01,
|
||||
7.420e+01, 7.830e+01, 8.900e+00, 1.050e+01, 2.810e+01,
|
||||
5.100e+00, -1.046e+02, -1.572e+02, 8.710e+01, -9.840e+01,
|
||||
-4.230e+01, -1.898e+02, 1.860e+01, -3.570e+01, 9.810e+01,
|
||||
4.680e+01, 1.182e+02, 4.020e+01, -1.900e+00, 1.508e+02,
|
||||
1.094e+02, 1.018e+02, -4.620e+01, 1.591e+02, -2.320e+01,
|
||||
// 5th value; torch gets 7.1
|
||||
-8.450e+01, -4.600e+00, 6.330e+01, 1.123e+02, -7.000e+00,
|
||||
1.101e+02, -6.620e+01, 2.090e+01, -5.120e+01, 8.990e+01,
|
||||
9.050e+01, -6.990e+01, 6.800e+01, -9.250e+01, 1.380e+02,
|
||||
4.720e+01, 4.710e+01, 6.210e+01, 8.870e+01, 2.098e+02,
|
||||
3.870e+01, -1.390e+01, 6.270e+01, 1.484e+02, -9.920e+01,
|
||||
-4.200e+01, -1.505e+02, -1.480e+01, -2.620e+01, 8.220e+01,
|
||||
-3.350e+01, -2.260e+01, -1.198e+02, -5.080e+01, 1.259e+02,
|
||||
5.600e+01, 9.270e+01, 1.209e+02, 6.590e+01, -8.330e+01,
|
||||
7.000e+00, -2.600e+01, -1.133e+02, 3.870e+01, 4.020e+01,
|
||||
-6.300e+00, -8.710e+01, -5.150e+01, -8.510e+01, 2.000e-01,
|
||||
3.640e+01, -6.100e+00, 6.590e+01, -2.700e+00, 6.550e+01,
|
||||
// 4th value; torch gets 3.8
|
||||
5.300e+00, -6.760e+01, -4.270e+01, -3.900e+00, 2.880e+01,
|
||||
5.260e+01, 6.170e+01, -1.203e+02, -1.610e+01, 7.740e+01,
|
||||
-1.008e+02, -1.070e+01, -9.900e+00, 3.300e+00, -2.620e+01,
|
||||
-4.440e+01, 2.580e+01, -6.920e+01, -4.220e+01, 1.108e+02,
|
||||
1.240e+01, -3.440e+01, -2.800e+00, 7.880e+01, -6.690e+01,
|
||||
1.480e+01, 2.310e+01, -4.260e+01, -1.500e+00, -4.760e+01,
|
||||
5.350e+01, -2.260e+01, 8.000e-01, -3.840e+01, -2.500e+00
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -112,3 +112,70 @@ fn custom_op1_with_backward() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl candle_core::InplaceOp1 for Elu {
|
||||
fn name(&self) -> &'static str {
|
||||
"elu"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, s: &mut CpuStorage, _l: &Layout) -> Result<()> {
|
||||
let alpha = self.alpha;
|
||||
match s {
|
||||
CpuStorage::BF16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
||||
CpuStorage::F16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
||||
CpuStorage::F32(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
||||
CpuStorage::F64(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
||||
_ => candle_core::bail!("unsupported dtype for inplace elu"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inplace_op1() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
|
||||
let t = (t - 5.)?;
|
||||
t.inplace_op1(&Elu { alpha: 1. })?;
|
||||
assert_eq!(
|
||||
to_vec1_round(&t, 4)?,
|
||||
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "cuda", feature = "metal"))]
|
||||
#[allow(clippy::approx_constant)]
|
||||
#[test]
|
||||
fn ug_op() -> Result<()> {
|
||||
let kernel = {
|
||||
use ug::lang::op;
|
||||
|
||||
let layout = ug::Layout::from_shape(&[12]);
|
||||
let ptr = op::Arg::ptr(ug::DType::F32);
|
||||
let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?;
|
||||
let src = op::unary(op::UnaryOp::Exp, src)?;
|
||||
let st = op::store(ptr.id(), layout, src)?;
|
||||
let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]);
|
||||
let opts: ug::lower_op::Opts = Default::default();
|
||||
kernel.lower(&opts)?
|
||||
};
|
||||
let device = if candle_core::utils::cuda_is_available() {
|
||||
Device::new_cuda(0)?
|
||||
} else if candle_core::utils::metal_is_available() {
|
||||
Device::new_metal(0)?
|
||||
} else {
|
||||
candle_core::bail!("metal/cuda is mandatory for this test")
|
||||
};
|
||||
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
|
||||
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
|
||||
t.inplace_op1(&op)?;
|
||||
assert_eq!(
|
||||
to_vec1_round(&t, 2)?,
|
||||
&[
|
||||
1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47,
|
||||
59874.13
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
BIN
candle-core/tests/fortran_tensor_3d.pth
Normal file
BIN
candle-core/tests/fortran_tensor_3d.pth
Normal file
Binary file not shown.
@ -1,5 +1,6 @@
|
||||
#![allow(clippy::approx_constant)]
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
||||
use candle_core::{test_device, test_utils, DType, Device, Shape, Tensor, Var};
|
||||
|
||||
fn simple_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||
@ -96,24 +97,24 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
y.to_vec1::<f32>()?,
|
||||
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[20.0855, 2.7183, 54.5982, 1.1618]
|
||||
);
|
||||
assert_eq!(
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[20.0855, 2.7183, 54.5982, 1.1618]
|
||||
);
|
||||
let y = x.exp()?.sqr()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
y.to_vec1::<f32>()?,
|
||||
[403.4288, 7.3890557, 2980.9578, 1.3498588]
|
||||
test_utils::to_vec1_round(&y, 3)?,
|
||||
[403.429, 7.389, 2980.958, 1.35]
|
||||
);
|
||||
// exp(x)^2 = exp(2*x)
|
||||
assert_eq!(
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[806.8576, 14.778111, 5961.9155, 2.6997175]
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[806.86, 14.78, 5961.92, 2.7]
|
||||
);
|
||||
let y = x.sin()?;
|
||||
let grads = y.backward()?;
|
||||
@ -261,6 +262,7 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let y = elu_x.elu(2.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||
@ -270,6 +272,194 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
// testing compared to pytorch nn.Silu()
|
||||
let y = x.silu()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.8577, 0.7311, 3.9281, 0.0806]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||
);
|
||||
|
||||
if device.is_cpu() {
|
||||
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
|
||||
let y = x.interpolate1d(12)?.reshape(36)?;
|
||||
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16.,
|
||||
17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
|
||||
33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
let grads = loss.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(grad_x, 4)?,
|
||||
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
|
||||
);
|
||||
}
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
||||
35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// row 1
|
||||
// 1+2+7+8 = 18
|
||||
// 3+4+9+10 = 26
|
||||
// 5+6+11+12 = 34
|
||||
// row 2
|
||||
// 13+14+19+20 = 66
|
||||
// 15+16+21+22 = 74
|
||||
// 17+18+23+24 = 82
|
||||
// row 3
|
||||
// 25+26+31+32 = 114
|
||||
// 27+28+33+34 = 122
|
||||
// 29+30+35+36 = 130
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
||||
[[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
||||
35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// row 1
|
||||
// 1+2+3+7+8+9+13+14+15 = 72
|
||||
// 4+5+6+10+11+12+16+17+18 = 99
|
||||
// row 2
|
||||
// 19+20+21+25+26+27+31+32+33 = 234
|
||||
// 22+23+24+28+29+30+34+35+36 = 243
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
||||
[[72_f32, 99.], [234., 261.]]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?;
|
||||
|
||||
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04.,
|
||||
05., 06., 07., 08.,
|
||||
09., 10., 11., 12.,
|
||||
13., 14., 15., 16.,
|
||||
17., 18., 19., 20.,
|
||||
21., 22., 23., 24.,
|
||||
25., 26., 27., 28.,
|
||||
29., 30., 31., 32.
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// m1r1
|
||||
// 1+2+5+6=14
|
||||
// 3+4+7+8=22
|
||||
// m1r2
|
||||
// 9+10+13+14=46
|
||||
// 11+12+15+16=54
|
||||
// m2r1
|
||||
// 17+18+21+22=78
|
||||
// 19+20+23+24=86
|
||||
// m2r2
|
||||
// 25+26+29+30=110
|
||||
// 27+28+31+32=118
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
||||
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(
|
||||
&[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04.,
|
||||
05., 06., 07., 08.,
|
||||
09., 10., 11., 12.,
|
||||
13., 14., 15., 16.,
|
||||
17., 18., 19., 20.,
|
||||
21., 22., 23., 24.,
|
||||
25., 26., 27., 28.,
|
||||
29., 30., 31., 32.
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// m1r1
|
||||
// 1+2+5+6=14
|
||||
// 3+4+7+8=22
|
||||
// m1r2
|
||||
// 9+10+13+14=46
|
||||
// 11+12+15+16=54
|
||||
// m2r1
|
||||
// 17+18+21+22=78
|
||||
// 19+20+23+24=86
|
||||
// m2r2
|
||||
// 25+26+29+30=110
|
||||
// 27+28+31+32=118
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
||||
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -315,6 +505,36 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flip_backprop() -> Result<()> {
|
||||
let device = &Device::Cpu;
|
||||
|
||||
// Create a tensor (leaf node) that requires gradients
|
||||
let x = Var::ones((2, 2), DType::F64, device)?;
|
||||
let weights = Tensor::arange(1.0, 5.0, device)?.reshape((2, 2))?;
|
||||
|
||||
let y = x.matmul(&weights)?;
|
||||
let expected_y = Tensor::from_vec(vec![4.0, 6.0, 4.0, 6.0], (2, 2), device)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&y, &expected_y)?;
|
||||
|
||||
let z = y.flip(&[1])?;
|
||||
let expected_z = Tensor::from_vec(vec![6.0, 4.0, 6.0, 4.0], (2, 2), device)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&z, &expected_z)?;
|
||||
|
||||
let loss = z.sum_all()?;
|
||||
|
||||
let grad_store = loss.backward()?;
|
||||
let grad_x = grad_store.get_id(x.id()).unwrap();
|
||||
|
||||
let flipped_weights = weights.flip(&[1])?;
|
||||
let dloss_dy = Tensor::ones((2, 2), DType::F64, device)?;
|
||||
// dloss/dx = dloss/dy @ dy/dx = ones @ weight.flip.T
|
||||
let expected_grad = dloss_dy.matmul(&flipped_weights.t()?)?;
|
||||
candle_core::test_utils::assert_tensor_eq(grad_x, &expected_grad)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(
|
||||
simple_grad,
|
||||
simple_grad_cpu,
|
||||
|
@ -88,7 +88,7 @@ fn strided_blocks() -> Result<()> {
|
||||
}
|
||||
};
|
||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||
let tensor = tensor.i((.., 1))?;
|
||||
let tensor = tensor.i((.., 1))?.contiguous()?;
|
||||
match tensor.strided_blocks() {
|
||||
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
assert_eq!(start_offset, 0);
|
||||
@ -100,6 +100,20 @@ fn strided_blocks() -> Result<()> {
|
||||
}
|
||||
};
|
||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||
let tensor = tensor.i((.., 1))?;
|
||||
match tensor.strided_blocks() {
|
||||
candle::StridedBlocks::SingleBlock { .. } => {
|
||||
panic!("unexpected block structure")
|
||||
}
|
||||
candle::StridedBlocks::MultipleBlocks {
|
||||
block_len,
|
||||
block_start_index,
|
||||
} => {
|
||||
assert_eq!(block_len, 4);
|
||||
assert_eq!(block_start_index.collect::<Vec<_>>(), &[4, 16])
|
||||
}
|
||||
};
|
||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||
match tensor.t()?.strided_blocks() {
|
||||
candle::StridedBlocks::SingleBlock { .. } => {
|
||||
panic!("unexpected block structure")
|
||||
|
126
candle-core/tests/matmul_tests.rs
Normal file
126
candle-core/tests/matmul_tests.rs
Normal file
@ -0,0 +1,126 @@
|
||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn matmul(device: &Device) -> Result<()> {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||
|
||||
let data = vec![1.0f32, 2.0];
|
||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||
let data = vec![3.0f32, 4.0];
|
||||
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||
|
||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
||||
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
||||
|
||||
// Also perform the matmul on contiguous transposed versions.
|
||||
let a_tt = a.t()?.contiguous()?.t()?;
|
||||
assert!(!a_tt.is_contiguous());
|
||||
assert_eq!(a.dims(), a_tt.dims());
|
||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||
|
||||
let b_tt = b.t()?.contiguous()?.t()?;
|
||||
assert!(!b_tt.is_contiguous());
|
||||
assert_eq!(b.dims(), b_tt.dims());
|
||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||
|
||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn matmul_bf16(device: &Device) -> Result<()> {
|
||||
if !device.supports_bf16() {
|
||||
return Ok(());
|
||||
}
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?;
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let b = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?;
|
||||
|
||||
let c = a.matmul(&b)?.to_dtype(DType::F32)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||
let out = lhs.broadcast_matmul(&rhs)?;
|
||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||
for idx1 in 0..3 {
|
||||
for idx2 in 0..6 {
|
||||
let out = out.i((idx1, idx2))?;
|
||||
let lhs = lhs.i((idx1, 0))?;
|
||||
let rhs = rhs.i(idx2)?;
|
||||
let out2 = lhs.matmul(&rhs);
|
||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||
// With cuda, we see errors of up to ~1e-12.
|
||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/candle/issues/1948
|
||||
fn squeeze_mm(device: &Device) -> Result<()> {
|
||||
let seq_len = 8_usize;
|
||||
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
|
||||
let x = a.i((.., seq_len - 1, ..))?;
|
||||
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
|
||||
let x = x.matmul(&w)?;
|
||||
assert_eq!(x.dims(), &[1, 32]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/candle/issues/1992
|
||||
fn mm_layout(device: &Device) -> Result<()> {
|
||||
let a = Tensor::arange(0f32, 16f32, device)?.reshape((1, 1, 4, 4))?;
|
||||
let b = Tensor::arange(0f32, 8f32, device)?.reshape((1, 1, 4, 2))?;
|
||||
let mm1 = a.matmul(&b)?;
|
||||
// Forces the layout to be:
|
||||
// shape: [1, 1, 4, 2], stride: [8, 2, 2, 1], start_offset: 0
|
||||
// This is still a contiguous matrix but matmul checks are only the two last dimensions have
|
||||
// non 1 sizes but matmul check may be reluctant to handle it.
|
||||
let b = b.transpose(1, 2)?.force_contiguous()?.transpose(1, 2)?;
|
||||
let mm2 = a.matmul(&b)?;
|
||||
let diff = (mm1 - mm2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
matmul_bf16,
|
||||
matmul_bf16_cpu,
|
||||
matmul_bf16_gpu,
|
||||
matmul_bf16_metal
|
||||
);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
|
||||
test_device!(mm_layout, mm_layout_cpu, mm_layout_gpu, mm_layout_metal);
|
@ -43,6 +43,9 @@ res = torch.nn.functional.avg_pool2d(t, 2)
|
||||
print(res)
|
||||
*/
|
||||
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
|
37
candle-core/tests/pth.py
Normal file
37
candle-core/tests/pth.py
Normal file
@ -0,0 +1,37 @@
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
|
||||
# Write a trivial tensor to a pt file
|
||||
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
|
||||
o = OrderedDict()
|
||||
o["test"] = a
|
||||
|
||||
# Write a trivial tensor to a pt file
|
||||
torch.save(o, "test.pt")
|
||||
|
||||
############################################################################################################
|
||||
# Write a trivial tensor to a pt file with a key
|
||||
torch.save({"model_state_dict": o}, "test_with_key.pt")
|
||||
|
||||
############################################################################################################
|
||||
# Create a tensor with fortran contiguous memory layout
|
||||
import numpy as np
|
||||
|
||||
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
|
||||
# For example, creating a 2x3x4 array
|
||||
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
|
||||
|
||||
# Verify the memory order
|
||||
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
|
||||
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
|
||||
|
||||
# Step 2: Convert the NumPy array to a PyTorch tensor
|
||||
tensor_fortran = torch.from_numpy(array_fortran)
|
||||
|
||||
# Verify the tensor layout
|
||||
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
|
||||
|
||||
# Step 3: Save the PyTorch tensor to a .pth file
|
||||
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
|
||||
|
||||
print("3D Tensor saved with Fortran layout.")
|
31
candle-core/tests/pth_tests.rs
Normal file
31
candle-core/tests/pth_tests.rs
Normal file
@ -0,0 +1,31 @@
|
||||
/// Regression test for pth files not loading on Windows.
|
||||
#[test]
|
||||
fn test_pth() {
|
||||
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
|
||||
tensors.get("test").unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pth_with_key() {
|
||||
let tensors =
|
||||
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
|
||||
.unwrap();
|
||||
tensors.get("test").unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pth_fortran_congiguous() {
|
||||
let tensors =
|
||||
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
|
||||
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
|
||||
|
||||
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
|
||||
|
||||
assert_eq!(
|
||||
tensor.to_vec3::<i64>().unwrap(),
|
||||
[
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
|
||||
]
|
||||
);
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,31 @@
|
||||
use candle_core::{DType, Result, Tensor};
|
||||
|
||||
struct TmpFile(std::path::PathBuf);
|
||||
|
||||
impl TmpFile {
|
||||
fn create(base: &str) -> TmpFile {
|
||||
let filename = std::env::temp_dir().join(format!(
|
||||
"candle-{}-{}-{:?}",
|
||||
base,
|
||||
std::process::id(),
|
||||
std::thread::current().id(),
|
||||
));
|
||||
TmpFile(filename)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::AsRef<std::path::Path> for TmpFile {
|
||||
fn as_ref(&self) -> &std::path::Path {
|
||||
self.0.as_path()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TmpFile {
|
||||
fn drop(&mut self) {
|
||||
std::fs::remove_file(&self.0).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npy() -> Result<()> {
|
||||
let npy = Tensor::read_npy("tests/test.npy")?;
|
||||
@ -22,3 +48,24 @@ fn npz() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn safetensors() -> Result<()> {
|
||||
use candle_core::safetensors::Load;
|
||||
|
||||
let tmp_file = TmpFile::create("st");
|
||||
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
|
||||
t.save_safetensors("t", &tmp_file)?;
|
||||
// Load from file.
|
||||
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
|
||||
let t2 = st.get("t").unwrap();
|
||||
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0f32);
|
||||
// Load from bytes.
|
||||
let bytes = std::fs::read(tmp_file)?;
|
||||
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
|
||||
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
|
||||
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0f32);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
@ -29,6 +29,44 @@ fn ones(device: &Device) -> Result<()> {
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
|
||||
[
|
||||
[
|
||||
half::f16::from_f32(1.0),
|
||||
half::f16::from_f32(1.0),
|
||||
half::f16::from_f32(1.0)
|
||||
],
|
||||
[
|
||||
half::f16::from_f32(1.0),
|
||||
half::f16::from_f32(1.0),
|
||||
half::f16::from_f32(1.0)
|
||||
]
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
|
||||
[
|
||||
[
|
||||
half::bf16::from_f32(1.0),
|
||||
half::bf16::from_f32(1.0),
|
||||
half::bf16::from_f32(1.0)
|
||||
],
|
||||
[
|
||||
half::bf16::from_f32(1.0),
|
||||
half::bf16::from_f32(1.0),
|
||||
half::bf16::from_f32(1.0)
|
||||
]
|
||||
],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn full(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
|
||||
[[42, 42, 42], [42, 42, 42]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -88,6 +126,40 @@ fn clamp(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn asort(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let indexes = tensor.arg_sort_last_dim(true)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||
);
|
||||
let indexes = tensor.arg_sort_last_dim(false)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||
);
|
||||
let (sorted, indexes) = tensor.sort_last_dim(true)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||
);
|
||||
assert_eq!(
|
||||
sorted.to_vec2::<f32>()?,
|
||||
[[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]]
|
||||
);
|
||||
let (sorted, indexes) = tensor.sort_last_dim(false)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
sorted.to_vec2::<f32>()?,
|
||||
[[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -98,6 +170,9 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||
]
|
||||
);
|
||||
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
|
||||
let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
|
||||
assert!(max_diff.to_vec0::<f32>()? < 5e-3);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||
[
|
||||
@ -112,6 +187,13 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
|
||||
[
|
||||
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
|
||||
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||
@ -133,6 +215,27 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||
[3000.0, 300.]
|
||||
);
|
||||
let tensor = Tensor::new(
|
||||
&[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1],
|
||||
device,
|
||||
)?;
|
||||
assert_eq!(
|
||||
tensor.sign()?.to_vec1::<f32>()?,
|
||||
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
|
||||
);
|
||||
let tensor = Tensor::new(&[-1.0f32, 0., -2., 3.], device)?;
|
||||
let y = tensor.elu(2.)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||
);
|
||||
// This test failed on metal prior to the following PR:
|
||||
// https://github.com/huggingface/candle/pull/2490
|
||||
let y = tensor.reshape((2, 2))?.t()?.elu(2.)?.flatten_all()?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, -1.7293, 0.0000, 3.0000]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -605,6 +708,32 @@ fn broadcast(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_set(device: &Device) -> Result<()> {
|
||||
let (b, h, max_t, d) = (2, 4, 7, 3);
|
||||
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
|
||||
cache.slice_set(&tensor, 2, 0)?;
|
||||
let cache_t = cache.narrow(2, 0, 4)?;
|
||||
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
cache.slice_set(&tensor, 2, 1)?;
|
||||
let cache_t = cache.narrow(2, 1, 4)?;
|
||||
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
|
||||
cache.slice_set(&ones, 2, 6)?;
|
||||
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let diff = (cache.narrow(2, 6, 1)? - 1.)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
// This used to create a deadlock rather than returning an actual error.
|
||||
assert!(cache.slice_set(&cache, 0, 0).is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cat(device: &Device) -> Result<()> {
|
||||
// 1D
|
||||
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
@ -657,6 +786,31 @@ fn cat(device: &Device) -> Result<()> {
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
||||
]
|
||||
);
|
||||
|
||||
// 3D
|
||||
let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?;
|
||||
let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?;
|
||||
let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?;
|
||||
|
||||
let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
||||
|
||||
let t1 = t1.t()?.contiguous()?.t()?;
|
||||
let t2 = t2.t()?.contiguous()?.t()?;
|
||||
let t3 = t3.t()?.contiguous()?.t()?;
|
||||
let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
||||
|
||||
let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?;
|
||||
assert_eq!(diff.to_vec0::<f32>()?, 104.0);
|
||||
assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::<i64>()?, 0);
|
||||
assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::<i64>()?, 16);
|
||||
assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::<i64>()?, 20);
|
||||
assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::<i64>()?, 44);
|
||||
assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::<i64>()?, 100);
|
||||
assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::<i64>()?, 112);
|
||||
assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::<i64>()?, 101);
|
||||
assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::<i64>()?, 105);
|
||||
assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::<i64>()?, 10013);
|
||||
assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::<i64>()?, 10031);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -667,6 +821,8 @@ fn embeddings(device: &Device) -> Result<()> {
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -694,44 +850,47 @@ fn index_select(device: &Device) -> Result<()> {
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 2.0, 1.0],
|
||||
[3.0, 5.0, 4.0],
|
||||
[6.0, 8.0, 7.0],
|
||||
[9.0, 11.0, 10.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
);
|
||||
for dtype in [DType::U8, DType::U32, DType::I64] {
|
||||
let ids = ids.to_dtype(dtype)?;
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 2.0, 1.0],
|
||||
[3.0, 5.0, 4.0],
|
||||
[6.0, 8.0, 7.0],
|
||||
[9.0, 11.0, 10.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
);
|
||||
|
||||
// Test when selecting dim > 0 with ids size different from elem count of
|
||||
// target dim in source/input.
|
||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||
// Test when selecting dim > 0 with ids size different from elem count of
|
||||
// target dim in source/input.
|
||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -890,74 +1049,280 @@ fn gather(device: &Device) -> Result<()> {
|
||||
let ids = Tensor::new(&[[0u32, 2u32, 0u32], [0u32, 1u32, 1u32]], device)?;
|
||||
let hs = t.gather(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn matmul(device: &Device) -> Result<()> {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
// Random data
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||
// Dim: 0
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
[
|
||||
[108_f32, -47., 16., -56., -83., -130., 210.],
|
||||
[253., 95., 151., 228., -210., -123., -127.],
|
||||
[-9., -217., 2., -78., 163., 245., -204.],
|
||||
[-246., 79., -238., 88., -226., -184., 171.],
|
||||
[8., -48., -153., 234., -34., 166., -153.],
|
||||
[124., 0., -10., -61., -242., -15., -238.],
|
||||
],
|
||||
[
|
||||
[12., -64., -199., 244., -240., 156., -128.],
|
||||
[173., -57., 4., -198., 233., -110., 238.],
|
||||
[95., 82., 0., 240., 53., -211., 209.],
|
||||
[-122., 167., -212., 227., -144., 61., 118.],
|
||||
[-63., -146., 200., 244., 168., -167., 116.],
|
||||
[-125., -147., 110., -253., -178., -250., -18.],
|
||||
],
|
||||
[
|
||||
[57., 86., -50., 56., 92., 205., -78.],
|
||||
[-137., -156., -18., 248., -61., -239., 14.],
|
||||
[-248., -30., -50., -70., -251., 250., -83.],
|
||||
[-221., 67., 72., 59., -24., -154., 232.],
|
||||
[-144., -23., -74., 5., 93., 171., 205.],
|
||||
[46., -77., -38., -226., 246., 161., -17.],
|
||||
],
|
||||
[
|
||||
[-153., -231., -236., 161., 126., 2., -22.],
|
||||
[-229., -41., 209., 164., 234., 160., 57.],
|
||||
[223., 254., -186., -162., -46., -160., -102.],
|
||||
[65., 30., 213., -253., 59., 224., -154.],
|
||||
[-82., -203., -177., 17., 31., -256., -246.],
|
||||
[176., -135., -65., 54., -56., 210., 76.],
|
||||
],
|
||||
[
|
||||
[-10., -245., 168., 124., -14., -33., -178.],
|
||||
[25., -43., -39., 132., -89., 169., 179.],
|
||||
[187., -215., 32., -133., 87., -7., -168.],
|
||||
[-224., -215., -5., -230., -58., -162., 128.],
|
||||
[158., -137., -122., -100., -202., -83., 136.],
|
||||
[30., -185., -144., 250., 209., -40., 127.],
|
||||
],
|
||||
[
|
||||
[-196., 108., -245., 122., 146., -228., 62.],
|
||||
[-1., -66., 160., 137., 13., -172., -21.],
|
||||
[244., 199., -164., 28., 119., -175., 198.],
|
||||
[-62., 253., -162., 195., -95., -230., -211.],
|
||||
[123., -72., -26., -107., -139., 64., 245.],
|
||||
[11., -126., -182., 108., -12., 184., -127.],
|
||||
],
|
||||
[
|
||||
[-159., 126., 176., 161., 73., -111., -138.],
|
||||
[-187., 214., -217., -33., -223., -201., -212.],
|
||||
[-61., -120., -166., -172., -95., 53., 196.],
|
||||
[-33., 86., 134., -152., 154., -53., 74.],
|
||||
[186., -28., -154., -174., 141., -109., 217.],
|
||||
[82., 35., 252., 145., 181., 74., -87.],
|
||||
],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let data = vec![1.0f32, 2.0];
|
||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||
let data = vec![3.0f32, 4.0];
|
||||
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||
let ids = Tensor::new(
|
||||
&[
|
||||
[
|
||||
[6_u32, 6, 4, 3, 4, 4, 6],
|
||||
[3, 3, 2, 4, 4, 4, 6],
|
||||
[3, 3, 0, 2, 4, 6, 4],
|
||||
[2, 5, 1, 2, 6, 6, 1],
|
||||
[2, 1, 6, 5, 3, 2, 3],
|
||||
[6, 1, 0, 1, 0, 2, 6],
|
||||
],
|
||||
[
|
||||
[4, 6, 4, 3, 3, 3, 2],
|
||||
[4, 3, 2, 4, 4, 4, 6],
|
||||
[2, 3, 0, 2, 4, 6, 4],
|
||||
[6, 5, 1, 2, 6, 6, 1],
|
||||
[4, 1, 6, 5, 3, 2, 3],
|
||||
[1, 1, 0, 1, 0, 2, 6],
|
||||
],
|
||||
[
|
||||
[3, 6, 4, 3, 3, 3, 2],
|
||||
[2, 3, 2, 4, 4, 4, 6],
|
||||
[4, 3, 0, 2, 4, 6, 4],
|
||||
[0, 5, 1, 2, 6, 6, 1],
|
||||
[6, 1, 6, 5, 3, 2, 3],
|
||||
[4, 1, 0, 1, 0, 2, 6],
|
||||
],
|
||||
[
|
||||
[0, 6, 4, 3, 3, 3, 2],
|
||||
[5, 3, 2, 4, 4, 4, 6],
|
||||
[0, 3, 0, 2, 4, 6, 4],
|
||||
[3, 5, 1, 2, 6, 6, 1],
|
||||
[0, 1, 6, 5, 3, 2, 3],
|
||||
[3, 1, 0, 1, 0, 2, 6],
|
||||
],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||
let hs = t.gather(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec3::<f32>()?,
|
||||
&[
|
||||
[
|
||||
[-159_f32, 126., 168., 161., -14., -33., -138.],
|
||||
[-229., -41., -18., 132., -89., 169., -212.],
|
||||
[223., 254., 2., -70., 87., 53., -168.],
|
||||
[-221., 253., -212., 59., 154., -53., 118.],
|
||||
[-144., -146., -154., -107., 31., 171., -246.],
|
||||
[82., -147., -10., -253., -242., 161., -87.]
|
||||
],
|
||||
[
|
||||
[-10., 126., 168., 161., 126., 2., -78.],
|
||||
[25., -41., -18., 132., -89., 169., -212.],
|
||||
[-248., 254., 2., -70., 87., 53., -168.],
|
||||
[-33., 253., -212., 59., 154., -53., 118.],
|
||||
[158., -146., -154., -107., 31., 171., -246.],
|
||||
[-125., -147., -10., -253., -242., 161., -87.]
|
||||
],
|
||||
[
|
||||
[-153., 126., 168., 161., 126., 2., -78.],
|
||||
[-137., -41., -18., 132., -89., 169., -212.],
|
||||
[187., 254., 2., -70., 87., 53., -168.],
|
||||
[-246., 253., -212., 59., 154., -53., 118.],
|
||||
[186., -146., -154., -107., 31., 171., -246.],
|
||||
[30., -147., -10., -253., -242., 161., -87.]
|
||||
],
|
||||
[
|
||||
[108., 126., 168., 161., 126., 2., -78.],
|
||||
[-1., -41., -18., 132., -89., 169., -212.],
|
||||
[-9., 254., 2., -70., 87., 53., -168.],
|
||||
[65., 253., -212., 59., 154., -53., 118.],
|
||||
[8., -146., -154., -107., 31., 171., -246.],
|
||||
[176., -147., -10., -253., -242., 161., -87.]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
||||
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
||||
// Dim: 1
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
[
|
||||
[-117_f32, -175., 69., -163.],
|
||||
[200., 242., -21., -67.],
|
||||
[179., 150., -126., -75.],
|
||||
[-118., 38., -138., -13.],
|
||||
[-221., 136., -185., 180.],
|
||||
[58., 182., -204., -149.],
|
||||
],
|
||||
[
|
||||
[3., -148., -58., -154.],
|
||||
[-43., 45., -108., 4.],
|
||||
[-69., -249., -71., -21.],
|
||||
[80., 110., -152., -235.],
|
||||
[-88., 7., 92., -250.],
|
||||
[-186., 207., -242., 98.],
|
||||
],
|
||||
[
|
||||
[238., 19., 64., -242.],
|
||||
[-150., -97., 218., 58.],
|
||||
[111., -233., 204., -212.],
|
||||
[-242., -232., 83., 42.],
|
||||
[153., 62., -251., 219.],
|
||||
[-117., 36., -119., 10.],
|
||||
],
|
||||
[
|
||||
[215., 159., -169., -27.],
|
||||
[-83., 101., -88., 169.],
|
||||
[-205., 93., 225., -64.],
|
||||
[-162., 240., 214., 23.],
|
||||
[-112., 6., 21., 245.],
|
||||
[-38., 113., 93., 215.],
|
||||
],
|
||||
[
|
||||
[91., -188., -148., 101.],
|
||||
[74., 203., -35., 55.],
|
||||
[-116., -130., -153., -96.],
|
||||
[58., 22., -45., -194.],
|
||||
[-221., -134., 73., 159.],
|
||||
[-203., -254., 31., 235.],
|
||||
],
|
||||
[
|
||||
[105., -53., 61., 186.],
|
||||
[-195., 234., 75., -1.],
|
||||
[51., 139., 160., -108.],
|
||||
[-173., -167., 161., 19.],
|
||||
[83., -246., 156., -222.],
|
||||
[109., 39., -149., 137.],
|
||||
],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
||||
let ids = Tensor::new(
|
||||
&[
|
||||
[[4_u32, 4, 4, 2]],
|
||||
[[0, 4, 4, 3]],
|
||||
[[1, 5, 3, 4]],
|
||||
[[0, 3, 3, 2]],
|
||||
[[1, 1, 5, 2]],
|
||||
[[1, 4, 5, 4]],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
||||
// Also perform the matmul on contiguous transposed versions.
|
||||
let a_tt = a.t()?.contiguous()?.t()?;
|
||||
assert!(!a_tt.is_contiguous());
|
||||
assert_eq!(a.dims(), a_tt.dims());
|
||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||
let hs = t.gather(&ids, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[-221., 136., -185., -75.]],
|
||||
[[3., 7., 92., -235.]],
|
||||
[[-150., 36., 83., 219.]],
|
||||
[[215., 240., 214., -64.]],
|
||||
[[74., 203., 31., -96.]],
|
||||
[[-195., -246., -149., -222.]]
|
||||
]
|
||||
);
|
||||
|
||||
let b_tt = b.t()?.contiguous()?.t()?;
|
||||
assert!(!b_tt.is_contiguous());
|
||||
assert_eq!(b.dims(), b_tt.dims());
|
||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||
// Dim: 2
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
[[-162_f32, 202.], [-126., -39.], [35., -65.], [1., 80.]],
|
||||
[[37., 248.], [-191., 89.], [117., -40.], [-217., 220.]],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
Ok(())
|
||||
}
|
||||
let ids = Tensor::new(&[[[1_u32], [0], [1], [1]], [[0], [1], [0], [1]]], device)?;
|
||||
|
||||
let hs = t.gather(&ids, 2)?;
|
||||
assert_eq!(
|
||||
hs.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[202.], [-126.], [-65.], [80.]],
|
||||
[[37.], [89.], [117.], [220.]]
|
||||
]
|
||||
);
|
||||
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
[[-21_f32, -197.], [194., 122.]],
|
||||
[[255., -106.], [-191., 250.]],
|
||||
[[33., -117.], [43., 10.]],
|
||||
[[-130., 238.], [-217., -92.]],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let ids = Tensor::new(
|
||||
&[
|
||||
[[0_u32, 1], [1, 0]],
|
||||
[[1, 0], [0, 1]],
|
||||
[[0, 1], [0, 1]],
|
||||
[[1, 0], [1, 0]],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let hs = t.gather(&ids, 2)?;
|
||||
assert_eq!(
|
||||
hs.to_vec3::<f32>()?,
|
||||
&[
|
||||
[[-21., -197.], [122., 194.]],
|
||||
[[-106., 255.], [-191., 250.]],
|
||||
[[33., -117.], [43., 10.]],
|
||||
[[238., -130.], [-92., -217.]]
|
||||
]
|
||||
);
|
||||
|
||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||
let out = lhs.broadcast_matmul(&rhs)?;
|
||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||
for idx1 in 0..3 {
|
||||
for idx2 in 0..6 {
|
||||
let out = out.i((idx1, idx2))?;
|
||||
let lhs = lhs.i((idx1, 0))?;
|
||||
let rhs = rhs.i(idx2)?;
|
||||
let out2 = lhs.matmul(&rhs);
|
||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||
// With cuda, we see errors of up to ~1e-12.
|
||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1065,18 +1430,66 @@ fn broadcasting(device: &Device) -> Result<()> {
|
||||
fn randn(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||
assert_eq!(tensor.dims(), [5, 3]);
|
||||
// Check that the seed gets updated by checking that
|
||||
// a new series of numbers is generated each time
|
||||
let tensor2 = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
||||
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||
assert_eq!(tensor.dims(), [5, 3]);
|
||||
// Check that the seed gets updated by checking that
|
||||
// a new series of numbers is generated each time
|
||||
let tensor2 = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
||||
// We do not expect deterministic elements at any index.
|
||||
// There once was a bug that had a deterministic zero element in evenly sized tensors.
|
||||
const N: usize = 2;
|
||||
let v = (0..100)
|
||||
.map(|_| Tensor::randn(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
assert!(
|
||||
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
||||
"There are deterministic values in the randn tensors"
|
||||
);
|
||||
let v = (0..100)
|
||||
.map(|_| Tensor::rand(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
assert!(
|
||||
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
||||
"There are deterministic values in the rand tensors"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn zero_dim(device: &Device) -> Result<()> {
|
||||
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
|
||||
assert_eq!(t.dims3()?, (4, 0, 1));
|
||||
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
|
||||
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
|
||||
assert_eq!(t_cat.dims3()?, (4, 3, 1));
|
||||
let t_cat = Tensor::cat(&[&t, &t], 1)?;
|
||||
assert_eq!(t_cat.dims3()?, (4, 0, 1));
|
||||
let t_unary = t.sqrt()?;
|
||||
assert_eq!(t_unary.dims3()?, (4, 0, 1));
|
||||
let t_plus = (&t + 1.)?;
|
||||
assert_eq!(t_plus.dims3()?, (4, 0, 1));
|
||||
let t_mm = t2.matmul(&t.t()?)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 3, 0));
|
||||
let t_mm = t.matmul(&t2.t()?)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 0, 3));
|
||||
let t_mm = t.t()?.matmul(&t)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 1, 1));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
|
||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||
@ -1088,13 +1501,6 @@ test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
broadcasting,
|
||||
broadcasting_cpu,
|
||||
@ -1123,7 +1529,9 @@ test_device!(
|
||||
);
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
@ -1221,3 +1629,107 @@ fn cumsum() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
|
||||
/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
|
||||
fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
||||
let a_vec: Vec<f64> = a.to_vec1()?;
|
||||
let b_vec: Vec<f64> = b.to_vec1()?;
|
||||
|
||||
assert_eq!(a_vec.len(), b_vec.len());
|
||||
for (a, b) in a_vec.iter().zip(b_vec.iter()) {
|
||||
assert!((a - b).abs() < epsilon);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log_sum_exp() -> Result<()> {
|
||||
let input = Tensor::new(
|
||||
&[
|
||||
[[1f64, 2., 3.], [4., 5., 6.]],
|
||||
[[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]],
|
||||
],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let output = input.log_sum_exp(D::Minus1)?;
|
||||
// The expectations obtained from pytorch.
|
||||
let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?;
|
||||
assert_eq!(output.dims(), expected.dims());
|
||||
assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?;
|
||||
|
||||
assert_eq!(
|
||||
input.log_sum_exp((0, 1))?.to_vec1::<f64>()?,
|
||||
[1000.0, 999.0, 1001.0]
|
||||
);
|
||||
assert_eq!(
|
||||
input.log_sum_exp(())?.to_vec3::<f64>()?,
|
||||
input.to_vec3::<f64>()?
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pow() -> Result<()> {
|
||||
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let rhs = (&lhs - 2.)?;
|
||||
let res = lhs.pow(&rhs)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res, 3)?,
|
||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flip_1d() -> Result<()> {
|
||||
// 1D: [0, 1, 2, 3, 4]
|
||||
let t = Tensor::arange(0.0, 5.0, &Device::Cpu)?.reshape((5,))?;
|
||||
let flipped = t.flip(&[0])?;
|
||||
// Expected: [4, 3, 2, 1, 0]
|
||||
let expected = Tensor::from_vec(vec![4.0, 3.0, 2.0, 1.0, 0.0], (5,), &Device::Cpu)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flip_2d() -> Result<()> {
|
||||
// 2D:
|
||||
// [[0, 1, 2],
|
||||
// [3, 4, 5]]
|
||||
let t = Tensor::arange(0.0, 6.0, &Device::Cpu)?.reshape((2, 3))?;
|
||||
let flipped = t.flip(&[0, 1])?;
|
||||
// Expected:
|
||||
// [[5, 4, 3],
|
||||
// [2, 1, 0]]
|
||||
let expected = Tensor::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0], (2, 3), &Device::Cpu)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flip_3d_channels() -> Result<()> {
|
||||
// 3D:
|
||||
// [[[0,1,2],
|
||||
// [3,4,5]],
|
||||
//
|
||||
// [[6,7,8],
|
||||
// [9,10,11]]]
|
||||
let t = Tensor::arange(0.0, 12.0, &Device::Cpu)?.reshape((2, 2, 3))?;
|
||||
let flipped = t.flip(&[2])?;
|
||||
// Expected:
|
||||
// [[[2,1,0],
|
||||
// [5,4,3]],
|
||||
//
|
||||
// [[8,7,6],
|
||||
// [11,10,9]]]
|
||||
let expected = Tensor::from_vec(
|
||||
vec![2.0, 1.0, 0.0, 5.0, 4.0, 3.0, 8.0, 7.0, 6.0, 11.0, 10.0, 9.0],
|
||||
(2, 2, 3),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
BIN
candle-core/tests/test.pt
Normal file
BIN
candle-core/tests/test.pt
Normal file
Binary file not shown.
BIN
candle-core/tests/test_with_key.pt
Normal file
BIN
candle-core/tests/test_with_key.pt
Normal file
Binary file not shown.
@ -11,8 +11,8 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
@ -78,7 +78,7 @@ impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> {
|
||||
match self.inner.inner.next() {
|
||||
Some(item) => items.push(item),
|
||||
None => {
|
||||
if self.return_last_incomplete_batch {
|
||||
if self.return_last_incomplete_batch && !items.is_empty() {
|
||||
break;
|
||||
}
|
||||
return None;
|
||||
@ -102,7 +102,7 @@ impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> {
|
||||
ys.push(y)
|
||||
}
|
||||
None => {
|
||||
if self.return_last_incomplete_batch {
|
||||
if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() {
|
||||
break;
|
||||
}
|
||||
return None;
|
||||
@ -127,7 +127,7 @@ impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> {
|
||||
match self.inner.inner.next() {
|
||||
Some(item) => items.push(item),
|
||||
None => {
|
||||
if self.return_last_incomplete_batch {
|
||||
if self.return_last_incomplete_batch && !items.is_empty() {
|
||||
break;
|
||||
}
|
||||
return None;
|
||||
@ -154,7 +154,7 @@ impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResu
|
||||
}
|
||||
Some(Err(err)) => errs.push(err),
|
||||
None => {
|
||||
if self.return_last_incomplete_batch {
|
||||
if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() {
|
||||
break;
|
||||
}
|
||||
return None;
|
||||
|
@ -60,8 +60,8 @@ pub struct DatasetRandomIter<'a> {
|
||||
|
||||
impl<'a> DatasetRandomIter<'a> {
|
||||
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
||||
use rand::rng;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let all_tokens = if valid {
|
||||
&ds.valid_tokens
|
||||
@ -69,13 +69,13 @@ impl<'a> DatasetRandomIter<'a> {
|
||||
&ds.train_tokens
|
||||
};
|
||||
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
||||
tokens.shuffle(&mut thread_rng());
|
||||
tokens.shuffle(&mut rng());
|
||||
let current_tokens = tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = seq_len * 2;
|
||||
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
indexes_in_bytes.shuffle(&mut rng());
|
||||
Self {
|
||||
all_tokens,
|
||||
tokens,
|
||||
@ -87,26 +87,26 @@ impl<'a> DatasetRandomIter<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for DatasetRandomIter<'a> {
|
||||
impl Iterator for DatasetRandomIter<'_> {
|
||||
type Item = Result<(Tensor, Tensor)>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use rand::rng;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let seq_len = self.seq_len;
|
||||
if self.indexes_in_bytes.is_empty() {
|
||||
if self.tokens.is_empty() {
|
||||
self.tokens = self.all_tokens.iter().collect();
|
||||
self.tokens.shuffle(&mut thread_rng());
|
||||
self.tokens.shuffle(&mut rng());
|
||||
}
|
||||
self.current_tokens = self.tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = self.seq_len * 2;
|
||||
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
self.indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
self.indexes_in_bytes.shuffle(&mut rng());
|
||||
}
|
||||
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
||||
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
||||
|
@ -72,6 +72,8 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
// image-rs crate convention is to load in (width, height, channels) order
|
||||
// See: https://docs.rs/image/latest/image/trait.ImageDecoder.html#tymethod.dimensions
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
buffer_images.extend(image.to_rgb8().as_raw());
|
||||
}
|
||||
@ -81,8 +83,10 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
||||
}
|
||||
}
|
||||
}
|
||||
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
|
||||
.to_dtype(DType::U8)?
|
||||
// Reorder image-rs convention (width, height, channels) to candle/pytorch convolution convention (channels, height, width)
|
||||
let images = (Tensor::from_vec(buffer_images, (samples, 32, 32, 3), &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
.permute((0, 3, 2, 1))?
|
||||
/ 255.)?;
|
||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||
Ok((images, labels))
|
||||
|
@ -89,7 +89,7 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
||||
|
||||
pub fn load() -> Result<crate::vision::Dataset> {
|
||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let dataset_id = "mnist".to_string();
|
||||
let dataset_id = "ylecun/mnist".to_string();
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
|
@ -11,53 +11,66 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||
candle-onnx = { path = "../candle-onnx", version = "0.3.1", optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true, optional = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
candle-onnx = { workspace = true, optional = true }
|
||||
|
||||
csv = "1.3.0"
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
hf-hub = { workspace = true, features = ["tokio"] }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
||||
palette = { version = "0.7.6", optional = true }
|
||||
enterpolation = { version = "0.2.1", optional = true}
|
||||
pyo3 = { version = "0.22.0", features = ["auto-initialize", "abi3-py311"], optional = true }
|
||||
rayon = { workspace = true }
|
||||
rubato = { version = "0.15.0", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
cpal = { version = "0.15.2", optional = true }
|
||||
pdf2image = { version = "0.1.2" , optional = true}
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
imageproc = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rusttype = { workspace = true }
|
||||
ab_glyph = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
tokio = "1.43.0"
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
||||
cudnn = ["candle/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal", "rubato"]
|
||||
encodec = ["cpal", "symphonia", "rubato"]
|
||||
mimi = ["cpal", "symphonia", "rubato"]
|
||||
snac = ["cpal", "symphonia", "rubato"]
|
||||
depth_anything_v2 = ["palette", "enterpolation"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
@ -74,3 +87,43 @@ required-features = ["onnx"]
|
||||
[[example]]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "whisper"
|
||||
required-features = ["symphonia"]
|
||||
|
||||
[[example]]
|
||||
name = "whisper-microphone"
|
||||
required-features = ["microphone"]
|
||||
|
||||
[[example]]
|
||||
name = "mnist-training"
|
||||
required-features = ["candle-datasets"]
|
||||
|
||||
[[example]]
|
||||
name = "llama2-c"
|
||||
required-features = ["candle-datasets"]
|
||||
|
||||
[[example]]
|
||||
name = "mimi"
|
||||
required-features = ["mimi"]
|
||||
|
||||
[[example]]
|
||||
name = "snac"
|
||||
required-features = ["snac"]
|
||||
|
||||
[[example]]
|
||||
name = "encodec"
|
||||
required-features = ["encodec"]
|
||||
|
||||
[[example]]
|
||||
name = "depth_anything_v2"
|
||||
required-features = ["depth_anything_v2"]
|
||||
|
||||
[[example]]
|
||||
name = "silero-vad"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "colpali"
|
||||
required-features = ["pdf2image"]
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user