diff --git a/packages/backend/src/server/api/SkRateLimiterService.ts b/packages/backend/src/server/api/SkRateLimiterService.ts index 3e4b125e79..763de0029b 100644 --- a/packages/backend/src/server/api/SkRateLimiterService.ts +++ b/packages/backend/src/server/api/SkRateLimiterService.ts @@ -209,7 +209,7 @@ export class SkRateLimiterService extends RateLimiterService { // Update the limit counter, but not if blocked if (!blocked) { // Don't await, or we will slow down the API. - this.setLimitCounter(limit, actor, counter, resetMs, 'min') + this.setLimitCounter(limit, actor, counter, fullResetSec, 'min') .catch(err => this.logger.error(`Failed to update limit ${limit.key}:min for ${actor}:`, err)); } @@ -217,7 +217,7 @@ export class SkRateLimiterService extends RateLimiterService { } private async limitBucket(limit: RateLimit, actor: string, factor: number): Promise { - const counter = await this.getLimitCounter(limit, actor); + const counter = await this.getLimitCounter(limit, actor, 'bucket'); const dripRate = (limit.dripRate ?? 1000); const dripSize = (limit.dripSize ?? 1); const bucketSize = (limit.size * factor); @@ -245,14 +245,14 @@ export class SkRateLimiterService extends RateLimiterService { // Update the limit counter, but not if blocked if (!blocked) { // Don't await, or we will slow down the API. - this.setLimitCounter(limit, actor, counter, fullResetMs) + this.setLimitCounter(limit, actor, counter, fullResetSec, 'bucket') .catch(err => this.logger.error(`Failed to update limit ${limit.key} for ${actor}:`, err)); } return limitInfo; } - private async getLimitCounter(limit: SupportedRateLimit, actor: string, subject?: string): Promise { + private async getLimitCounter(limit: SupportedRateLimit, actor: string, subject: string): Promise { const key = createLimitKey(limit, actor, subject); const value = await this.redisClient.get(key); @@ -263,19 +263,16 @@ export class SkRateLimiterService extends RateLimiterService { return JSON.parse(value); } - private async setLimitCounter(limit: SupportedRateLimit, actor: string, counter: LimitCounter, expirationMs: number, subject?: string): Promise { + private async setLimitCounter(limit: SupportedRateLimit, actor: string, counter: LimitCounter, expiration: number, subject: string): Promise { const key = createLimitKey(limit, actor, subject); const value = JSON.stringify(counter); - await this.redisClient.set(key, value, 'PX', expirationMs); + const expirationSec = Math.max(expiration, 1); + await this.redisClient.set(key, value, 'EX', expirationSec); } } -function createLimitKey(limit: SupportedRateLimit, actor: string, subject?: string): string { - if (subject) { - return `rl_${actor}_${limit.key}_${subject}`; - } else { - return `rl_${actor}_${limit.key}`; - } +function createLimitKey(limit: SupportedRateLimit, actor: string, subject: string): string { + return `rl_${actor}_${limit.key}_${subject}`; } export interface LimitCounter { diff --git a/packages/backend/test/unit/server/api/SkRateLimiterServiceTests.ts b/packages/backend/test/unit/server/api/SkRateLimiterServiceTests.ts index 8554aa39ef..711894095d 100644 --- a/packages/backend/test/unit/server/api/SkRateLimiterServiceTests.ts +++ b/packages/backend/test/unit/server/api/SkRateLimiterServiceTests.ts @@ -4,6 +4,7 @@ */ import { KEYWORD } from 'color-convert/conversions.js'; +import { jest } from '@jest/globals'; import type Redis from 'ioredis'; import { LegacyRateLimit, LimitCounter, RateLimit, SkRateLimiterService } from '@/server/api/SkRateLimiterService.js'; import { LoggerService } from '@/core/LoggerService.js'; @@ -98,7 +99,7 @@ describe(SkRateLimiterService, () => { minCounter = undefined; mockRedisGet = (key: string) => { - if (key === 'rl_actor_test' && counter) { + if (key === 'rl_actor_test_bucket' && counter) { return JSON.stringify(counter); } @@ -112,7 +113,7 @@ describe(SkRateLimiterService, () => { mockRedisSet = (args: unknown[]) => { const [key, value] = args; - if (key === 'rl_actor_test') { + if (key === 'rl_actor_test_bucket') { if (value == null) counter = undefined; else if (typeof(value) === 'string') counter = JSON.parse(value); else throw new Error('invalid redis call'); @@ -280,12 +281,12 @@ describe(SkRateLimiterService, () => { }); it('should set key expiration', async () => { - mockRedisSet = args => { - expect(args[2]).toBe('PX'); - expect(args[3]).toBe(1000); - }; + const mock = jest.fn(mockRedisSet); + mockRedisSet = mock; await serviceUnderTest().limit(limit, actor); + + expect(mock).toHaveBeenCalledWith(['rl_actor_test_bucket', '{"t":0,"c":1}', 'EX', 1]); }); it('should not increment when already blocked', async () => { @@ -434,12 +435,12 @@ describe(SkRateLimiterService, () => { }); it('should set key expiration', async () => { - mockRedisSet = args => { - expect(args[2]).toBe('PX'); - expect(args[3]).toBe(1000); - }; + const mock = jest.fn(mockRedisSet); + mockRedisSet = mock; await serviceUnderTest().limit(limit, actor); + + expect(mock).toHaveBeenCalledWith(['rl_actor_test_min', '{"t":0,"c":1}', 'EX', 1]); }); it('should not increment when already blocked', async () => { @@ -547,12 +548,12 @@ describe(SkRateLimiterService, () => { }); it('should set key expiration', async () => { - mockRedisSet = args => { - expect(args[2]).toBe('PX'); - expect(args[3]).toBe(1000); - }; + const mock = jest.fn(mockRedisSet); + mockRedisSet = mock; await serviceUnderTest().limit(limit, actor); + + expect(mock).toHaveBeenCalledWith(['rl_actor_test_bucket', '{"t":0,"c":1}', 'EX', 1]); }); it('should not increment when already blocked', async () => { @@ -670,12 +671,13 @@ describe(SkRateLimiterService, () => { }); it('should set key expiration', async () => { - mockRedisSet = args => { - expect(args[2]).toBe('PX'); - expect(args[3]).toBe(1000); - }; + const mock = jest.fn(mockRedisSet); + mockRedisSet = mock; await serviceUnderTest().limit(limit, actor); + + expect(mock).toHaveBeenCalledWith(['rl_actor_test_bucket', '{"t":0,"c":1}', 'EX', 1]); + expect(mock).toHaveBeenCalledWith(['rl_actor_test_min', '{"t":0,"c":1}', 'EX', 1]); }); it('should not increment when already blocked', async () => {