diff --git a/packages/backend/src/server/api/StreamingApiServerService.ts b/packages/backend/src/server/api/StreamingApiServerService.ts index b8f448477b..7ac1bcf469 100644 --- a/packages/backend/src/server/api/StreamingApiServerService.ts +++ b/packages/backend/src/server/api/StreamingApiServerService.ts @@ -19,7 +19,12 @@ import { ChannelFollowingService } from '@/core/ChannelFollowingService.js'; import { AuthenticateService, AuthenticationError } from './AuthenticateService.js'; import MainStreamConnection from './stream/Connection.js'; import { ChannelsService } from './stream/ChannelsService.js'; +import { RateLimiterService } from './RateLimiterService.js'; +import { RoleService } from '@/core/RoleService.js'; +import { getIpHash } from '@/misc/get-ip-hash.js'; +import ms from 'ms'; import type * as http from 'node:http'; +import type { IEndpointMeta } from './endpoints.js'; @Injectable() export class StreamingApiServerService { @@ -41,9 +46,32 @@ export class StreamingApiServerService { private notificationService: NotificationService, private usersService: UserService, private channelFollowingService: ChannelFollowingService, + private rateLimiterService: RateLimiterService, + private roleService: RoleService, ) { } + @bindThis + private async rateLimitThis( + user: MiLocalUser | null | undefined, + requestIp: string | undefined, + limit: IEndpointMeta['limit'] & { key: NonNullable }, + ) : Promise { + let limitActor: string; + if (user) { + limitActor = user.id; + } else { + limitActor = getIpHash(requestIp || 'wtf'); + } + + const factor = user ? (await this.roleService.getUserPolicies(user.id)).rateLimitFactor : 1; + + if (factor <= 0) return false; + + // Rate limit + return await this.rateLimiterService.limit(limit, limitActor, factor).then(() => { return false }).catch(err => { return true }); + } + @bindThis public attach(server: http.Server): void { this.#wss = new WebSocket.WebSocketServer({ @@ -57,6 +85,17 @@ export class StreamingApiServerService { return; } + if (await this.rateLimitThis(null, request.socket.remoteAddress, { + key: 'wsconnect', + duration: ms('1min'), + max: 20, + minInterval: ms('1sec'), + })) { + socket.write('HTTP/1.1 429 Rate Limit Exceeded\r\n\r\n'); + socket.destroy(); + return; + } + const q = new URL(request.url, `http://${request.headers.host}`).searchParams; let user: MiLocalUser | null = null; @@ -94,6 +133,14 @@ export class StreamingApiServerService { return; } + const rateLimiter = () => { + return this.rateLimitThis(user, request.socket.remoteAddress, { + key: 'wsmessage', + duration: ms('1sec'), + max: 100, + }); + }; + const stream = new MainStreamConnection( this.channelsService, this.noteReadService, @@ -101,6 +148,7 @@ export class StreamingApiServerService { this.cacheService, this.channelFollowingService, user, app, + rateLimiter, ); await stream.init(); diff --git a/packages/backend/src/server/api/stream/Connection.ts b/packages/backend/src/server/api/stream/Connection.ts index 7dd7db24e5..dfc6f0d298 100644 --- a/packages/backend/src/server/api/stream/Connection.ts +++ b/packages/backend/src/server/api/stream/Connection.ts @@ -25,6 +25,7 @@ import type Channel from './channel.js'; export default class Connection { public user?: MiUser; public token?: MiAccessToken; + private rateLimiter?: () => Promise; private wsConnection: WebSocket.WebSocket; public subscriber: StreamEventEmitter; private channels: Channel[] = []; @@ -48,9 +49,11 @@ export default class Connection { user: MiUser | null | undefined, token: MiAccessToken | null | undefined, + rateLimiter: () => Promise, ) { if (user) this.user = user; if (token) this.token = token; + if (rateLimiter) this.rateLimiter = rateLimiter; } @bindThis @@ -103,6 +106,10 @@ export default class Connection { private async onWsConnectionMessage(data: WebSocket.RawData) { let obj: Record; + if (this.rateLimiter && await this.rateLimiter()) { + return; + } + try { obj = JSON.parse(data.toString()); } catch (e) {