implement SkRateLimiterService with Leaky Bucket rate limiting

This commit is contained in:
Hazelnoot 2024-12-07 10:22:45 -05:00
parent f59af78d8a
commit ffc2737478
9 changed files with 1102 additions and 26 deletions

View file

@ -42,6 +42,13 @@ export default [
name: '__filename', name: '__filename',
message: 'Not in ESModule. Use `import.meta.url` instead.', message: 'Not in ESModule. Use `import.meta.url` instead.',
}], }],
// https://typescript-eslint.io/rules/prefer-nullish-coalescing/
'@typescript-eslint/prefer-nullish-coalescing': ['warn', {
ignorePrimitives: {
// Without this, the rule breaks for nullable booleans
boolean: true,
},
}],
}, },
}, },
{ {

View file

@ -14,6 +14,8 @@ import { AbuseReportNotificationService } from '@/core/AbuseReportNotificationSe
import { SystemWebhookService } from '@/core/SystemWebhookService.js'; import { SystemWebhookService } from '@/core/SystemWebhookService.js';
import { UserSearchService } from '@/core/UserSearchService.js'; import { UserSearchService } from '@/core/UserSearchService.js';
import { WebhookTestService } from '@/core/WebhookTestService.js'; import { WebhookTestService } from '@/core/WebhookTestService.js';
import { TimeService } from '@/core/TimeService.js';
import { EnvService } from '@/core/EnvService.js';
import { AccountMoveService } from './AccountMoveService.js'; import { AccountMoveService } from './AccountMoveService.js';
import { AccountUpdateService } from './AccountUpdateService.js'; import { AccountUpdateService } from './AccountUpdateService.js';
import { AnnouncementService } from './AnnouncementService.js'; import { AnnouncementService } from './AnnouncementService.js';
@ -381,6 +383,8 @@ const $SponsorsService: Provider = { provide: 'SponsorsService', useExisting: Sp
ChannelFollowingService, ChannelFollowingService,
RegistryApiService, RegistryApiService,
ReversiService, ReversiService,
TimeService,
EnvService,
ChartLoggerService, ChartLoggerService,
FederationChart, FederationChart,
@ -680,6 +684,8 @@ const $SponsorsService: Provider = { provide: 'SponsorsService', useExisting: Sp
ChannelFollowingService, ChannelFollowingService,
RegistryApiService, RegistryApiService,
ReversiService, ReversiService,
TimeService,
EnvService,
FederationChart, FederationChart,
NotesChart, NotesChart,

View file

@ -0,0 +1,20 @@
/*
* SPDX-FileCopyrightText: hazelnoot and other Sharkey contributors
* SPDX-License-Identifier: AGPL-3.0-only
*/
import { Injectable } from '@nestjs/common';
/**
* Provides access to the process environment variables.
* This exists for testing purposes, so that a test can mock the environment without corrupting state for other tests.
*/
@Injectable()
export class EnvService {
/**
* Passthrough to process.env
*/
public get env() {
return process.env;
}
}

View file

@ -0,0 +1,27 @@
/*
* SPDX-FileCopyrightText: hazelnoot and other Sharkey contributors
* SPDX-License-Identifier: AGPL-3.0-only
*/
import { Injectable } from '@nestjs/common';
/**
* Provides abstractions to access the current time.
* Exists for unit testing purposes, so that tests can "simulate" any given time for consistency.
*/
@Injectable()
export class TimeService {
/**
* Returns Date.now()
*/
public get now() {
return Date.now();
}
/**
* Returns a new Date instance.
*/
public get date() {
return new Date();
}
}

View file

@ -6,6 +6,7 @@
import { Module } from '@nestjs/common'; import { Module } from '@nestjs/common';
import { EndpointsModule } from '@/server/api/EndpointsModule.js'; import { EndpointsModule } from '@/server/api/EndpointsModule.js';
import { CoreModule } from '@/core/CoreModule.js'; import { CoreModule } from '@/core/CoreModule.js';
import { SkRateLimiterService } from '@/server/api/SkRateLimiterService.js';
import { ApiCallService } from './api/ApiCallService.js'; import { ApiCallService } from './api/ApiCallService.js';
import { FileServerService } from './FileServerService.js'; import { FileServerService } from './FileServerService.js';
import { HealthServerService } from './HealthServerService.js'; import { HealthServerService } from './HealthServerService.js';
@ -73,7 +74,10 @@ import { SigninWithPasskeyApiService } from './api/SigninWithPasskeyApiService.j
ApiLoggerService, ApiLoggerService,
ApiServerService, ApiServerService,
AuthenticateService, AuthenticateService,
RateLimiterService, {
provide: RateLimiterService,
useClass: SkRateLimiterService,
},
SigninApiService, SigninApiService,
SigninWithPasskeyApiService, SigninWithPasskeyApiService,
SigninService, SigninService,

View file

@ -8,6 +8,7 @@ import * as fs from 'node:fs';
import * as stream from 'node:stream/promises'; import * as stream from 'node:stream/promises';
import { Inject, Injectable } from '@nestjs/common'; import { Inject, Injectable } from '@nestjs/common';
import * as Sentry from '@sentry/node'; import * as Sentry from '@sentry/node';
import { LimiterInfo } from 'ratelimiter';
import { DI } from '@/di-symbols.js'; import { DI } from '@/di-symbols.js';
import { getIpHash } from '@/misc/get-ip-hash.js'; import { getIpHash } from '@/misc/get-ip-hash.js';
import type { MiLocalUser, MiUser } from '@/models/User.js'; import type { MiLocalUser, MiUser } from '@/models/User.js';
@ -18,6 +19,7 @@ import { createTemp } from '@/misc/create-temp.js';
import { bindThis } from '@/decorators.js'; import { bindThis } from '@/decorators.js';
import { RoleService } from '@/core/RoleService.js'; import { RoleService } from '@/core/RoleService.js';
import type { Config } from '@/config.js'; import type { Config } from '@/config.js';
import { isLimitInfo } from '@/server/api/SkRateLimiterService.js';
import { ApiError } from './error.js'; import { ApiError } from './error.js';
import { RateLimiterService } from './RateLimiterService.js'; import { RateLimiterService } from './RateLimiterService.js';
import { ApiLoggerService } from './ApiLoggerService.js'; import { ApiLoggerService } from './ApiLoggerService.js';
@ -68,12 +70,17 @@ export class ApiCallService implements OnApplicationShutdown {
} else if (err.code === 'RATE_LIMIT_EXCEEDED') { } else if (err.code === 'RATE_LIMIT_EXCEEDED') {
const info: unknown = err.info; const info: unknown = err.info;
const unixEpochInSeconds = Date.now(); const unixEpochInSeconds = Date.now();
if (typeof(info) === 'object' && info && 'resetMs' in info && typeof(info.resetMs) === 'number') { if (isLimitInfo(info)) {
// Number of seconds to wait before trying again. Left for backwards compatibility.
reply.header('Retry-After', info.resetSec.toString());
// Number of milliseconds to wait before trying again.
reply.header('X-RateLimit-Reset', info.resetMs.toString());
} else if (typeof(info) === 'object' && info && 'resetMs' in info && typeof(info.resetMs) === 'number') {
const cooldownInSeconds = Math.ceil((info.resetMs - unixEpochInSeconds) / 1000); const cooldownInSeconds = Math.ceil((info.resetMs - unixEpochInSeconds) / 1000);
// もしかするとマイナスになる可能性がなくはないのでマイナスだったら0にしておく // もしかするとマイナスになる可能性がなくはないのでマイナスだったら0にしておく
reply.header('Retry-After', Math.max(cooldownInSeconds, 0).toString(10)); reply.header('Retry-After', Math.max(cooldownInSeconds, 0).toString(10));
} else { } else {
this.logger.warn(`rate limit information has unexpected type ${typeof(err.info?.reset)}`); this.logger.warn(`rate limit information has unexpected type: ${JSON.stringify(info)}`);
} }
} else if (err.kind === 'client') { } else if (err.kind === 'client') {
reply.header('WWW-Authenticate', `Bearer realm="Misskey", error="invalid_request", error_description="${err.message}"`); reply.header('WWW-Authenticate', `Bearer realm="Misskey", error="invalid_request", error_description="${err.message}"`);
@ -168,7 +175,7 @@ export class ApiCallService implements OnApplicationShutdown {
return; return;
} }
this.authenticateService.authenticate(token).then(([user, app]) => { this.authenticateService.authenticate(token).then(([user, app]) => {
this.call(endpoint, user, app, body, null, request).then((res) => { this.call(endpoint, user, app, body, null, request, reply).then((res) => {
if (request.method === 'GET' && endpoint.meta.cacheSec && !token && !user) { if (request.method === 'GET' && endpoint.meta.cacheSec && !token && !user) {
reply.header('Cache-Control', `public, max-age=${endpoint.meta.cacheSec}`); reply.header('Cache-Control', `public, max-age=${endpoint.meta.cacheSec}`);
} }
@ -229,7 +236,7 @@ export class ApiCallService implements OnApplicationShutdown {
this.call(endpoint, user, app, fields, { this.call(endpoint, user, app, fields, {
name: multipartData.filename, name: multipartData.filename,
path: path, path: path,
}, request).then((res) => { }, request, reply).then((res) => {
this.send(reply, res); this.send(reply, res);
}).catch((err: ApiError) => { }).catch((err: ApiError) => {
this.#sendApiError(reply, err); this.#sendApiError(reply, err);
@ -304,6 +311,7 @@ export class ApiCallService implements OnApplicationShutdown {
path: string; path: string;
} | null, } | null,
request: FastifyRequest<{ Body: Record<string, unknown> | undefined, Querystring: Record<string, unknown> }>, request: FastifyRequest<{ Body: Record<string, unknown> | undefined, Querystring: Record<string, unknown> }>,
reply: FastifyReply,
) { ) {
const isSecure = user != null && token == null; const isSecure = user != null && token == null;
@ -339,19 +347,41 @@ export class ApiCallService implements OnApplicationShutdown {
if (factor > 0) { if (factor > 0) {
// Rate limit // Rate limit
await this.rateLimiterService.limit(limit as IEndpointMeta['limit'] & { key: NonNullable<string> }, limitActor, factor).catch(err => { const info = await this.rateLimiterService.limit(limit as IEndpointMeta['limit'] & { key: NonNullable<string> }, limitActor, factor)
if ('info' in err) { .then(info => {
// errはLimiter.LimiterInfoであることが期待される // We always want these headers, because clients need them for pacing.
throw new ApiError({ // Conditional check in case we somehow revert to the old limiter, which does not return info.
message: 'Rate limit exceeded. Please try again later.', if (info) {
code: 'RATE_LIMIT_EXCEEDED', // Number of seconds until the limit has fully reset.
id: 'd5826d14-3982-4d2e-8011-b9e9f02499ef', reply.header('X-RateLimit-Clear', info.fullResetSec.toString());
httpStatusCode: 429, // Number of calls that can be made before being limited.
}, err.info); reply.header('X-RateLimit-Remaining', info.remaining.toString());
} else {
throw new TypeError('information must be a rate-limiter information.'); // Only forward the info object if it's blocked, otherwise we'll reject *all* requests
} if (info.blocked) {
}); return info;
}
}
return undefined;
})
.catch(err => {
// The old limiter throws info instead of returning it.
if ('info' in err) {
return err.info as LimiterInfo;
} else {
throw err;
}
});
if (info) {
throw new ApiError({
message: 'Rate limit exceeded. Please try again later.',
code: 'RATE_LIMIT_EXCEEDED',
id: 'd5826d14-3982-4d2e-8011-b9e9f02499ef',
httpStatusCode: 429,
}, info);
}
} }
} }

View file

@ -10,28 +10,28 @@ import { DI } from '@/di-symbols.js';
import type Logger from '@/logger.js'; import type Logger from '@/logger.js';
import { LoggerService } from '@/core/LoggerService.js'; import { LoggerService } from '@/core/LoggerService.js';
import { bindThis } from '@/decorators.js'; import { bindThis } from '@/decorators.js';
import type { LimitInfo } from '@/server/api/SkRateLimiterService.js';
import { EnvService } from '@/core/EnvService.js';
import type { IEndpointMeta } from './endpoints.js'; import type { IEndpointMeta } from './endpoints.js';
@Injectable() @Injectable()
export class RateLimiterService { export class RateLimiterService {
private logger: Logger; protected readonly logger: Logger;
private disabled = false; protected readonly disabled: boolean;
constructor( constructor(
@Inject(DI.redis) @Inject(DI.redis)
private redisClient: Redis.Redis, protected readonly redisClient: Redis.Redis,
private loggerService: LoggerService, private loggerService: LoggerService,
envService: EnvService,
) { ) {
this.logger = this.loggerService.getLogger('limiter'); this.logger = this.loggerService.getLogger('limiter');
this.disabled = envService.env.NODE_ENV !== 'production';
if (process.env.NODE_ENV !== 'production') {
this.disabled = true;
}
} }
@bindThis @bindThis
public limit(limitation: IEndpointMeta['limit'] & { key: NonNullable<string> }, actor: string, factor = 1) { public limit(limitation: IEndpointMeta['limit'] & { key: NonNullable<string> }, actor: string, factor = 1): Promise<LimitInfo | void> {
return new Promise<void>((ok, reject) => { return new Promise<void>((ok, reject) => {
if (this.disabled) ok(); if (this.disabled) ok();

View file

@ -0,0 +1,279 @@
/*
* SPDX-FileCopyrightText: hazelnoot and other Sharkey contributors
* SPDX-License-Identifier: AGPL-3.0-only
*/
import { Injectable } from '@nestjs/common';
import Redis from 'ioredis';
import type { IEndpointMeta } from '@/server/api/endpoints.js';
import { LoggerService } from '@/core/LoggerService.js';
import { TimeService } from '@/core/TimeService.js';
import { EnvService } from '@/core/EnvService.js';
import { RateLimiterService } from './RateLimiterService.js';
/**
* Metadata about the current status of a rate limiter
*/
export interface LimitInfo {
/**
* True if the limit has been reached, and the call should be blocked.
*/
blocked: boolean;
/**
* Number of calls that can be made before the limit is triggered.
*/
remaining: number;
/**
* Time in seconds until the next call can be made, or zero if the next call can be made immediately.
* Rounded up to the nearest second.
*/
resetSec: number;
/**
* Time in milliseconds until the next call can be made, or zero if the next call can be made immediately.
* Rounded up to the nearest milliseconds.
*/
resetMs: number;
/**
* Time in seconds until the limit has fully reset.
* Rounded up to the nearest second.
*/
fullResetSec: number;
/**
* Time in milliseconds until the limit has fully reset.
* Rounded up to the nearest millisecond.
*/
fullResetMs: number;
}
export function isLimitInfo(info: unknown): info is LimitInfo {
if (info == null) return false;
if (typeof(info) !== 'object') return false;
if (!('blocked' in info) || typeof(info.blocked) !== 'boolean') return false;
if (!('remaining' in info) || typeof(info.remaining) !== 'number') return false;
if (!('resetSec' in info) || typeof(info.resetSec) !== 'number') return false;
if (!('resetMs' in info) || typeof(info.resetMs) !== 'number') return false;
if (!('fullResetSec' in info) || typeof(info.fullResetSec) !== 'number') return false;
if (!('fullResetMs' in info) || typeof(info.fullResetMs) !== 'number') return false;
return true;
}
/**
* Rate limit based on "leaky bucket" logic.
* The bucket count increases with each call, and decreases gradually at a given rate.
* The subject is blocked until the bucket count drops below the limit.
*/
export interface RateLimit {
/**
* Unique key identifying the particular resource (or resource group) being limited.
*/
key: string;
/**
* Constant value identifying the type of rate limit.
*/
type: 'bucket';
/**
* Size of the bucket, in number of requests.
* The subject will be blocked when the number of calls exceeds this size.
*/
size: number;
/**
* How often the bucket should "drip" and reduce the counter, measured in milliseconds.
* Defaults to 1000 (1 second).
*/
dripRate?: number;
/**
* Amount to reduce the counter on each drip.
* Defaults to 1.
*/
dripSize?: number;
}
export type SupportedRateLimit = RateLimit | LegacyRateLimit;
export type LegacyRateLimit = IEndpointMeta['limit'] & { key: NonNullable<string>, type: undefined | 'legacy' };
export function isLegacyRateLimit(limit: SupportedRateLimit): limit is LegacyRateLimit {
return limit.type === undefined || limit.type === 'legacy';
}
export function hasMinLimit(limit: LegacyRateLimit): limit is LegacyRateLimit & { minInterval: number } {
return !!limit.minInterval;
}
@Injectable()
export class SkRateLimiterService extends RateLimiterService {
constructor(
private readonly timeService: TimeService,
redisClient: Redis.Redis,
loggerService: LoggerService,
envService: EnvService,
) {
super(redisClient, loggerService, envService);
}
public async limit(limit: SupportedRateLimit, actor: string, factor = 1): Promise<LimitInfo> {
if (this.disabled) {
return {
blocked: false,
remaining: Number.MAX_SAFE_INTEGER,
resetSec: 0,
resetMs: 0,
fullResetSec: 0,
fullResetMs: 0,
};
}
if (isLegacyRateLimit(limit)) {
return await this.limitLegacy(limit, actor, factor);
} else {
return await this.limitBucket(limit, actor, factor);
}
}
private async limitLegacy(limit: LegacyRateLimit, actor: string, factor: number): Promise<LimitInfo> {
const promises: Promise<LimitInfo | null>[] = [];
// The "min" limit - if present - is handled directly.
if (hasMinLimit(limit)) {
promises.push(
this.limitMin(limit, actor, factor),
);
}
// Convert the "max" limit into a leaky bucket with 1 drip / second rate.
if (limit.max && limit.duration) {
promises.push(
this.limitBucket({
type: 'bucket',
key: limit.key,
size: limit.max,
dripRate: Math.round(limit.duration / limit.max),
}, actor, factor),
);
}
const [lim1, lim2] = await Promise.all(promises);
return {
blocked: (lim1?.blocked || lim2?.blocked) ?? false,
remaining: Math.min(lim1?.remaining ?? 1, lim2?.remaining ?? 1),
resetSec: Math.max(lim1?.resetSec ?? 0, lim2?.resetSec ?? 0),
resetMs: Math.max(lim1?.resetMs ?? 0, lim2?.resetMs ?? 0),
fullResetSec: Math.max(lim1?.fullResetSec ?? 0, lim2?.fullResetSec ?? 0),
fullResetMs: Math.max(lim1?.fullResetMs ?? 0, lim2?.fullResetMs ?? 0),
};
}
private async limitMin(limit: LegacyRateLimit & { minInterval: number }, actor: string, factor: number): Promise<LimitInfo | null> {
const counter = await this.getLimitCounter(limit, actor, 'min');
const maxCalls = Math.max(Math.ceil(factor), 1);
// Update expiration
if (counter.c >= maxCalls) {
const isCleared = this.timeService.now - counter.t >= limit.minInterval;
if (isCleared) {
counter.c = 0;
}
}
const blocked = counter.c >= maxCalls;
if (!blocked) {
counter.c++;
counter.t = this.timeService.now;
}
// Calculate limit status
const remaining = Math.max(maxCalls - counter.c, 0);
const fullResetMs = Math.max(Math.ceil(limit.minInterval - (this.timeService.now - counter.t)), 0);
const fullResetSec = Math.ceil(fullResetMs / 1000);
const resetMs = remaining < 1 ? fullResetMs : 0;
const resetSec = remaining < 1 ? fullResetSec : 0;
const limitInfo: LimitInfo = { blocked, remaining, resetSec, resetMs, fullResetSec, fullResetMs,
};
// Update the limit counter, but not if blocked
if (!blocked) {
// Don't await, or we will slow down the API.
this.setLimitCounter(limit, actor, counter, resetMs, 'min')
.catch(err => this.logger.error(`Failed to update limit ${limit.key}:min for ${actor}:`, err));
}
return limitInfo;
}
private async limitBucket(limit: RateLimit, actor: string, factor: number): Promise<LimitInfo> {
const counter = await this.getLimitCounter(limit, actor);
const dripRate = (limit.dripRate ?? 1000);
const dripSize = (limit.dripSize ?? 1);
const bucketSize = (limit.size * factor);
// Update drips
if (counter.c > 0) {
const dripsSinceLastTick = Math.floor((this.timeService.now - counter.t) / dripRate) * dripSize;
counter.c = Math.max(counter.c - dripsSinceLastTick, 0);
}
const blocked = counter.c >= bucketSize;
if (!blocked) {
counter.c++;
counter.t = this.timeService.now;
}
// Calculate limit status
const remaining = Math.max(bucketSize - counter.c, 0);
const resetMs = remaining > 0 ? 0 : Math.max(dripRate - (this.timeService.now - counter.t), 0);
const resetSec = Math.ceil(resetMs / 1000);
const fullResetMs = Math.ceil(counter.c / dripSize) * dripRate;
const fullResetSec = Math.ceil(fullResetMs / 1000);
const limitInfo: LimitInfo = { blocked, remaining, resetSec, resetMs, fullResetSec, fullResetMs };
// Update the limit counter, but not if blocked
if (!blocked) {
// Don't await, or we will slow down the API.
this.setLimitCounter(limit, actor, counter, fullResetMs)
.catch(err => this.logger.error(`Failed to update limit ${limit.key} for ${actor}:`, err));
}
return limitInfo;
}
private async getLimitCounter(limit: SupportedRateLimit, actor: string, subject?: string): Promise<LimitCounter> {
const key = createLimitKey(limit, actor, subject);
const value = await this.redisClient.get(key);
if (value == null) {
return { t: 0, c: 0 };
}
return JSON.parse(value);
}
private async setLimitCounter(limit: SupportedRateLimit, actor: string, counter: LimitCounter, expirationMs: number, subject?: string): Promise<void> {
const key = createLimitKey(limit, actor, subject);
const value = JSON.stringify(counter);
await this.redisClient.set(key, value, 'PX', expirationMs);
}
}
function createLimitKey(limit: SupportedRateLimit, actor: string, subject?: string): string {
if (subject) {
return `rl_${actor}_${limit.key}_${subject}`;
} else {
return `rl_${actor}_${limit.key}`;
}
}
export interface LimitCounter {
/** Timestamp */
t: number;
/** Counter */
c: number;
}

View file

@ -0,0 +1,703 @@
/*
* SPDX-FileCopyrightText: hazelnoot and other Sharkey contributors
* SPDX-License-Identifier: AGPL-3.0-only
*/
import { KEYWORD } from 'color-convert/conversions.js';
import type Redis from 'ioredis';
import { LegacyRateLimit, LimitCounter, RateLimit, SkRateLimiterService } from '@/server/api/SkRateLimiterService.js';
import { LoggerService } from '@/core/LoggerService.js';
/* eslint-disable @typescript-eslint/no-non-null-assertion */
/* eslint-disable @typescript-eslint/no-unnecessary-condition */
describe(SkRateLimiterService, () => {
let mockTimeService: { now: number, date: Date } = null!;
let mockRedisGet: ((key: string) => string | null) | undefined = undefined;
let mockRedisSet: ((args: unknown[]) => void) | undefined = undefined;
let mockEnvironment: Record<string, string | undefined> = null!;
let serviceUnderTest: () => SkRateLimiterService = null!;
let loggedMessages: { level: string, data: unknown[] }[] = [];
beforeEach(() => {
mockTimeService = {
now: 0,
get date() {
return new Date(mockTimeService.now);
},
};
mockRedisGet = undefined;
mockRedisSet = undefined;
const mockRedisClient = {
get(key: string) {
if (mockRedisGet) return Promise.resolve(mockRedisGet(key));
else return Promise.resolve(null);
},
set(...args: unknown[]): Promise<void> {
if (mockRedisSet) mockRedisSet(args);
return Promise.resolve();
},
} as unknown as Redis.Redis;
mockEnvironment = Object.create(process.env);
mockEnvironment.NODE_ENV = 'production';
const mockEnvService = {
env: mockEnvironment,
};
loggedMessages = [];
const mockLogService = {
getLogger() {
return {
createSubLogger(context: string, color?: KEYWORD) {
return mockLogService.getLogger(context, color);
},
error(...data: unknown[]) {
loggedMessages.push({ level: 'error', data });
},
warn(...data: unknown[]) {
loggedMessages.push({ level: 'warn', data });
},
succ(...data: unknown[]) {
loggedMessages.push({ level: 'succ', data });
},
debug(...data: unknown[]) {
loggedMessages.push({ level: 'debug', data });
},
info(...data: unknown[]) {
loggedMessages.push({ level: 'info', data });
},
};
},
} as unknown as LoggerService;
let service: SkRateLimiterService | undefined = undefined;
serviceUnderTest = () => {
return service ??= new SkRateLimiterService(mockTimeService, mockRedisClient, mockLogService, mockEnvService);
};
});
function expectNoUnhandledErrors() {
const unhandledErrors = loggedMessages.filter(m => m.level === 'error');
if (unhandledErrors.length > 0) {
throw new Error(`Test failed: got unhandled errors ${unhandledErrors.join('\n')}`);
}
}
describe('limit', () => {
const actor = 'actor';
const key = 'test';
let counter: LimitCounter | undefined = undefined;
let minCounter: LimitCounter | undefined = undefined;
beforeEach(() => {
counter = undefined;
minCounter = undefined;
mockRedisGet = (key: string) => {
if (key === 'rl_actor_test' && counter) {
return JSON.stringify(counter);
}
if (key === 'rl_actor_test_min' && minCounter) {
return JSON.stringify(minCounter);
}
return null;
};
mockRedisSet = (args: unknown[]) => {
const [key, value] = args;
if (key === 'rl_actor_test') {
if (value == null) counter = undefined;
else if (typeof(value) === 'string') counter = JSON.parse(value);
else throw new Error('invalid redis call');
}
if (key === 'rl_actor_test_min') {
if (value == null) minCounter = undefined;
else if (typeof(value) === 'string') minCounter = JSON.parse(value);
else throw new Error('invalid redis call');
}
};
});
it('should bypass in non-production', async () => {
mockEnvironment.NODE_ENV = 'test';
const info = await serviceUnderTest().limit({ key: 'l', type: undefined, max: 0 }, 'actor');
expect(info.blocked).toBeFalsy();
expect(info.remaining).toBe(Number.MAX_SAFE_INTEGER);
expect(info.resetSec).toBe(0);
expect(info.resetMs).toBe(0);
expect(info.fullResetSec).toBe(0);
expect(info.fullResetMs).toBe(0);
});
describe('with bucket limit', () => {
let limit: RateLimit = null!;
beforeEach(() => {
limit = {
type: 'bucket',
key: 'test',
size: 1,
};
});
it('should allow when limit is not reached', async () => {
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeFalsy();
});
it('should not error when allowed', async () => {
await serviceUnderTest().limit(limit, actor);
expectNoUnhandledErrors();
});
it('should return correct info when allowed', async () => {
limit.size = 2;
counter = { c: 1, t: 0 };
const info = await serviceUnderTest().limit(limit, actor);
expect(info.remaining).toBe(0);
expect(info.resetSec).toBe(1);
expect(info.resetMs).toBe(1000);
expect(info.fullResetSec).toBe(2);
expect(info.fullResetMs).toBe(2000);
});
it('should increment counter when called', async () => {
await serviceUnderTest().limit(limit, actor);
expect(counter).not.toBeUndefined();
expect(counter?.c).toBe(1);
});
it('should set timestamp when called', async () => {
mockTimeService.now = 1000;
await serviceUnderTest().limit(limit, actor);
expect(counter).not.toBeUndefined();
expect(counter?.t).toBe(1000);
});
it('should decrement counter when dripRate has passed', async () => {
counter = { c: 2, t: 0 };
mockTimeService.now = 2000;
await serviceUnderTest().limit(limit, actor);
expect(counter).not.toBeUndefined();
expect(counter?.c).toBe(1); // 2 (starting) - 2 (2x1 drip) + 1 (call) = 1
});
it('should decrement counter by dripSize', async () => {
counter = { c: 2, t: 0 };
limit.dripSize = 2;
mockTimeService.now = 1000;
await serviceUnderTest().limit(limit, actor);
expect(counter).not.toBeUndefined();
expect(counter?.c).toBe(1); // 2 (starting) - 2 (1x2 drip) + 1 (call) = 1
});
it('should maintain counter between calls over time', async () => {
limit.size = 5;
await serviceUnderTest().limit(limit, actor); // 0 + 1 = 1
mockTimeService.now += 1000; // 1 - 1 = 0
await serviceUnderTest().limit(limit, actor); // 0 + 1 = 1
await serviceUnderTest().limit(limit, actor); // 1 + 1 = 2
await serviceUnderTest().limit(limit, actor); // 2 + 1 = 3
mockTimeService.now += 1000; // 3 - 1 = 2
mockTimeService.now += 1000; // 2 - 1 = 1
await serviceUnderTest().limit(limit, actor); // 1 + 1 = 2
expect(counter?.c).toBe(2);
expect(counter?.t).toBe(3000);
});
it('should log error and continue when update fails', async () => {
mockRedisSet = () => {
throw new Error('test error');
};
await serviceUnderTest().limit(limit, actor);
const matchingError = loggedMessages
.find(m => m.level === 'error' && m.data
.some(d => typeof(d) === 'string' && d.includes('Failed to update limit')));
expect(matchingError).toBeTruthy();
});
it('should block when bucket is filled', async () => {
counter = { c: 1, t: 0 };
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeTruthy();
});
it('should calculate correct info when blocked', async () => {
counter = { c: 1, t: 0 };
const info = await serviceUnderTest().limit(limit, actor);
expect(info.resetSec).toBe(1);
expect(info.resetMs).toBe(1000);
expect(info.fullResetSec).toBe(1);
expect(info.fullResetMs).toBe(1000);
});
it('should allow when bucket is filled but should drip', async () => {
counter = { c: 1, t: 0 };
mockTimeService.now = 1000;
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeFalsy();
});
it('should scale limit by factor', async () => {
counter = { c: 1, t: 0 };
const i1 = await serviceUnderTest().limit(limit, actor, 2); // 1 + 1 = 2
const i2 = await serviceUnderTest().limit(limit, actor, 2); // 2 + 1 = 3
expect(i1.blocked).toBeFalsy();
expect(i2.blocked).toBeTruthy();
});
it('should set key expiration', async () => {
mockRedisSet = args => {
expect(args[2]).toBe('PX');
expect(args[3]).toBe(1000);
};
await serviceUnderTest().limit(limit, actor);
});
it('should not increment when already blocked', async () => {
counter = { c: 1, t: 0 };
mockTimeService.now += 100;
await serviceUnderTest().limit(limit, actor);
expect(counter?.c).toBe(1);
expect(counter?.t).toBe(0);
});
});
describe('with min interval', () => {
let limit: MutableLegacyRateLimit = null!;
beforeEach(() => {
limit = {
type: undefined,
key,
minInterval: 1000,
};
});
it('should allow when limit is not reached', async () => {
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeFalsy();
});
it('should not error when allowed', async () => {
await serviceUnderTest().limit(limit, actor);
expectNoUnhandledErrors();
});
it('should calculate correct info when allowed', async () => {
const info = await serviceUnderTest().limit(limit, actor);
expect(info.remaining).toBe(0);
expect(info.resetSec).toBe(1);
expect(info.resetMs).toBe(1000);
expect(info.fullResetSec).toBe(1);
expect(info.fullResetMs).toBe(1000);
});
it('should increment counter when called', async () => {
await serviceUnderTest().limit(limit, actor);
expect(minCounter).not.toBeUndefined();
expect(minCounter?.c).toBe(1);
});
it('should set timestamp when called', async () => {
mockTimeService.now = 1000;
await serviceUnderTest().limit(limit, actor);
expect(minCounter).not.toBeUndefined();
expect(minCounter?.t).toBe(1000);
});
it('should decrement counter when minInterval has passed', async () => {
minCounter = { c: 1, t: 0 };
mockTimeService.now = 1000;
await serviceUnderTest().limit(limit, actor);
expect(minCounter).not.toBeUndefined();
expect(minCounter?.c).toBe(1); // 1 (starting) - 1 (interval) + 1 (call) = 1
});
it('should reset counter entirely', async () => {
minCounter = { c: 2, t: 0 };
mockTimeService.now = 1000;
await serviceUnderTest().limit(limit, actor);
expect(minCounter).not.toBeUndefined();
expect(minCounter?.c).toBe(1); // 2 (starting) - 2 (interval) + 1 (call) = 1
});
it('should maintain counter between calls over time', async () => {
await serviceUnderTest().limit(limit, actor); // 0 + 1 = 1
mockTimeService.now += 1000; // 1 - 1 = 0
await serviceUnderTest().limit(limit, actor); // 0 + 1 = 1
await serviceUnderTest().limit(limit, actor); // blocked
await serviceUnderTest().limit(limit, actor); // blocked
mockTimeService.now += 1000; // 1 - 1 = 0
mockTimeService.now += 1000; // 0 - 1 = 0
await serviceUnderTest().limit(limit, actor); // 0 + 1 = 1
expect(minCounter?.c).toBe(1);
expect(minCounter?.t).toBe(3000);
});
it('should log error and continue when update fails', async () => {
mockRedisSet = () => {
throw new Error('test error');
};
await serviceUnderTest().limit(limit, actor);
const matchingError = loggedMessages
.find(m => m.level === 'error' && m.data
.some(d => typeof(d) === 'string' && d.includes('Failed to update limit')));
expect(matchingError).toBeTruthy();
});
it('should block when interval exceeded', async () => {
minCounter = { c: 1, t: 0 };
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeTruthy();
});
it('should calculate correct info when blocked', async () => {
minCounter = { c: 1, t: 0 };
const info = await serviceUnderTest().limit(limit, actor);
expect(info.resetSec).toBe(1);
expect(info.resetMs).toBe(1000);
expect(info.fullResetSec).toBe(1);
expect(info.fullResetMs).toBe(1000);
});
it('should allow when bucket is filled but interval has passed', async () => {
minCounter = { c: 1, t: 0 };
mockTimeService.now = 1000;
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeFalsy();
});
it('should scale limit by factor', async () => {
minCounter = { c: 1, t: 0 };
const i1 = await serviceUnderTest().limit(limit, actor, 2); // 1 + 1 = 2
const i2 = await serviceUnderTest().limit(limit, actor, 2); // 2 + 1 = 3
expect(i1.blocked).toBeFalsy();
expect(i2.blocked).toBeTruthy();
});
it('should set key expiration', async () => {
mockRedisSet = args => {
expect(args[2]).toBe('PX');
expect(args[3]).toBe(1000);
};
await serviceUnderTest().limit(limit, actor);
});
it('should not increment when already blocked', async () => {
minCounter = { c: 1, t: 0 };
mockTimeService.now += 100;
await serviceUnderTest().limit(limit, actor);
expect(minCounter?.c).toBe(1);
expect(minCounter?.t).toBe(0);
});
});
describe('with legacy limit', () => {
let limit: MutableLegacyRateLimit = null!;
beforeEach(() => {
limit = {
type: undefined,
key,
max: 1,
duration: 1000,
};
});
it('should allow when limit is not reached', async () => {
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeFalsy();
});
it('should not error when allowed', async () => {
await serviceUnderTest().limit(limit, actor);
expectNoUnhandledErrors();
});
it('should infer dripRate from duration', async () => {
limit.max = 10;
limit.duration = 10000;
counter = { c: 10, t: 0 };
const i1 = await serviceUnderTest().limit(limit, actor);
mockTimeService.now += 1000;
const i2 = await serviceUnderTest().limit(limit, actor);
mockTimeService.now += 2000;
const i3 = await serviceUnderTest().limit(limit, actor);
const i4 = await serviceUnderTest().limit(limit, actor);
const i5 = await serviceUnderTest().limit(limit, actor);
mockTimeService.now += 2000;
const i6 = await serviceUnderTest().limit(limit, actor);
expect(i1.blocked).toBeTruthy();
expect(i2.blocked).toBeFalsy();
expect(i3.blocked).toBeFalsy();
expect(i4.blocked).toBeFalsy();
expect(i5.blocked).toBeTruthy();
expect(i6.blocked).toBeFalsy();
});
it('should calculate correct info when allowed', async () => {
limit.max = 10;
limit.duration = 10000;
counter = { c: 10, t: 0 };
mockTimeService.now += 2000;
const info = await serviceUnderTest().limit(limit, actor);
expect(info.remaining).toBe(1);
expect(info.resetSec).toBe(0);
expect(info.resetMs).toBe(0);
expect(info.fullResetSec).toBe(9);
expect(info.fullResetMs).toBe(9000);
});
it('should calculate correct info when blocked', async () => {
limit.max = 10;
limit.duration = 10000;
counter = { c: 10, t: 0 };
const info = await serviceUnderTest().limit(limit, actor);
expect(info.remaining).toBe(0);
expect(info.resetSec).toBe(1);
expect(info.resetMs).toBe(1000);
expect(info.fullResetSec).toBe(10);
expect(info.fullResetMs).toBe(10000);
});
it('should allow when bucket is filled but interval has passed', async () => {
counter = { c: 10, t: 0 };
mockTimeService.now = 1000;
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeTruthy();
});
it('should scale limit by factor', async () => {
counter = { c: 10, t: 0 };
const info = await serviceUnderTest().limit(limit, actor, 2); // 10 + 1 = 11
expect(info.blocked).toBeTruthy();
});
it('should set key expiration', async () => {
mockRedisSet = args => {
expect(args[2]).toBe('PX');
expect(args[3]).toBe(1000);
};
await serviceUnderTest().limit(limit, actor);
});
it('should not increment when already blocked', async () => {
counter = { c: 1, t: 0 };
mockTimeService.now += 100;
await serviceUnderTest().limit(limit, actor);
expect(counter?.c).toBe(1);
expect(counter?.t).toBe(0);
});
});
describe('with legacy limit and min interval', () => {
let limit: MutableLegacyRateLimit = null!;
beforeEach(() => {
limit = {
type: undefined,
key,
max: 5,
duration: 5000,
minInterval: 1000,
};
});
it('should allow when limit is not reached', async () => {
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeFalsy();
});
it('should not error when allowed', async () => {
await serviceUnderTest().limit(limit, actor);
expectNoUnhandledErrors();
});
it('should block when limit exceeded', async () => {
counter = { c: 5, t: 0 };
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeTruthy();
});
it('should block when minInterval exceeded', async () => {
minCounter = { c: 1, t: 0 };
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeTruthy();
});
it('should calculate correct info when allowed', async () => {
counter = { c: 1, t: 0 };
const info = await serviceUnderTest().limit(limit, actor);
expect(info.remaining).toBe(0);
expect(info.resetSec).toBe(1);
expect(info.resetMs).toBe(1000);
expect(info.fullResetSec).toBe(2);
expect(info.fullResetMs).toBe(2000);
});
it('should calculate correct info when blocked by limit', async () => {
counter = { c: 5, t: 0 };
const info = await serviceUnderTest().limit(limit, actor);
expect(info.remaining).toBe(0);
expect(info.resetSec).toBe(1);
expect(info.resetMs).toBe(1000);
expect(info.fullResetSec).toBe(5);
expect(info.fullResetMs).toBe(5000);
});
it('should calculate correct info when blocked by minInterval', async () => {
minCounter = { c: 1, t: 0 };
const info = await serviceUnderTest().limit(limit, actor);
expect(info.remaining).toBe(0);
expect(info.resetSec).toBe(1);
expect(info.resetMs).toBe(1000);
expect(info.fullResetSec).toBe(1);
expect(info.fullResetMs).toBe(1000);
});
it('should allow when counter is filled but interval has passed', async () => {
counter = { c: 5, t: 0 };
mockTimeService.now = 1000;
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeFalsy();
});
it('should allow when minCounter is filled but interval has passed', async () => {
minCounter = { c: 1, t: 0 };
mockTimeService.now = 1000;
const info = await serviceUnderTest().limit(limit, actor);
expect(info.blocked).toBeFalsy();
});
it('should scale limit by factor', async () => {
minCounter = { c: 5, t: 0 };
const info = await serviceUnderTest().limit(limit, actor, 2);
expect(info.blocked).toBeTruthy();
});
it('should set key expiration', async () => {
mockRedisSet = args => {
expect(args[2]).toBe('PX');
expect(args[3]).toBe(1000);
};
await serviceUnderTest().limit(limit, actor);
});
it('should not increment when already blocked', async () => {
counter = { c: 5, t: 0 };
minCounter = { c: 1, t: 0 };
mockTimeService.now += 100;
await serviceUnderTest().limit(limit, actor);
expect(counter?.c).toBe(5);
expect(counter?.t).toBe(0);
expect(minCounter?.c).toBe(1);
expect(minCounter?.t).toBe(0);
});
});
});
});
// The same thing, but mutable
interface MutableLegacyRateLimit extends LegacyRateLimit {
key: string;
duration?: number;
max?: number;
minInterval?: number;
}