From 407b2423af31ecaf44035f66a180a0bbc40e3aaa Mon Sep 17 00:00:00 2001 From: Hazelnoot Date: Tue, 10 Dec 2024 19:01:35 -0500 Subject: [PATCH] fix redis transaction implementation --- packages/backend/src/core/CoreModule.ts | 6 + .../backend/src/core/RedisConnectionPool.ts | 103 ++++++ packages/backend/src/core/TimeoutService.ts | 76 ++++ packages/backend/src/misc/rate-limit-utils.ts | 19 +- .../src/server/api/SkRateLimiterService.ts | 238 ++++++------ .../server/api/SkRateLimiterServiceTests.ts | 346 +++++++----------- 6 files changed, 439 insertions(+), 349 deletions(-) create mode 100644 packages/backend/src/core/RedisConnectionPool.ts create mode 100644 packages/backend/src/core/TimeoutService.ts diff --git a/packages/backend/src/core/CoreModule.ts b/packages/backend/src/core/CoreModule.ts index b18db7f366..caf135ae4b 100644 --- a/packages/backend/src/core/CoreModule.ts +++ b/packages/backend/src/core/CoreModule.ts @@ -155,6 +155,8 @@ import { QueueModule } from './QueueModule.js'; import { QueueService } from './QueueService.js'; import { LoggerService } from './LoggerService.js'; import { SponsorsService } from './SponsorsService.js'; +import { RedisConnectionPool } from './RedisConnectionPool.js'; +import { TimeoutService } from './TimeoutService.js'; import type { Provider } from '@nestjs/common'; //#region 文字列ベースでのinjection用(循環参照対応のため) @@ -383,6 +385,8 @@ const $SponsorsService: Provider = { provide: 'SponsorsService', useExisting: Sp ChannelFollowingService, RegistryApiService, ReversiService, + RedisConnectionPool, + TimeoutService, TimeService, EnvService, @@ -684,6 +688,8 @@ const $SponsorsService: Provider = { provide: 'SponsorsService', useExisting: Sp ChannelFollowingService, RegistryApiService, ReversiService, + RedisConnectionPool, + TimeoutService, TimeService, EnvService, diff --git a/packages/backend/src/core/RedisConnectionPool.ts b/packages/backend/src/core/RedisConnectionPool.ts new file mode 100644 index 0000000000..7ebefdfcb3 --- /dev/null +++ b/packages/backend/src/core/RedisConnectionPool.ts @@ -0,0 +1,103 @@ +/* + * SPDX-FileCopyrightText: hazelnoot and other Sharkey contributors + * SPDX-License-Identifier: AGPL-3.0-only + */ + +import { Inject, Injectable, OnApplicationShutdown } from '@nestjs/common'; +import Redis, { RedisOptions } from 'ioredis'; +import { DI } from '@/di-symbols.js'; +import type { Config } from '@/config.js'; +import Logger from '@/logger.js'; +import { Timeout, TimeoutService } from '@/core/TimeoutService.js'; +import { LoggerService } from './LoggerService.js'; + +/** + * Target number of connections to keep open and ready for use. + * The pool may grow beyond this during bursty traffic, but it will always shrink back to this number. + * The pool may remain below this number is the server never experiences enough traffic to consume this many clients. + */ +export const poolSize = 16; + +/** + * How often to drop an idle connection from the pool. + * This will never shrink the pool below poolSize. + */ +export const poolShrinkInterval = 5 * 1000; // 5 seconds + +@Injectable() +export class RedisConnectionPool implements OnApplicationShutdown { + private readonly poolShrinkTimer: Timeout; + private readonly pool: Redis.Redis[] = []; + private readonly logger: Logger; + private readonly redisOptions: RedisOptions; + + constructor(@Inject(DI.config) config: Config, loggerService: LoggerService, timeoutService: TimeoutService) { + this.logger = loggerService.getLogger('redis-pool'); + this.poolShrinkTimer = timeoutService.setInterval(() => this.shrinkPool(), poolShrinkInterval); + this.redisOptions = { + ...config.redis, + + // Set lazyConnect so that we can await() the connection manually. + // This helps to avoid a "stampede" of new connections (which are processed in the background!) under bursty conditions. + lazyConnect: true, + enableOfflineQueue: false, + }; + } + + /** + * Gets a Redis connection from the pool, or creates a new connection if the pool is empty. + * The returned object MUST be returned with a call to free(), even in the case of exceptions! + * Use a try...finally block for safe handling. + */ + public async alloc(): Promise { + let redis = this.pool.pop(); + + // The pool may be empty if we're under heavy load and/or we haven't opened all connections. + // Just construct a new instance, which will eventually be added to the pool. + // Excess clients will be disposed eventually. + if (!redis) { + redis = new Redis.Redis(this.redisOptions); + await redis.connect(); + } + + return redis; + } + + /** + * Returns a Redis connection to the pool. + * The instance MUST not be used after returning! + * Use a try...finally block for safe handling. + */ + public async free(redis: Redis.Redis): Promise { + // https://redis.io/docs/latest/commands/reset/ + await redis.reset(); + + this.pool.push(redis); + } + + public async onApplicationShutdown(): Promise { + // Cancel timer, otherwise it will cause a memory leak + clearInterval(this.poolShrinkTimer); + + // Disconnect all remaining instances + while (this.pool.length > 0) { + await this.dropClient(); + } + } + + private async shrinkPool(): Promise { + this.logger.debug(`Pool size is ${this.pool.length}`); + if (this.pool.length > poolSize) { + await this.dropClient(); + } + } + + private async dropClient(): Promise { + try { + const redis = this.pool.pop(); + await redis?.quit(); + } catch (err) { + this.logger.warn(`Error disconnecting from redis: ${err}`, { err }); + } + } +} diff --git a/packages/backend/src/core/TimeoutService.ts b/packages/backend/src/core/TimeoutService.ts new file mode 100644 index 0000000000..093b9a7b04 --- /dev/null +++ b/packages/backend/src/core/TimeoutService.ts @@ -0,0 +1,76 @@ +/* + * SPDX-FileCopyrightText: hazelnoot and other Sharkey contributors + * SPDX-License-Identifier: AGPL-3.0-only + */ + +/** + * Provides access to setTimeout, setInterval, and related functions. + * Used to support deterministic unit testing. + */ +export class TimeoutService { + /** + * Returns a promise that resolves after the specified timeout in milliseconds. + */ + public delay(timeout: number): Promise { + return new Promise(resolve => { + this.setTimeout(resolve, timeout); + }); + } + + /** + * Passthrough to node's setTimeout + */ + public setTimeout(handler: TimeoutHandler, timeout?: number): Timeout { + return setTimeout(() => handler(), timeout); + } + + /** + * Passthrough to node's setInterval + */ + public setInterval(handler: TimeoutHandler, timeout?: number): Timeout { + return setInterval(() => handler(), timeout); + } + + /** + * Passthrough to node's clearTimeout + */ + public clearTimeout(timeout: Timeout) { + clearTimeout(timeout); + } + + /** + * Passthrough to node's clearInterval + */ + public clearInterval(timeout: Timeout) { + clearInterval(timeout); + } +} + +/** + * Function to be called when a timer or interval elapses. + */ +export type TimeoutHandler = () => void; + +/** + * A fucked TS issue causes the DOM setTimeout to get merged with Node setTimeout, creating a "quantum method" that returns either "number" or "NodeJS.Timeout" depending on how it's called. + * This would be fine, except it always matches the *wrong type*! + * The result is this "impossible" scenario: + * + * ```typescript + * // Test evaluates to "false", because the method's return type is not equal to itself. + * type Test = ReturnType extends ReturnType ? true : false; + * + * // This is a compiler error, because the type is broken and triggers some internal TS bug. + * const timeout = setTimeout(handler); + * clearTimeout(timeout); // compiler error here, because even type inference doesn't work. + * + * // This fails to compile. + * function test(handler, timeout): ReturnType { + * return setTimeout(handler, timeout); + * } + * ``` + * + * The bug is marked as "wontfix" by TS devs, so we have to work around it ourselves. -_- + * By forcing the return type to *explicitly* include both types, we at least make it possible to work with the resulting token. + */ +export type Timeout = NodeJS.Timeout | number; diff --git a/packages/backend/src/misc/rate-limit-utils.ts b/packages/backend/src/misc/rate-limit-utils.ts index 9909bb97fa..cc13111390 100644 --- a/packages/backend/src/misc/rate-limit-utils.ts +++ b/packages/backend/src/misc/rate-limit-utils.ts @@ -117,12 +117,27 @@ export interface LimitInfo { fullResetMs: number; } +export const disabledLimitInfo: Readonly = Object.freeze({ + blocked: false, + remaining: Number.MAX_SAFE_INTEGER, + resetSec: 0, + resetMs: 0, + fullResetSec: 0, + fullResetMs: 0, +}); + export function isLegacyRateLimit(limit: RateLimit): limit is LegacyRateLimit { return limit.type === undefined; } -export function hasMinLimit(limit: LegacyRateLimit): limit is LegacyRateLimit & { minInterval: number } { - return !!limit.minInterval; +export type MaxLegacyLimit = LegacyRateLimit & { duration: number, max: number }; +export function hasMaxLimit(limit: LegacyRateLimit): limit is MaxLegacyLimit { + return limit.max != null && limit.duration != null; +} + +export type MinLegacyLimit = LegacyRateLimit & { minInterval: number }; +export function hasMinLimit(limit: LegacyRateLimit): limit is MinLegacyLimit { + return limit.minInterval != null; } export function sendRateLimitHeaders(reply: FastifyReply, info: LimitInfo): void { diff --git a/packages/backend/src/server/api/SkRateLimiterService.ts b/packages/backend/src/server/api/SkRateLimiterService.ts index b11d1556ba..71681aadc9 100644 --- a/packages/backend/src/server/api/SkRateLimiterService.ts +++ b/packages/backend/src/server/api/SkRateLimiterService.ts @@ -7,8 +7,9 @@ import { Inject, Injectable } from '@nestjs/common'; import Redis from 'ioredis'; import { TimeService } from '@/core/TimeService.js'; import { EnvService } from '@/core/EnvService.js'; -import { DI } from '@/di-symbols.js'; -import { BucketRateLimit, LegacyRateLimit, LimitInfo, RateLimit, hasMinLimit, isLegacyRateLimit, Keyed } from '@/misc/rate-limit-utils.js'; +import { BucketRateLimit, LegacyRateLimit, LimitInfo, RateLimit, hasMinLimit, isLegacyRateLimit, Keyed, hasMaxLimit, disabledLimitInfo, MaxLegacyLimit, MinLegacyLimit } from '@/misc/rate-limit-utils.js'; +import { RedisConnectionPool } from '@/core/RedisConnectionPool.js'; +import { TimeoutService } from '@/core/TimeoutService.js'; @Injectable() export class SkRateLimiterService { @@ -18,8 +19,11 @@ export class SkRateLimiterService { @Inject(TimeService) private readonly timeService: TimeService, - @Inject(DI.redis) - private readonly redisClient: Redis.Redis, + @Inject(TimeoutService) + private readonly timeoutService: TimeoutService, + + @Inject(RedisConnectionPool) + private readonly redisPool: RedisConnectionPool, @Inject(EnvService) envService: EnvService, @@ -29,117 +33,110 @@ export class SkRateLimiterService { public async limit(limit: Keyed, actor: string, factor = 1): Promise { if (this.disabled || factor === 0) { - return { - blocked: false, - remaining: Number.MAX_SAFE_INTEGER, - resetSec: 0, - resetMs: 0, - fullResetSec: 0, - fullResetMs: 0, - }; + return disabledLimitInfo; } if (factor < 0) { throw new Error(`Rate limit factor is zero or negative: ${factor}`); } - return await this.tryLimit(limit, actor, factor); + const redis = await this.redisPool.alloc(); + try { + return await this.tryLimit(redis, limit, actor, factor); + } finally { + await this.redisPool.free(redis); + } } - private async tryLimit(limit: Keyed, actor: string, factor: number, retry = 1): Promise { + private async tryLimit(redis: Redis.Redis, limit: Keyed, actor: string, factor: number, retry = 0): Promise { try { + if (retry > 0) { + // Real-world testing showed the need for backoff to "spread out" bursty traffic. + const backoff = Math.round(Math.pow(2, retry + Math.random())); + await this.timeoutService.delay(backoff); + } + if (isLegacyRateLimit(limit)) { - return await this.limitLegacy(limit, actor, factor); + return await this.limitLegacy(redis, limit, actor, factor); } else { - return await this.limitBucket(limit, actor, factor); + return await this.limitBucket(redis, limit, actor, factor); } } catch (err) { // We may experience collision errors from optimistic locking. // This is expected, so we should retry a few times before giving up. // https://redis.io/docs/latest/develop/interact/transactions/#optimistic-locking-using-check-and-set - if (err instanceof TransactionError && retry < 3) { - return await this.tryLimit(limit, actor, factor, retry + 1); + if (err instanceof ConflictError && retry < 4) { + // We can reuse the same connection to reduce pool contention, but we have to reset it first. + await redis.reset(); + return await this.tryLimit(redis, limit, actor, factor, retry + 1); } throw err; } } - private async limitLegacy(limit: Keyed, actor: string, factor: number): Promise { - const promises: Promise[] = []; - - // The "min" limit - if present - is handled directly. - if (hasMinLimit(limit)) { - promises.push( - this.limitMin(limit, actor, factor), - ); + private async limitLegacy(redis: Redis.Redis, limit: Keyed, actor: string, factor: number): Promise { + if (hasMaxLimit(limit)) { + return await this.limitMaxLegacy(redis, limit, actor, factor); + } else if (hasMinLimit(limit)) { + return await this.limitMinLegacy(redis, limit, actor, factor); + } else { + return disabledLimitInfo; } - - // Convert the "max" limit into a leaky bucket with 1 drip / second rate. - if (limit.max != null && limit.duration != null) { - promises.push( - this.limitBucket({ - type: 'bucket', - key: limit.key, - size: limit.max, - dripRate: Math.max(Math.round(limit.duration / limit.max), 1), - }, actor, factor), - ); - } - - const [lim1, lim2] = await Promise.all(promises); - return { - blocked: (lim1?.blocked || lim2?.blocked) ?? false, - remaining: Math.min(lim1?.remaining ?? Number.MAX_SAFE_INTEGER, lim2?.remaining ?? Number.MAX_SAFE_INTEGER), - resetSec: Math.max(lim1?.resetSec ?? 0, lim2?.resetSec ?? 0), - resetMs: Math.max(lim1?.resetMs ?? 0, lim2?.resetMs ?? 0), - fullResetSec: Math.max(lim1?.fullResetSec ?? 0, lim2?.fullResetSec ?? 0), - fullResetMs: Math.max(lim1?.fullResetMs ?? 0, lim2?.fullResetMs ?? 0), - }; } - private async limitMin(limit: Keyed & { minInterval: number }, actor: string, factor: number): Promise { - if (limit.minInterval === 0) return null; + private async limitMaxLegacy(redis: Redis.Redis, limit: Keyed, actor: string, factor: number): Promise { + if (limit.duration === 0) return disabledLimitInfo; + if (limit.duration < 0) throw new Error(`Invalid rate limit ${limit.key}: duration is negative (${limit.duration})`); + if (limit.max < 1) throw new Error(`Invalid rate limit ${limit.key}: max is less than 1 (${limit.max})`); + + // Derive initial dripRate from minInterval OR duration/max. + const initialDripRate = Math.max(limit.minInterval ?? Math.round(limit.duration / limit.max), 1); + + // Calculate dripSize to reach max at exactly duration + const dripSize = Math.max(Math.round(limit.max / (limit.duration / initialDripRate)), 1); + + // Calculate final dripRate from dripSize and duration/max + const dripRate = Math.max(Math.round(limit.duration / (limit.max / dripSize)), 1); + + const bucketLimit: Keyed = { + type: 'bucket', + key: limit.key, + size: limit.max, + dripRate, + dripSize, + }; + return await this.limitBucket(redis, bucketLimit, actor, factor); + } + + private async limitMinLegacy(redis: Redis.Redis, limit: Keyed, actor: string, factor: number): Promise { + if (limit.minInterval === 0) return disabledLimitInfo; if (limit.minInterval < 0) throw new Error(`Invalid rate limit ${limit.key}: minInterval is negative (${limit.minInterval})`); - const minInterval = Math.max(Math.ceil(limit.minInterval * factor), 0); - const expirationSec = Math.max(Math.ceil(minInterval / 1000), 1); - - // Check for window clear - const counter = await this.getLimitCounter(limit, actor, 'min'); - if (counter.counter > 0) { - const isCleared = this.timeService.now - counter.timestamp >= minInterval; - if (isCleared) { - counter.counter = 0; - } - } - - // Increment the limit, then synchronize with redis - const blocked = counter.counter > 0; - if (!blocked) { - counter.counter++; - counter.timestamp = this.timeService.now; - await this.updateLimitCounter(limit, actor, 'min', expirationSec, counter); - } - - // Calculate limit status - const resetMs = Math.max(minInterval - (this.timeService.now - counter.timestamp), 0); - const resetSec = Math.ceil(resetMs / 1000); - return { blocked, remaining: 0, resetSec, resetMs, fullResetSec: resetSec, fullResetMs: resetMs }; + const dripRate = Math.max(Math.round(limit.minInterval), 1); + const bucketLimit: Keyed = { + type: 'bucket', + key: limit.key, + size: 1, + dripRate, + dripSize: 1, + }; + return await this.limitBucket(redis, bucketLimit, actor, factor); } - private async limitBucket(limit: Keyed, actor: string, factor: number): Promise { + private async limitBucket(redis: Redis.Redis, limit: Keyed, actor: string, factor: number): Promise { if (limit.size < 1) throw new Error(`Invalid rate limit ${limit.key}: size is less than 1 (${limit.size})`); if (limit.dripRate != null && limit.dripRate < 1) throw new Error(`Invalid rate limit ${limit.key}: dripRate is less than 1 (${limit.dripRate})`); if (limit.dripSize != null && limit.dripSize < 1) throw new Error(`Invalid rate limit ${limit.key}: dripSize is less than 1 (${limit.dripSize})`); + const redisKey = createLimitKey(limit, actor); const bucketSize = Math.max(Math.ceil(limit.size / factor), 1); const dripRate = Math.ceil(limit.dripRate ?? 1000); const dripSize = Math.ceil(limit.dripSize ?? 1); const expirationSec = Math.max(Math.ceil(bucketSize / dripRate), 1); // Simulate bucket drips - const counter = await this.getLimitCounter(limit, actor, 'bucket'); + const counter = await this.getLimitCounter(redis, redisKey); if (counter.counter > 0) { const dripsSinceLastTick = Math.floor((this.timeService.now - counter.timestamp) / dripRate) * dripSize; counter.counter = Math.max(counter.counter - dripsSinceLastTick, 0); @@ -150,7 +147,7 @@ export class SkRateLimiterService { if (!blocked) { counter.counter++; counter.timestamp = this.timeService.now; - await this.updateLimitCounter(limit, actor, 'bucket', expirationSec, counter); + await this.updateLimitCounter(redis, redisKey, expirationSec, counter); } // Calculate how much time is needed to free up a bucket slot @@ -167,60 +164,49 @@ export class SkRateLimiterService { return { blocked, remaining, resetSec, resetMs, fullResetSec, fullResetMs }; } - private async getLimitCounter(limit: Keyed, actor: string, subject: string): Promise { - const timestampKey = createLimitKey(limit, actor, subject, 't'); - const counterKey = createLimitKey(limit, actor, subject, 'c'); + private async getLimitCounter(redis: Redis.Redis, key: string): Promise { + const counter: LimitCounter = { counter: 0, timestamp: 0 }; - const [timestamp, counter] = await this.executeRedis( - [ - ['get', timestampKey], - ['get', counterKey], - ], - [ - timestampKey, - counterKey, - ], - ); + // Watch the key BEFORE reading it! + await redis.watch(key); + const data = await redis.get(key); - return { - timestamp: timestamp ? parseInt(timestamp) : 0, - counter: counter ? parseInt(counter) : 0, - }; - } - - private async updateLimitCounter(limit: Keyed, actor: string, subject: string, expirationSec: number, counter: LimitCounter): Promise { - const timestampKey = createLimitKey(limit, actor, subject, 't'); - const counterKey = createLimitKey(limit, actor, subject, 'c'); - - await this.executeRedis( - [ - ['set', timestampKey, counter.timestamp.toString(), 'EX', expirationSec], - ['set', counterKey, counter.counter.toString(), 'EX', expirationSec], - ], - [ - timestampKey, - counterKey, - ], - ); - } - - private async executeRedis(batch: RedisBatch, watch: string[]): Promise> { - const results = await this.redisClient - .multi(batch) - .watch(watch) - .exec(); - - // Transaction error - if (!results) { - throw new TransactionError('Redis error: transaction conflict'); + // Data may be missing or corrupt if the key doesn't exist. + // This is an expected edge case. + if (data) { + const parts = data.split(':'); + if (parts.length === 2) { + counter.counter = parseInt(parts[0]); + counter.timestamp = parseInt(parts[1]); + } } - // The entire call failed + return counter; + } + + private async updateLimitCounter(redis: Redis.Redis, key: string, expirationSec: number, counter: LimitCounter): Promise { + const data = `${counter.counter}:${counter.timestamp}`; + + await this.executeRedisMulti( + redis, + [['set', key, data, 'EX', expirationSec]], + ); + } + + private async executeRedisMulti(redis: Redis.Redis, batch: RedisBatch): Promise> { + const results = await redis.multi(batch).exec(); + + // Transaction conflict (retryable) + if (!results) { + throw new ConflictError('Redis error: transaction conflict'); + } + + // Transaction failed (fatal) if (results.length !== batch.length) { throw new Error('Redis error: failed to execute batch'); } - // A particular command failed + // Command failed (fatal) const errors = results.map(r => r[0]).filter(e => e != null); if (errors.length > 0) { throw new AggregateError(errors, `Redis error: failed to execute command(s): '${errors.join('\', \'')}'`); @@ -233,11 +219,11 @@ export class SkRateLimiterService { type RedisBatch = [string, ...unknown[]][] & { length: Num }; type RedisResults = (string | null)[] & { length: Num }; -function createLimitKey(limit: Keyed, actor: string, subject: string, value: string): string { - return `rl_${actor}_${limit.key}_${subject}_${value}`; +function createLimitKey(limit: Keyed, actor: string): string { + return `rl_${actor}_${limit.key}`; } -class TransactionError extends Error {} +class ConflictError extends Error {} interface LimitCounter { timestamp: number; diff --git a/packages/backend/test/unit/server/api/SkRateLimiterServiceTests.ts b/packages/backend/test/unit/server/api/SkRateLimiterServiceTests.ts index 90030495ed..deb6b9f80e 100644 --- a/packages/backend/test/unit/server/api/SkRateLimiterServiceTests.ts +++ b/packages/backend/test/unit/server/api/SkRateLimiterServiceTests.ts @@ -6,6 +6,8 @@ import type Redis from 'ioredis'; import { SkRateLimiterService } from '@/server/api/SkRateLimiterService.js'; import { BucketRateLimit, Keyed, LegacyRateLimit } from '@/misc/rate-limit-utils.js'; +import { RedisConnectionPool } from '@/core/RedisConnectionPool.js'; +import { Timeout, TimeoutHandler, TimeoutService } from '@/core/TimeoutService.js'; /* eslint-disable @typescript-eslint/no-non-null-assertion */ @@ -24,28 +26,50 @@ describe(SkRateLimiterService, () => { }, }; + function callMockRedis(command: [string, ...unknown[]]) { + const handlerResults = mockRedis.map(handler => handler(command)); + const finalResult = handlerResults.findLast(result => result != null); + return finalResult ?? [null, null]; + } + + // I apologize to anyone who tries to read this later 🥲 mockRedis = []; mockRedisExec = (batch) => { const results: [Error | null, unknown][] = batch.map(command => { - const handlerResults = mockRedis.map(handler => handler(command)); - const finalResult = handlerResults.findLast(result => result != null); - return finalResult ?? [new Error('test error: no handler'), null]; + return callMockRedis(command); }); return Promise.resolve(results); }; const mockRedisClient = { + watch(...args: unknown[]) { + const result = callMockRedis(['watch', ...args]); + return Promise.resolve(result[0] ?? result[1]); + }, + get(...args: unknown[]) { + const result = callMockRedis(['get', ...args]); + return Promise.resolve(result[0] ?? result[1]); + }, + set(...args: unknown[]) { + const result = callMockRedis(['set', ...args]); + return Promise.resolve(result[0] ?? result[1]); + }, multi(batch: [string, ...unknown[]][]) { return { - watch() { - return { - exec() { - return mockRedisExec(batch); - }, - }; + exec() { + return mockRedisExec(batch); }, }; }, + reset() { + return Promise.resolve(); + }, } as unknown as Redis.Redis; + const mockRedisPool = { + alloc() { + return Promise.resolve(mockRedisClient); + }, + free() {}, + } as unknown as RedisConnectionPool; mockEnvironment = Object.create(process.env); mockEnvironment.NODE_ENV = 'production'; @@ -53,9 +77,22 @@ describe(SkRateLimiterService, () => { env: mockEnvironment, }; + const mockTimeoutService = new class extends TimeoutService { + setTimeout(handler: TimeoutHandler): Timeout { + handler(); + return 0; + } + setInterval(handler: TimeoutHandler): Timeout { + handler(); + return 0; + } + clearTimeout() {} + clearInterval() {} + }; + let service: SkRateLimiterService | undefined = undefined; serviceUnderTest = () => { - return service ??= new SkRateLimiterService(mockTimeService, mockRedisClient, mockEnvService); + return service ??= new SkRateLimiterService(mockTimeService, mockTimeoutService, mockRedisPool, mockEnvService); }; }); @@ -65,56 +102,22 @@ describe(SkRateLimiterService, () => { let limitCounter: number | undefined = undefined; let limitTimestamp: number | undefined = undefined; - let minCounter: number | undefined = undefined; - let minTimestamp: number | undefined = undefined; beforeEach(() => { limitCounter = undefined; limitTimestamp = undefined; - minCounter = undefined; - minTimestamp = undefined; mockRedis.push(([command, ...args]) => { - if (command === 'set' && args[0] === 'rl_actor_test_bucket_t') { - limitTimestamp = parseInt(args[1] as string); + if (command === 'set' && args[0] === 'rl_actor_test') { + const parts = (args[1] as string).split(':'); + limitCounter = parseInt(parts[0] as string); + limitTimestamp = parseInt(parts[1] as string); return [null, args[1]]; } - if (command === 'get' && args[0] === 'rl_actor_test_bucket_t') { - return [null, limitTimestamp?.toString() ?? null]; + if (command === 'get' && args[0] === 'rl_actor_test') { + const data = `${limitCounter ?? 0}:${limitTimestamp ?? 0}`; + return [null, data]; } - // if (command === 'incr' && args[0] === 'rl_actor_test_bucket_c') { - // limitCounter = (limitCounter ?? 0) + 1; - // return [null, null]; - // } - if (command === 'set' && args[0] === 'rl_actor_test_bucket_c') { - limitCounter = parseInt(args[1] as string); - return [null, args[1]]; - } - if (command === 'get' && args[0] === 'rl_actor_test_bucket_c') { - return [null, limitCounter?.toString() ?? null]; - } - - if (command === 'set' && args[0] === 'rl_actor_test_min_t') { - minTimestamp = parseInt(args[1] as string); - return [null, args[1]]; - } - if (command === 'get' && args[0] === 'rl_actor_test_min_t') { - return [null, minTimestamp?.toString() ?? null]; - } - // if (command === 'incr' && args[0] === 'rl_actor_test_min_c') { - // minCounter = (minCounter ?? 0) + 1; - // return [null, null]; - // } - if (command === 'set' && args[0] === 'rl_actor_test_min_c') { - minCounter = parseInt(args[1] as string); - return [null, args[1]]; - } - if (command === 'get' && args[0] === 'rl_actor_test_min_c') { - return [null, minCounter?.toString() ?? null]; - } - // if (command === 'expire') { - // return [null, null]; - // } return null; }); @@ -266,19 +269,7 @@ describe(SkRateLimiterService, () => { await serviceUnderTest().limit(limit, actor); - expect(commands).toContainEqual(['set', 'rl_actor_test_bucket_c', '1', 'EX', 1]); - }); - - it('should set timestamp expiration', async () => { - const commands: unknown[][] = []; - mockRedis.push(command => { - commands.push(command); - return null; - }); - - await serviceUnderTest().limit(limit, actor); - - expect(commands).toContainEqual(['set', 'rl_actor_test_bucket_t', '0', 'EX', 1]); + expect(commands).toContainEqual(['set', 'rl_actor_test', '1:0', 'EX', 1]); }); it('should not increment when already blocked', async () => { @@ -379,12 +370,12 @@ describe(SkRateLimiterService, () => { it('should retry when redis conflicts', async () => { let numCalls = 0; - const realMockRedisExec = mockRedisExec; + const originalExec = mockRedisExec; mockRedisExec = () => { - if (numCalls > 0) { - mockRedisExec = realMockRedisExec; - } numCalls++; + if (numCalls > 1) { + mockRedisExec = originalExec; + } return Promise.resolve(null); }; @@ -393,7 +384,7 @@ describe(SkRateLimiterService, () => { expect(numCalls).toBe(2); }); - it('should bail out after 3 tries', async () => { + it('should bail out after 5 tries', async () => { let numCalls = 0; mockRedisExec = () => { numCalls++; @@ -403,7 +394,7 @@ describe(SkRateLimiterService, () => { const promise = serviceUnderTest().limit(limit, actor); await expect(promise).rejects.toThrow(/transaction conflict/); - expect(numCalls).toBe(3); + expect(numCalls).toBe(5); }); it('should apply correction if extra calls slip through', async () => { @@ -450,8 +441,8 @@ describe(SkRateLimiterService, () => { it('should increment counter when called', async () => { await serviceUnderTest().limit(limit, actor); - expect(minCounter).not.toBeUndefined(); - expect(minCounter).toBe(1); + expect(limitCounter).not.toBeUndefined(); + expect(limitCounter).toBe(1); }); it('should set timestamp when called', async () => { @@ -459,30 +450,19 @@ describe(SkRateLimiterService, () => { await serviceUnderTest().limit(limit, actor); - expect(minCounter).not.toBeUndefined(); - expect(minTimestamp).toBe(1000); + expect(limitCounter).not.toBeUndefined(); + expect(limitTimestamp).toBe(1000); }); it('should decrement counter when minInterval has passed', async () => { - minCounter = 1; - minTimestamp = 0; + limitCounter = 1; + limitTimestamp = 0; mockTimeService.now = 1000; await serviceUnderTest().limit(limit, actor); - expect(minCounter).not.toBeUndefined(); - expect(minCounter).toBe(1); // 1 (starting) - 1 (interval) + 1 (call) = 1 - }); - - it('should reset counter entirely', async () => { - minCounter = 2; - minTimestamp = 0; - mockTimeService.now = 1000; - - await serviceUnderTest().limit(limit, actor); - - expect(minCounter).not.toBeUndefined(); - expect(minCounter).toBe(1); // 2 (starting) - 2 (interval) + 1 (call) = 1 + expect(limitCounter).not.toBeUndefined(); + expect(limitCounter).toBe(1); // 1 (starting) - 1 (interval) + 1 (call) = 1 }); it('should maintain counter between calls over time', async () => { @@ -495,13 +475,13 @@ describe(SkRateLimiterService, () => { mockTimeService.now += 1000; // 0 - 1 = 0 await serviceUnderTest().limit(limit, actor); // 0 + 1 = 1 - expect(minCounter).toBe(1); - expect(minTimestamp).toBe(3000); + expect(limitCounter).toBe(1); + expect(limitTimestamp).toBe(3000); }); it('should block when interval exceeded', async () => { - minCounter = 1; - minTimestamp = 0; + limitCounter = 1; + limitTimestamp = 0; const info = await serviceUnderTest().limit(limit, actor); @@ -509,8 +489,8 @@ describe(SkRateLimiterService, () => { }); it('should calculate correct info when blocked', async () => { - minCounter = 1; - minTimestamp = 0; + limitCounter = 1; + limitTimestamp = 0; const info = await serviceUnderTest().limit(limit, actor); @@ -521,8 +501,8 @@ describe(SkRateLimiterService, () => { }); it('should allow when bucket is filled but interval has passed', async () => { - minCounter = 1; - minTimestamp = 0; + limitCounter = 1; + limitTimestamp = 0; mockTimeService.now = 1000; const info = await serviceUnderTest().limit(limit, actor); @@ -531,8 +511,8 @@ describe(SkRateLimiterService, () => { }); it('should scale interval by factor', async () => { - minCounter = 1; - minTimestamp = 0; + limitCounter = 1; + limitTimestamp = 0; mockTimeService.now += 500; const info = await serviceUnderTest().limit(limit, actor, 0.5); @@ -549,30 +529,18 @@ describe(SkRateLimiterService, () => { await serviceUnderTest().limit(limit, actor); - expect(commands).toContainEqual(['set', 'rl_actor_test_min_c', '1', 'EX', 1]); - }); - - it('should set timestamp expiration', async () => { - const commands: unknown[][] = []; - mockRedis.push(command => { - commands.push(command); - return null; - }); - - await serviceUnderTest().limit(limit, actor); - - expect(commands).toContainEqual(['set', 'rl_actor_test_min_t', '0', 'EX', 1]); + expect(commands).toContainEqual(['set', 'rl_actor_test', '1:0', 'EX', 1]); }); it('should not increment when already blocked', async () => { - minCounter = 1; - minTimestamp = 0; + limitCounter = 1; + limitTimestamp = 0; mockTimeService.now += 100; await serviceUnderTest().limit(limit, actor); - expect(minCounter).toBe(1); - expect(minTimestamp).toBe(0); + expect(limitCounter).toBe(1); + expect(limitTimestamp).toBe(0); }); it('should skip if factor is zero', async () => { @@ -605,17 +573,17 @@ describe(SkRateLimiterService, () => { await expect(promise).rejects.toThrow(/minInterval is negative/); }); - it('should not apply correction to extra calls', async () => { - minCounter = 2; + it('should apply correction if extra calls slip through', async () => { + limitCounter = 2; const info = await serviceUnderTest().limit(limit, actor); expect(info.blocked).toBeTruthy(); expect(info.remaining).toBe(0); - expect(info.resetMs).toBe(1000); - expect(info.resetSec).toBe(1); - expect(info.fullResetMs).toBe(1000); - expect(info.fullResetSec).toBe(1); + expect(info.resetMs).toBe(2000); + expect(info.resetSec).toBe(2); + expect(info.fullResetMs).toBe(2000); + expect(info.fullResetSec).toBe(2); }); }); @@ -720,19 +688,7 @@ describe(SkRateLimiterService, () => { await serviceUnderTest().limit(limit, actor); - expect(commands).toContainEqual(['set', 'rl_actor_test_bucket_c', '1', 'EX', 1]); - }); - - it('should set timestamp expiration', async () => { - const commands: unknown[][] = []; - mockRedis.push(command => { - commands.push(command); - return null; - }); - - await serviceUnderTest().limit(limit, actor); - - expect(commands).toContainEqual(['set', 'rl_actor_test_bucket_t', '0', 'EX', 1]); + expect(commands).toContainEqual(['set', 'rl_actor_test', '1:0', 'EX', 1]); }); it('should not increment when already blocked', async () => { @@ -774,12 +730,21 @@ describe(SkRateLimiterService, () => { await expect(promise).rejects.toThrow(/factor is zero or negative/); }); + it('should skip if duration is zero', async () => { + limit.duration = 0; + + const info = await serviceUnderTest().limit(limit, actor); + + expect(info.blocked).toBeFalsy(); + expect(info.remaining).toBe(Number.MAX_SAFE_INTEGER); + }); + it('should throw if max is zero', async () => { limit.max = 0; const promise = serviceUnderTest().limit(limit, actor); - await expect(promise).rejects.toThrow(/size is less than 1/); + await expect(promise).rejects.toThrow(/max is less than 1/); }); it('should throw if max is negative', async () => { @@ -787,7 +752,7 @@ describe(SkRateLimiterService, () => { const promise = serviceUnderTest().limit(limit, actor); - await expect(promise).rejects.toThrow(/size is less than 1/); + await expect(promise).rejects.toThrow(/max is less than 1/); }); it('should apply correction if extra calls slip through', async () => { @@ -811,7 +776,7 @@ describe(SkRateLimiterService, () => { limit = { type: undefined, key, - max: 5, + max: 10, duration: 5000, minInterval: 1000, }; @@ -824,7 +789,7 @@ describe(SkRateLimiterService, () => { }); it('should block when limit exceeded', async () => { - limitCounter = 5; + limitCounter = 10; limitTimestamp = 0; const info = await serviceUnderTest().limit(limit, actor); @@ -832,30 +797,8 @@ describe(SkRateLimiterService, () => { expect(info.blocked).toBeTruthy(); }); - it('should block when minInterval exceeded', async () => { - minCounter = 1; - minTimestamp = 0; - - const info = await serviceUnderTest().limit(limit, actor); - - expect(info.blocked).toBeTruthy(); - }); - it('should calculate correct info when allowed', async () => { - limitCounter = 1; - limitTimestamp = 0; - - const info = await serviceUnderTest().limit(limit, actor); - - expect(info.remaining).toBe(0); - expect(info.resetSec).toBe(1); - expect(info.resetMs).toBe(1000); - expect(info.fullResetSec).toBe(2); - expect(info.fullResetMs).toBe(2000); - }); - - it('should calculate correct info when blocked by limit', async () => { - limitCounter = 5; + limitCounter = 9; limitTimestamp = 0; const info = await serviceUnderTest().limit(limit, actor); @@ -867,17 +810,17 @@ describe(SkRateLimiterService, () => { expect(info.fullResetMs).toBe(5000); }); - it('should calculate correct info when blocked by minInterval', async () => { - minCounter = 1; - minTimestamp = 0; + it('should calculate correct info when blocked', async () => { + limitCounter = 10; + limitTimestamp = 0; const info = await serviceUnderTest().limit(limit, actor); expect(info.remaining).toBe(0); expect(info.resetSec).toBe(1); expect(info.resetMs).toBe(1000); - expect(info.fullResetSec).toBe(1); - expect(info.fullResetMs).toBe(1000); + expect(info.fullResetSec).toBe(5); + expect(info.fullResetMs).toBe(5000); }); it('should allow when counter is filled but interval has passed', async () => { @@ -890,21 +833,23 @@ describe(SkRateLimiterService, () => { expect(info.blocked).toBeFalsy(); }); - it('should allow when minCounter is filled but interval has passed', async () => { - minCounter = 1; - minTimestamp = 0; - mockTimeService.now = 1000; + it('should drip according to minInterval', async () => { + limitCounter = 10; + limitTimestamp = 0; + mockTimeService.now += 1000; - const info = await serviceUnderTest().limit(limit, actor); + const i1 = await serviceUnderTest().limit(limit, actor); + const i2 = await serviceUnderTest().limit(limit, actor); + const i3 = await serviceUnderTest().limit(limit, actor); - expect(info.blocked).toBeFalsy(); + expect(i1.blocked).toBeFalsy(); + expect(i2.blocked).toBeFalsy(); + expect(i3.blocked).toBeTruthy(); }); it('should scale limit and interval by factor', async () => { limitCounter = 5; limitTimestamp = 0; - minCounter = 1; - minTimestamp = 0; mockTimeService.now += 500; const info = await serviceUnderTest().limit(limit, actor, 0.5); @@ -912,7 +857,7 @@ describe(SkRateLimiterService, () => { expect(info.blocked).toBeFalsy(); }); - it('should set bucket counter expiration', async () => { + it('should set counter expiration', async () => { const commands: unknown[][] = []; mockRedis.push(command => { commands.push(command); @@ -921,63 +866,22 @@ describe(SkRateLimiterService, () => { await serviceUnderTest().limit(limit, actor); - expect(commands).toContainEqual(['set', 'rl_actor_test_bucket_c', '1', 'EX', 1]); - }); - - it('should set bucket timestamp expiration', async () => { - const commands: unknown[][] = []; - mockRedis.push(command => { - commands.push(command); - return null; - }); - - await serviceUnderTest().limit(limit, actor); - - expect(commands).toContainEqual(['set', 'rl_actor_test_bucket_t', '0', 'EX', 1]); - }); - - it('should set min counter expiration', async () => { - const commands: unknown[][] = []; - mockRedis.push(command => { - commands.push(command); - return null; - }); - - await serviceUnderTest().limit(limit, actor); - - expect(commands).toContainEqual(['set', 'rl_actor_test_min_c', '1', 'EX', 1]); - }); - - it('should set min timestamp expiration', async () => { - const commands: unknown[][] = []; - mockRedis.push(command => { - commands.push(command); - return null; - }); - - await serviceUnderTest().limit(limit, actor); - - expect(commands).toContainEqual(['set', 'rl_actor_test_min_t', '0', 'EX', 1]); + expect(commands).toContainEqual(['set', 'rl_actor_test', '1:0', 'EX', 1]); }); it('should not increment when already blocked', async () => { - limitCounter = 5; + limitCounter = 10; limitTimestamp = 0; - minCounter = 1; - minTimestamp = 0; mockTimeService.now += 100; await serviceUnderTest().limit(limit, actor); - expect(limitCounter).toBe(5); + expect(limitCounter).toBe(10); expect(limitTimestamp).toBe(0); - expect(minCounter).toBe(1); - expect(minTimestamp).toBe(0); }); it('should apply correction if extra calls slip through', async () => { - limitCounter = 6; - minCounter = 6; + limitCounter = 12; const info = await serviceUnderTest().limit(limit, actor);