fix rate limit storage in redis

This commit is contained in:
Hazelnoot 2024-12-07 12:15:38 -05:00
parent 8239ce4282
commit 32635ecc25
2 changed files with 29 additions and 30 deletions

View file

@ -209,7 +209,7 @@ export class SkRateLimiterService extends RateLimiterService {
// Update the limit counter, but not if blocked // Update the limit counter, but not if blocked
if (!blocked) { if (!blocked) {
// Don't await, or we will slow down the API. // Don't await, or we will slow down the API.
this.setLimitCounter(limit, actor, counter, resetMs, 'min') this.setLimitCounter(limit, actor, counter, fullResetSec, 'min')
.catch(err => this.logger.error(`Failed to update limit ${limit.key}:min for ${actor}:`, err)); .catch(err => this.logger.error(`Failed to update limit ${limit.key}:min for ${actor}:`, err));
} }
@ -217,7 +217,7 @@ export class SkRateLimiterService extends RateLimiterService {
} }
private async limitBucket(limit: RateLimit, actor: string, factor: number): Promise<LimitInfo> { private async limitBucket(limit: RateLimit, actor: string, factor: number): Promise<LimitInfo> {
const counter = await this.getLimitCounter(limit, actor); const counter = await this.getLimitCounter(limit, actor, 'bucket');
const dripRate = (limit.dripRate ?? 1000); const dripRate = (limit.dripRate ?? 1000);
const dripSize = (limit.dripSize ?? 1); const dripSize = (limit.dripSize ?? 1);
const bucketSize = (limit.size * factor); const bucketSize = (limit.size * factor);
@ -245,14 +245,14 @@ export class SkRateLimiterService extends RateLimiterService {
// Update the limit counter, but not if blocked // Update the limit counter, but not if blocked
if (!blocked) { if (!blocked) {
// Don't await, or we will slow down the API. // Don't await, or we will slow down the API.
this.setLimitCounter(limit, actor, counter, fullResetMs) this.setLimitCounter(limit, actor, counter, fullResetSec, 'bucket')
.catch(err => this.logger.error(`Failed to update limit ${limit.key} for ${actor}:`, err)); .catch(err => this.logger.error(`Failed to update limit ${limit.key} for ${actor}:`, err));
} }
return limitInfo; return limitInfo;
} }
private async getLimitCounter(limit: SupportedRateLimit, actor: string, subject?: string): Promise<LimitCounter> { private async getLimitCounter(limit: SupportedRateLimit, actor: string, subject: string): Promise<LimitCounter> {
const key = createLimitKey(limit, actor, subject); const key = createLimitKey(limit, actor, subject);
const value = await this.redisClient.get(key); const value = await this.redisClient.get(key);
@ -263,19 +263,16 @@ export class SkRateLimiterService extends RateLimiterService {
return JSON.parse(value); return JSON.parse(value);
} }
private async setLimitCounter(limit: SupportedRateLimit, actor: string, counter: LimitCounter, expirationMs: number, subject?: string): Promise<void> { private async setLimitCounter(limit: SupportedRateLimit, actor: string, counter: LimitCounter, expiration: number, subject: string): Promise<void> {
const key = createLimitKey(limit, actor, subject); const key = createLimitKey(limit, actor, subject);
const value = JSON.stringify(counter); const value = JSON.stringify(counter);
await this.redisClient.set(key, value, 'PX', expirationMs); const expirationSec = Math.max(expiration, 1);
await this.redisClient.set(key, value, 'EX', expirationSec);
} }
} }
function createLimitKey(limit: SupportedRateLimit, actor: string, subject?: string): string { function createLimitKey(limit: SupportedRateLimit, actor: string, subject: string): string {
if (subject) { return `rl_${actor}_${limit.key}_${subject}`;
return `rl_${actor}_${limit.key}_${subject}`;
} else {
return `rl_${actor}_${limit.key}`;
}
} }
export interface LimitCounter { export interface LimitCounter {

View file

@ -4,6 +4,7 @@
*/ */
import { KEYWORD } from 'color-convert/conversions.js'; import { KEYWORD } from 'color-convert/conversions.js';
import { jest } from '@jest/globals';
import type Redis from 'ioredis'; import type Redis from 'ioredis';
import { LegacyRateLimit, LimitCounter, RateLimit, SkRateLimiterService } from '@/server/api/SkRateLimiterService.js'; import { LegacyRateLimit, LimitCounter, RateLimit, SkRateLimiterService } from '@/server/api/SkRateLimiterService.js';
import { LoggerService } from '@/core/LoggerService.js'; import { LoggerService } from '@/core/LoggerService.js';
@ -98,7 +99,7 @@ describe(SkRateLimiterService, () => {
minCounter = undefined; minCounter = undefined;
mockRedisGet = (key: string) => { mockRedisGet = (key: string) => {
if (key === 'rl_actor_test' && counter) { if (key === 'rl_actor_test_bucket' && counter) {
return JSON.stringify(counter); return JSON.stringify(counter);
} }
@ -112,7 +113,7 @@ describe(SkRateLimiterService, () => {
mockRedisSet = (args: unknown[]) => { mockRedisSet = (args: unknown[]) => {
const [key, value] = args; const [key, value] = args;
if (key === 'rl_actor_test') { if (key === 'rl_actor_test_bucket') {
if (value == null) counter = undefined; if (value == null) counter = undefined;
else if (typeof(value) === 'string') counter = JSON.parse(value); else if (typeof(value) === 'string') counter = JSON.parse(value);
else throw new Error('invalid redis call'); else throw new Error('invalid redis call');
@ -280,12 +281,12 @@ describe(SkRateLimiterService, () => {
}); });
it('should set key expiration', async () => { it('should set key expiration', async () => {
mockRedisSet = args => { const mock = jest.fn(mockRedisSet);
expect(args[2]).toBe('PX'); mockRedisSet = mock;
expect(args[3]).toBe(1000);
};
await serviceUnderTest().limit(limit, actor); await serviceUnderTest().limit(limit, actor);
expect(mock).toHaveBeenCalledWith(['rl_actor_test_bucket', '{"t":0,"c":1}', 'EX', 1]);
}); });
it('should not increment when already blocked', async () => { it('should not increment when already blocked', async () => {
@ -434,12 +435,12 @@ describe(SkRateLimiterService, () => {
}); });
it('should set key expiration', async () => { it('should set key expiration', async () => {
mockRedisSet = args => { const mock = jest.fn(mockRedisSet);
expect(args[2]).toBe('PX'); mockRedisSet = mock;
expect(args[3]).toBe(1000);
};
await serviceUnderTest().limit(limit, actor); await serviceUnderTest().limit(limit, actor);
expect(mock).toHaveBeenCalledWith(['rl_actor_test_min', '{"t":0,"c":1}', 'EX', 1]);
}); });
it('should not increment when already blocked', async () => { it('should not increment when already blocked', async () => {
@ -547,12 +548,12 @@ describe(SkRateLimiterService, () => {
}); });
it('should set key expiration', async () => { it('should set key expiration', async () => {
mockRedisSet = args => { const mock = jest.fn(mockRedisSet);
expect(args[2]).toBe('PX'); mockRedisSet = mock;
expect(args[3]).toBe(1000);
};
await serviceUnderTest().limit(limit, actor); await serviceUnderTest().limit(limit, actor);
expect(mock).toHaveBeenCalledWith(['rl_actor_test_bucket', '{"t":0,"c":1}', 'EX', 1]);
}); });
it('should not increment when already blocked', async () => { it('should not increment when already blocked', async () => {
@ -670,12 +671,13 @@ describe(SkRateLimiterService, () => {
}); });
it('should set key expiration', async () => { it('should set key expiration', async () => {
mockRedisSet = args => { const mock = jest.fn(mockRedisSet);
expect(args[2]).toBe('PX'); mockRedisSet = mock;
expect(args[3]).toBe(1000);
};
await serviceUnderTest().limit(limit, actor); await serviceUnderTest().limit(limit, actor);
expect(mock).toHaveBeenCalledWith(['rl_actor_test_bucket', '{"t":0,"c":1}', 'EX', 1]);
expect(mock).toHaveBeenCalledWith(['rl_actor_test_min', '{"t":0,"c":1}', 'EX', 1]);
}); });
it('should not increment when already blocked', async () => { it('should not increment when already blocked', async () => {