feat: respond to auth only when expected

This commit is contained in:
2023-10-12 22:02:05 +01:00
parent a080f0bb0c
commit 2e38ac0d4f

View File

@ -1,12 +1,13 @@
import { v4 as uuid } from "uuid"; import { v4 as uuid } from "uuid";
import debug from "debug"; import debug from "debug";
import WebSocket from "isomorphic-ws"; import WebSocket from "isomorphic-ws";
import { unwrap, ExternalStore, unixNowMs } from "@snort/shared"; import { unwrap, ExternalStore, unixNowMs, dedupe } from "@snort/shared";
import { DefaultConnectTimeout } from "./const"; import { DefaultConnectTimeout } from "./const";
import { ConnectionStats } from "./connection-stats"; import { ConnectionStats } from "./connection-stats";
import { NostrEvent, ReqCommand, TaggedNostrEvent, u256 } from "./nostr"; import { NostrEvent, ReqCommand, ReqFilter, TaggedNostrEvent, u256 } from "./nostr";
import { RelayInfo } from "./relay-info"; import { RelayInfo } from "./relay-info";
import EventKind from "./event-kind";
export type AuthHandler = (challenge: string, relay: string) => Promise<NostrEvent | undefined>; export type AuthHandler = (challenge: string, relay: string) => Promise<NostrEvent | undefined>;
@ -46,9 +47,10 @@ export interface ConnectionStateSnapshot {
} }
export class Connection extends ExternalStore<ConnectionStateSnapshot> { export class Connection extends ExternalStore<ConnectionStateSnapshot> {
#log = debug("Connection"); #log: debug.Debugger;
#ephemeralCheck?: ReturnType<typeof setInterval>; #ephemeralCheck?: ReturnType<typeof setInterval>;
#activity: number = unixNowMs(); #activity: number = unixNowMs();
#expectAuth = false;
Id: string; Id: string;
Address: string; Address: string;
@ -89,6 +91,7 @@ export class Connection extends ExternalStore<ConnectionStateSnapshot> {
this.AwaitingAuth = new Map(); this.AwaitingAuth = new Map();
this.Auth = auth; this.Auth = auth;
this.Ephemeral = ephemeral; this.Ephemeral = ephemeral;
this.#log = debug("Connection").extend(addr);
} }
async Connect() { async Connect() {
@ -139,7 +142,7 @@ export class Connection extends ExternalStore<ConnectionStateSnapshot> {
OnOpen(wasReconnect: boolean) { OnOpen(wasReconnect: boolean) {
this.ConnectTimeout = DefaultConnectTimeout; this.ConnectTimeout = DefaultConnectTimeout;
this.#log(`[${this.Address}] Open!`); this.#log(`Open!`);
this.Down = false; this.Down = false;
this.#setupEphemeral(); this.#setupEphemeral();
this.OnConnected?.(wasReconnect); this.OnConnected?.(wasReconnect);
@ -155,27 +158,23 @@ export class Connection extends ExternalStore<ConnectionStateSnapshot> {
// remote server closed the connection, dont re-connect // remote server closed the connection, dont re-connect
if (e.code === 4000) { if (e.code === 4000) {
this.IsClosed = true; this.IsClosed = true;
this.#log(`[${this.Address}] Closed! (Remote)`); this.#log(`Closed! (Remote)`);
} else if (!this.IsClosed) { } else if (!this.IsClosed) {
this.ConnectTimeout = this.ConnectTimeout * 2; this.ConnectTimeout = this.ConnectTimeout * 2;
this.#log( this.#log(
`[${this.Address}] Closed (code=${e.code}), trying again in ${(this.ConnectTimeout / 1000) `Closed (code=${e.code}), trying again in ${(this.ConnectTimeout / 1000).toFixed(0).toLocaleString()} sec`,
.toFixed(0)
.toLocaleString()} sec`,
); );
this.ReconnectTimer = setTimeout(() => { this.ReconnectTimer = setTimeout(() => {
this.Connect(); this.Connect();
}, this.ConnectTimeout); }, this.ConnectTimeout);
this.Stats.Disconnects++; this.Stats.Disconnects++;
} else { } else {
this.#log(`[${this.Address}] Closed!`); this.#log(`Closed!`);
this.ReconnectTimer = undefined; this.ReconnectTimer = undefined;
} }
this.OnDisconnect?.(e.code); this.OnDisconnect?.(e.code);
this.#resetQueues(); this.#reset();
// reset connection Id on disconnect, for query-tracking
this.Id = uuid();
this.notifyChange(); this.notifyChange();
} }
@ -186,11 +185,15 @@ export class Connection extends ExternalStore<ConnectionStateSnapshot> {
const tag = msg[0] as string; const tag = msg[0] as string;
switch (tag) { switch (tag) {
case "AUTH": { case "AUTH": {
this.#onAuthAsync(msg[1] as string) if (this.#expectAuth) {
.then(() => this.#sendPendingRaw()) this.#onAuthAsync(msg[1] as string)
.catch(this.#log); .then(() => this.#sendPendingRaw())
this.Stats.EventsReceived++; .catch(this.#log);
this.notifyChange(); this.Stats.EventsReceived++;
this.notifyChange();
} else {
this.#log("Ignoring unexpected AUTH request");
}
break; break;
} }
case "EVENT": { case "EVENT": {
@ -208,7 +211,7 @@ export class Connection extends ExternalStore<ConnectionStateSnapshot> {
} }
case "OK": { case "OK": {
// feedback to broadcast call // feedback to broadcast call
this.#log(`${this.Address} OK: %O`, msg); this.#log(`OK: %O`, msg);
const id = msg[1] as string; const id = msg[1] as string;
const cb = this.EventsCallback.get(id); const cb = this.EventsCallback.get(id);
if (cb) { if (cb) {
@ -218,7 +221,7 @@ export class Connection extends ExternalStore<ConnectionStateSnapshot> {
break; break;
} }
case "NOTICE": { case "NOTICE": {
this.#log(`[${this.Address}] NOTICE: ${msg[1]}`); this.#log(`NOTICE: ${msg[1]}`);
break; break;
} }
default: { default: {
@ -306,12 +309,23 @@ export class Connection extends ExternalStore<ConnectionStateSnapshot> {
* @param cmd The REQ to send to the server * @param cmd The REQ to send to the server
*/ */
QueueReq(cmd: ReqCommand, cbSent: () => void) { QueueReq(cmd: ReqCommand, cbSent: () => void) {
const requestKinds = dedupe(
cmd
.slice(2)
.map(a => (a as ReqFilter).kinds ?? [])
.flat(),
);
const ExpectAuth = [EventKind.DirectMessage, EventKind.GiftWrap];
if (ExpectAuth.some(a => requestKinds.includes(a)) && !this.#expectAuth) {
this.#expectAuth = true;
this.#log("Setting expectAuth flag %o", requestKinds);
}
if (this.ActiveRequests.size >= this.#maxSubscriptions) { if (this.ActiveRequests.size >= this.#maxSubscriptions) {
this.PendingRequests.push({ this.PendingRequests.push({
cmd, cmd,
cb: cbSent, cb: cbSent,
}); });
this.#log("Queuing: %s %O", this.Address, cmd); this.#log("Queuing: %O", cmd);
} else { } else {
this.ActiveRequests.add(cmd[1]); this.ActiveRequests.add(cmd[1]);
this.#sendJson(cmd); this.#sendJson(cmd);
@ -359,13 +373,16 @@ export class Connection extends ExternalStore<ConnectionStateSnapshot> {
this.ActiveRequests.add(p.cmd[1]); this.ActiveRequests.add(p.cmd[1]);
this.#sendJson(p.cmd); this.#sendJson(p.cmd);
p.cb(); p.cb();
this.#log("Sent pending REQ %s %O", this.Address, p.cmd); this.#log("Sent pending REQ %O", p.cmd);
} }
} }
} }
} }
#resetQueues() { #reset() {
// reset connection Id on disconnect, for query-tracking
this.Id = uuid();
this.#expectAuth = false;
this.ActiveRequests.clear(); this.ActiveRequests.clear();
this.PendingRequests = []; this.PendingRequests = [];
this.PendingRaw = []; this.PendingRaw = [];
@ -452,12 +469,7 @@ export class Connection extends ExternalStore<ConnectionStateSnapshot> {
const lastActivity = unixNowMs() - this.#activity; const lastActivity = unixNowMs() - this.#activity;
if (lastActivity > 30_000 && !this.IsClosed) { if (lastActivity > 30_000 && !this.IsClosed) {
if (this.ActiveRequests.size > 0) { if (this.ActiveRequests.size > 0) {
this.#log( this.#log("Inactive connection has %d active requests! %O", this.ActiveRequests.size, this.ActiveRequests);
"%s Inactive connection has %d active requests! %O",
this.Address,
this.ActiveRequests.size,
this.ActiveRequests,
);
} else { } else {
this.Close(); this.Close();
} }