diff --git a/src/RelayXor.cpp b/src/RelayXor.cpp index 6b0d1cd..d622c77 100644 --- a/src/RelayXor.cpp +++ b/src/RelayXor.cpp @@ -1,163 +1,19 @@ #include "RelayServer.h" #include "DBQuery.h" #include "QueryScheduler.h" -#include "transport.h" +#include "xor.h" struct XorViews { - struct Elem { - char data[5 * 8]; - - Elem() { - memset(data, '\0', sizeof(data)); - } - - Elem(uint64_t created, std::string_view id, uint64_t idSize) { - memset(data, '\0', sizeof(data)); - data[3] = (created >> (4*8)) & 0xFF; - data[4] = (created >> (3*8)) & 0xFF; - data[5] = (created >> (2*8)) & 0xFF; - data[6] = (created >> (1*8)) & 0xFF; - data[7] = (created >> (0*8)) & 0xFF; - memcpy(data + 8, id.data(), idSize); - } - - Elem(std::string_view id) { - memset(data, '\0', sizeof(data)); - memcpy(data + 3, id.data(), id.size()); - } - - std::string_view getCompare(uint64_t idSize) const { - return std::string_view(data + 3, idSize + 5); - } - - std::string_view getId(uint64_t idSize) const { - return std::string_view(data + 8, idSize); - } - - std::string_view getIdPadded() const { - return std::string_view(data + 8, 32); - } - - std::string_view getFull() const { - return std::string_view(data, sizeof(data)); - } - - bool isZero() { - uint64_t *ours = reinterpret_cast(data + 8); - return ours[0] == 0 && ours[1] == 0 && ours[2] == 0 && ours[3] == 0; - } - - void doXor(const Elem &other) { - uint64_t *ours = reinterpret_cast(data + 8); - const uint64_t *theirs = reinterpret_cast(other.data + 8); - - ours[0] ^= theirs[0]; - ours[1] ^= theirs[1]; - ours[2] ^= theirs[2]; - ours[3] ^= theirs[3]; - } - - bool operator==(const Elem &o) const { - return o.getIdPadded() == getIdPadded(); - } - }; - - struct View { - uint64_t idSize; + struct UserView { + XorView v; std::string initialQuery; - - std::vector elems; - bool ready = false; - - View(uint64_t idSize, const std::string &initialQuery) : idSize(idSize), initialQuery(initialQuery) { - if (idSize < 8 || idSize > 32) throw herr("idSize out of range"); - } - - void addElem(uint64_t createdAt, std::string_view id) { - elems.emplace_back(createdAt, id, idSize); - } - - void finalise() { - std::reverse(elems.begin(), elems.end()); // pushed in approximately descending order, so hopefully this speeds up the sort - - std::sort(elems.begin(), elems.end(), [&](const auto &a, const auto &b) { return a.getCompare(idSize) < b.getCompare(idSize); }); - - ready = true; - - handleQuery(initialQuery); - initialQuery = ""; - } - - std::string handleQuery(std::string_view query) { // FIXME: this can throw - std::string output; - - auto cmp = [&](const auto &a, const auto &b){ return a.getCompare(idSize) < b.getCompare(idSize); }; - - while (query.size()) { - uint64_t lowerLength = decodeVarInt(query); - if (lowerLength > idSize + 5) throw herr("lower too long"); - Elem lowerKey(getBytes(query, lowerLength)); - - uint64_t upperLength = decodeVarInt(query); - if (upperLength > idSize + 5) throw herr("upper too long"); - Elem upperKey(getBytes(query, upperLength)); - - auto lower = std::lower_bound(elems.begin(), elems.end(), lowerKey, cmp); - auto upper = std::upper_bound(elems.begin(), elems.end(), upperKey, cmp); // FIXME: start at lower? - - uint64_t mode = decodeVarInt(query); // 0 = range, 8 and above = n-8 inline IDs - - if (mode == 0) { - Elem theirXorSet(getBytes(query, idSize)); - - Elem ourXorSet; - for (auto i = lower; i < upper; ++i) ourXorSet.doXor(*i); - - if (theirXorSet.getId(idSize) != ourXorSet.getId(idSize)) { - // Split our range - } - } else if (mode >= 8) { - flat_hash_map theirElems; - for (uint64_t i = 0; i < mode - 8; i++) theirElems.emplace(getBytes(query, idSize), false); - - for (auto it = lower; lower < upper; ++it) { - auto e = theirElems.find(*it); - - if (e == theirElems.end()) { - // Id exists on our side, but not their side - } else { - // Id exists on both sides - e->second = true; - } - } - - for (const auto &[k, v] : theirElems) { - if (!v) { - // Id exists on their side, but not our side - } - } - } - } - - return output; - } - - std::string xorRange(uint64_t start, uint64_t len) { - Elem output; - - for (uint64_t i = 0; i < len; i++) { - output.doXor(elems[i]); - } - - return std::string(output.getId(idSize)); - } }; - using ConnViews = flat_hash_map; - flat_hash_map conns; // connId -> subId -> View + using ConnViews = flat_hash_map; + flat_hash_map conns; // connId -> subId -> XorView - bool addView(uint64_t connId, const SubId &subId, uint64_t idSize, const std::string &query) { + bool addView(uint64_t connId, const SubId &subId, uint64_t idSize, const std::string &initialQuery) { { auto *existing = findView(connId, subId); if (existing) removeView(connId, subId); @@ -170,12 +26,12 @@ struct XorViews { return false; } - connViews.try_emplace(subId, idSize, query); + connViews.try_emplace(subId, UserView{ XorView(idSize), initialQuery }); return true; } - View *findView(uint64_t connId, const SubId &subId) { + UserView *findView(uint64_t connId, const SubId &subId) { auto f1 = conns.find(connId); if (f1 == conns.end()) return nullptr; @@ -211,7 +67,7 @@ void RelayServer::runXor(ThreadPool::Thread &thr) { for (auto levId : levIds) { auto ev = lookupEventByLevId(txn, levId); - view->addElem(ev.flat_nested()->created_at(), sv(ev.flat_nested()->id()).substr(0, view->idSize)); + view->v.addElem(ev.flat_nested()->created_at(), sv(ev.flat_nested()->id()).substr(0, view->v.idSize)); } }; @@ -219,15 +75,19 @@ void RelayServer::runXor(ThreadPool::Thread &thr) { auto *view = views.findView(sub.connId, sub.subId); if (!view) return; - view->finalise(); - /* - sendToConn(sub.connId, tao::json::to_string(tao::json::value::array({ - "XOR-RES", - sub.subId.str(), - to_hex(view->xorRange(0, view->elems.size())), - view->elems.size(), - }))); - */ + view->v.finalise(); + + std::vector haveIds, needIds; + auto resp = view->v.handleQuery(view->initialQuery, haveIds, needIds); + + sendToConn(sub.connId, tao::json::to_string(tao::json::value::array({ + "XOR-RES", + sub.subId.str(), + to_hex(resp), + // FIXME: haveIds + }))); + + view->initialQuery = ""; }; while(1) { @@ -267,13 +127,3 @@ void RelayServer::runXor(ThreadPool::Thread &thr) { txn.abort(); } } - - -namespace std { - // inject specialization of std::hash - template<> struct hash { - std::size_t operator()(XorViews::Elem const &p) const { - return phmap::HashState().combine(0, p.getFull()); - } - }; -} diff --git a/src/xor.h b/src/xor.h new file mode 100644 index 0000000..454f0c8 --- /dev/null +++ b/src/xor.h @@ -0,0 +1,217 @@ +#pragma once + +#include "transport.h" + + +struct XorElem { + char data[5 * 8]; + + XorElem() { + memset(data, '\0', sizeof(data)); + } + + XorElem(uint64_t created, std::string_view id, uint64_t idSize) { + memset(data, '\0', sizeof(data)); + data[3] = (created >> (4*8)) & 0xFF; + data[4] = (created >> (3*8)) & 0xFF; + data[5] = (created >> (2*8)) & 0xFF; + data[6] = (created >> (1*8)) & 0xFF; + data[7] = (created >> (0*8)) & 0xFF; + memcpy(data + 8, id.data(), idSize); + } + + XorElem(std::string_view id) { + memset(data, '\0', sizeof(data)); + memcpy(data + 3, id.data(), id.size()); + } + + std::string_view getCompare(uint64_t idSize) const { + return std::string_view(data + 3, idSize + 5); + } + + std::string_view getId(uint64_t idSize) const { + return std::string_view(data + 8, idSize); + } + + std::string_view getIdPadded() const { + return std::string_view(data + 8, 32); + } + + std::string_view getFull() const { + return std::string_view(data, sizeof(data)); + } + + bool isZero() { + uint64_t *ours = reinterpret_cast(data + 8); + return ours[0] == 0 && ours[1] == 0 && ours[2] == 0 && ours[3] == 0; + } + + void doXor(const XorElem &other) { + uint64_t *ours = reinterpret_cast(data + 8); + const uint64_t *theirs = reinterpret_cast(other.data + 8); + + ours[0] ^= theirs[0]; + ours[1] ^= theirs[1]; + ours[2] ^= theirs[2]; + ours[3] ^= theirs[3]; + } + + bool operator==(const XorElem &o) const { + return o.getIdPadded() == getIdPadded(); + } +}; + +struct XorView { + uint64_t idSize; + + std::vector elems; + bool ready = false; + + XorView(uint64_t idSize) : idSize(idSize) { + if (idSize < 8 || idSize > 32) throw herr("idSize out of range"); + } + + void addElem(uint64_t createdAt, std::string_view id) { + elems.emplace_back(createdAt, id, idSize); + } + + void finalise() { + std::reverse(elems.begin(), elems.end()); // pushed in approximately descending order, so hopefully this speeds up the sort + + std::sort(elems.begin(), elems.end(), [&](const auto &a, const auto &b) { return a.getCompare(idSize) < b.getCompare(idSize); }); + + ready = true; + } + + // FIXME: try/catch everywhere that calls this + std::string handleQuery(std::string_view query, std::vector &haveIds, std::vector &needIds) { + std::string output; + + auto cmp = [&](const auto &a, const auto &b){ return a.getCompare(idSize) < b.getCompare(idSize); }; + + while (query.size()) { + uint64_t lowerLength = decodeVarInt(query); + if (lowerLength > idSize + 5) throw herr("lower too long"); + XorElem lowerKey(getBytes(query, lowerLength)); + + uint64_t upperLength = decodeVarInt(query); + if (upperLength > idSize + 5) throw herr("upper too long"); + XorElem upperKey(getBytes(query, upperLength)); + + auto lower = std::lower_bound(elems.begin(), elems.end(), lowerKey, cmp); + auto upper = std::upper_bound(elems.begin(), elems.end(), upperKey, cmp); // FIXME: start at lower? + + uint64_t mode = decodeVarInt(query); // 0 = range, 8 and above = n-8 inline IDs + + if (mode == 0) { + XorElem theirXorSet(0, getBytes(query, idSize), idSize); + + XorElem ourXorSet; + for (auto i = lower; i < upper; ++i) ourXorSet.doXor(*i); + + if (theirXorSet.getId(idSize) != ourXorSet.getId(idSize)) { + // Split our range + uint64_t numElems = upper - lower; + const uint64_t buckets = 16; + + if (numElems < buckets * 2) { + output += encodeVarInt(numElems + 8); + for (auto it = lower; it < upper; ++it) output += it->getId(idSize); + } else { + uint64_t elemsPerBucket = numElems / buckets; + uint64_t bucketsWithExtra = numElems % buckets; + auto curr = lower; + + for (uint64_t i = 0; i < buckets; i++) { + { + auto k = getLowerKey(curr); + output += encodeVarInt(k.size()); + output += k; + } + + auto bucketEnd = curr + elemsPerBucket; + if (i < bucketsWithExtra) bucketEnd++; + + XorElem ourXorSet; + for (auto bucketEnd = curr + elemsPerBucket + (i < bucketsWithExtra ? 1 : 0); curr != bucketEnd; curr++) { + ourXorSet.doXor(*curr); + } + + output += ourXorSet.getId(idSize); + + { + auto k = getLowerKey(curr); + output += encodeVarInt(k.size()); + output += k; + } + } + } + } + } else if (mode >= 8) { + flat_hash_map theirElems; + for (uint64_t i = 0; i < mode - 8; i++) { + theirElems.emplace(XorElem(0, getBytes(query, idSize), idSize), false); + } + + for (auto it = lower; lower < upper; ++it) { + auto e = theirElems.find(*it); + + if (e == theirElems.end()) { + // Id exists on our side, but not their side + haveIds.emplace_back(it->getId(idSize)); + } else { + // Id exists on both sides + e->second = true; + } + } + + for (const auto &[k, v] : theirElems) { + if (!v) { + // Id exists on their side, but not our side + needIds.emplace_back(k.getId(idSize)); + } + } + } + } + + return output; + } + + std::string getLowerKey(std::vector::iterator it) { + if (it == elems.begin()) return std::string(1, '\0'); + return minimalKeyDiff(it->getCompare(idSize), std::prev(it)->getCompare(idSize)); + } + + std::string getUpperKey(std::vector::iterator it) { + if (it == elems.end()) return std::string(1, '\xFF'); + return minimalKeyDiff(it->getCompare(idSize), std::prev(it)->getCompare(idSize)); + } + + std::string minimalKeyDiff(std::string_view key, std::string_view prevKey) { + for (uint64_t i = 0; i < idSize + 5; i++) { + if (key[i] != prevKey[i]) return std::string(key.substr(0, i + 1)); + } + + throw herr("couldn't compute shared prefix"); + } + + std::string xorRange(uint64_t start, uint64_t len) { + XorElem output; + + for (uint64_t i = 0; i < len; i++) { + output.doXor(elems[i]); + } + + return std::string(output.getId(idSize)); + } +}; + + +namespace std { + // inject specialization of std::hash + template<> struct hash { + std::size_t operator()(XorElem const &p) const { + return phmap::HashState().combine(0, p.getFull()); + } + }; +}