diff --git a/src/apps/relay/RelayIngester.cpp b/src/apps/relay/RelayIngester.cpp index 21e8984..560750e 100644 --- a/src/apps/relay/RelayIngester.cpp +++ b/src/apps/relay/RelayIngester.cpp @@ -20,11 +20,10 @@ void RelayServer::runIngester(ThreadPool::Thread &thr) { if (cfg().relay__logging__dumpInAll) LI << "[" << msg->connId << "] dumpInAll: " << msg->payload; - if (!payload.is_array()) throw herr("message is not an array"); - auto &arr = payload.get_array(); - if (arr.size() < 2) throw herr("bad message"); + auto &arr = jsonGetArray(payload, "message is not an array"); + if (arr.size() < 2) throw herr("too few array elements"); - auto &cmd = arr[0].get_string(); + auto &cmd = jsonGetString(arr[0], "first element not a command like REQ"); if (cmd == "EVENT") { if (cfg().relay__logging__dumpInEvents) LI << "[" << msg->connId << "] dumpInEvent: " << msg->payload; @@ -32,7 +31,8 @@ void RelayServer::runIngester(ThreadPool::Thread &thr) { try { ingesterProcessEvent(txn, msg->connId, msg->ipAddr, secpCtx, arr[1], writerMsgs); } catch (std::exception &e) { - sendOKResponse(msg->connId, arr[1].at("id").get_string(), false, std::string("invalid: ") + e.what()); + sendOKResponse(msg->connId, arr[1].at("id").is_string() ? arr[1].at("id").get_string() : "?", + false, std::string("invalid: ") + e.what()); if (cfg().relay__logging__invalidEvents) LI << "Rejected invalid event: " << e.what(); } } else if (cmd == "REQ") { @@ -93,13 +93,20 @@ void RelayServer::ingesterProcessEvent(lmdb::txn &txn, uint64_t connId, std::str PackedEventView packed(packedStr); { - for (const auto &tagArr : origJson.at("tags").get_array()) { - auto tag = tagArr.get_array(); - if (tag.size() == 1 && tag.at(0).get_string() == "-") { - LI << "Protected event, skipping"; - sendOKResponse(connId, to_hex(packed.id()), false, "blocked: event marked as protected"); - return; + bool foundProtected = false; + + packed.foreachTag([&](char tagName, std::string_view tagVal){ + if (tagName == '-') { + foundProtected = true; + return false; } + return true; + }); + + if (foundProtected) { + LI << "Protected event, skipping"; + sendOKResponse(connId, to_hex(packed.id()), false, "blocked: event marked as protected"); + return; } } @@ -119,7 +126,7 @@ void RelayServer::ingesterProcessReq(lmdb::txn &txn, uint64_t connId, const tao: if (arr.get_array().size() < 2 + 1) throw herr("arr too small"); if (arr.get_array().size() > 2 + 20) throw herr("arr too big"); - Subscription sub(connId, arr[1].get_string(), NostrFilterGroup(arr)); + Subscription sub(connId, jsonGetString(arr[1], "REQ subscription id was not a string"), NostrFilterGroup(arr)); tpReqWorker.dispatch(connId, MsgReqWorker{MsgReqWorker::NewSub{std::move(sub)}}); } @@ -139,7 +146,7 @@ void RelayServer::ingesterProcessNegentropy(lmdb::txn &txn, Decompressor &decomp auto filterJson = arr.at(2); NostrFilterGroup filter = NostrFilterGroup::unwrapped(filterJson, maxFilterLimit); - Subscription sub(connId, arr[1].get_string(), std::move(filter)); + Subscription sub(connId, jsonGetString(arr[1], "NEG-OPEN subscription id was not a string"), std::move(filter)); if (filterJson.is_object()) { filterJson.get_object().erase("since"); diff --git a/src/apps/relay/RelayServer.h b/src/apps/relay/RelayServer.h index 0224bd0..083a2e6 100644 --- a/src/apps/relay/RelayServer.h +++ b/src/apps/relay/RelayServer.h @@ -16,6 +16,7 @@ #include "ThreadPool.h" #include "events.h" #include "filters.h" +#include "jsonParseUtils.h" #include "Decompressor.h" diff --git a/src/events.cpp b/src/events.cpp index 4510267..bf3f4d2 100644 --- a/src/events.cpp +++ b/src/events.cpp @@ -2,6 +2,7 @@ #include #include "events.h" +#include "jsonParseUtils.h" std::string nostrJsonToPackedEvent(const tao::json::value &v) { @@ -9,14 +10,16 @@ std::string nostrJsonToPackedEvent(const tao::json::value &v) { // Extract values from JSON, add strings to builder - auto id = from_hex(v.at("id").get_string(), false); - auto pubkey = from_hex(v.at("pubkey").get_string(), false); - uint64_t created_at = v.at("created_at").get_unsigned(); - uint64_t kind = v.at("kind").get_unsigned(); + auto id = from_hex(jsonGetString(v.at("id"), "event id field was not a string"), false); + auto pubkey = from_hex(jsonGetString(v.at("pubkey"), "event pubkey field was not a string"), false); + uint64_t created_at = jsonGetUnsigned(v.at("created_at"), "event created_at field was not an integer"); + uint64_t kind = jsonGetUnsigned(v.at("kind"), "event kind field was not an integer"); if (id.size() != 32) throw herr("unexpected id size"); if (pubkey.size() != 32) throw herr("unexpected pubkey size"); + jsonGetString(v.at("content"), "event content field was not a string"); + uint64_t expiration = 0; if (isReplaceableKind(kind)) { @@ -24,13 +27,13 @@ std::string nostrJsonToPackedEvent(const tao::json::value &v) { tagBuilder.add('d', ""); } - if (v.at("tags").get_array().size() > cfg().events__maxNumTags) throw herr("too many tags: ", v.at("tags").get_array().size()); + if (jsonGetArray(v.at("tags"), "tags field not an array").size() > cfg().events__maxNumTags) throw herr("too many tags: ", v.at("tags").get_array().size()); for (auto &tagArr : v.at("tags").get_array()) { - auto &tag = tagArr.get_array(); + auto &tag = jsonGetArray(tagArr, "tag in tags field was not an array"); if (tag.size() < 1) throw herr("too few fields in tag"); - auto tagName = tag.at(0).get_string(); - auto tagVal = tag.size() >= 2 ? tag.at(1).get_string() : ""; + auto tagName = jsonGetString(tag.at(0), "tag name was not a string"); + auto tagVal = tag.size() >= 2 ? jsonGetString(tag.at(1), "tag val was not a string") : ""; if (tagName == "e" || tagName == "p") { tagVal = from_hex(tagVal, false); @@ -105,7 +108,7 @@ void verifyNostrEvent(secp256k1_context *secpCtx, PackedEventView packed, const auto hash = nostrHash(origJson); if (hash != packed.id()) throw herr("bad event id"); - bool valid = verifySig(secpCtx, from_hex(origJson.at("sig").get_string(), false), packed.id(), packed.pubkey()); + bool valid = verifySig(secpCtx, from_hex(jsonGetString(origJson.at("sig"), "event sig was not a string"), false), packed.id(), packed.pubkey()); if (!valid) throw herr("bad signature"); } @@ -130,6 +133,7 @@ void verifyEventTimestamp(PackedEventView packed) { if (packed.expiration() > 1 && packed.expiration() <= now) throw herr("event expired"); } + void parseAndVerifyEvent(const tao::json::value &origJson, secp256k1_context *secpCtx, bool verifyMsg, bool verifyTime, std::string &packedStr, std::string &jsonStr) { packedStr = nostrJsonToPackedEvent(origJson); PackedEventView packed(packedStr); diff --git a/src/filters.h b/src/filters.h index b849494..d9ee320 100644 --- a/src/filters.h +++ b/src/filters.h @@ -121,6 +121,8 @@ struct NostrFilter { explicit NostrFilter(const tao::json::value &filterObj, uint64_t maxFilterLimit) { uint64_t numMajorFields = 0; + if (!filterObj.is_object()) throw herr("provided filter is not an object"); + for (const auto &[k, v] : filterObj.get_object()) { if (v.is_array() && v.get_array().size() == 0) { neverMatch = true; diff --git a/src/jsonParseUtils.h b/src/jsonParseUtils.h new file mode 100644 index 0000000..6d67ee1 --- /dev/null +++ b/src/jsonParseUtils.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include "golpe.h" + + +inline const std::string &jsonGetString(const tao::json::value &v, std::string_view errMsg) { + if (v.is_string()) return v.get_string(); + throw herr(errMsg); +} + +inline uint64_t jsonGetUnsigned(const tao::json::value &v, std::string_view errMsg) { + if (v.is_unsigned()) return v.get_unsigned(); + throw herr(errMsg); +} + +inline const std::vector &jsonGetArray(const tao::json::value &v, std::string_view errMsg) { + if (v.is_array()) return v.get_array(); + throw herr(errMsg); +}