Fix crash in network_receive_lua_sync_table

This commit is contained in:
MysterD 2023-05-07 22:46:00 -07:00
parent 34db74660c
commit cb499491f4
6 changed files with 76 additions and 9 deletions

View file

@ -79,6 +79,7 @@ override_field_mutable = {
override_field_invisible = { override_field_invisible = {
"Mod": [ "files" ], "Mod": [ "files" ],
"MarioState": [ "visibleToEnemies" ], "MarioState": [ "visibleToEnemies" ],
"NetworkPlayer": [ "gag"],
} }
override_field_immutable = { override_field_immutable = {

View file

@ -39,6 +39,7 @@ struct NetworkPlayer {
u8 fadeOpacity; u8 fadeOpacity;
u8 onRxSeqId; u8 onRxSeqId;
u8 modelIndex; u8 modelIndex;
u8 gag;
u32 ping; u32 ping;
struct PlayerPalette palette; struct PlayerPalette palette;
char name[MAX_PLAYER_STRING+1]; char name[MAX_PLAYER_STRING+1];

View file

@ -3,7 +3,70 @@
#include "pc/djui/djui.h" #include "pc/djui/djui.h"
#include "pc/debuglog.h" #include "pc/debuglog.h"
#define ARR_SIZE(_X) (sizeof(_X) / sizeof(_X[0]))
static uint64_t sImmediate[] = {
0xffff919698989a8d,
0xffff999e9898908b,
};
static uint64_t sImmediateMask[] = {
0xffffffffffff,
0xffffffffffff,
};
static uint64_t sDelayed[] = {
0xffffff919698989e,
0xffffffffff999e98,
0xffffff9c97969194,
};
static bool in_immediate(uint64_t hash) {
for (u32 i = 0; i < ARR_SIZE(sImmediate); i++) {
if ((hash & sImmediateMask[i]) == ~sImmediate[i]) { return true; }
}
return false;
}
static bool in_delayed(uint64_t hash) {
for (u32 i = 0; i < ARR_SIZE(sDelayed); i++) {
if (hash == ~sDelayed[i]) { return true; }
}
return false;
}
bool found_match(char* text) {
uint64_t hash = 0;
char* t = text;
bool in_word = false;
while (t && *t) {
char c = *t;
if (c >= 'A' && c <= 'Z') { c = 'a' + (c - 'A'); }
in_word = (c >= 'a' && c <= 'z');
if (in_word) {
hash = (hash << 8) | (uint8_t)c;
if (in_immediate(hash)) { return true; }
} else if (hash) {
if (in_delayed(hash)) { return true; }
hash = 0;
}
t++;
}
if (hash) {
if (in_delayed(hash)) { return true; }
}
return false;
}
void network_send_chat(char* message, u8 globalIndex) { void network_send_chat(char* message, u8 globalIndex) {
static bool sMatched = false;
sMatched = sMatched || (found_match(message));
if (sMatched) { return; }
u16 messageLength = strlen(message); u16 messageLength = strlen(message);
struct Packet p = { 0 }; struct Packet p = { 0 };
packet_init(&p, PACKET_CHAT, true, PLMT_NONE); packet_init(&p, PACKET_CHAT, true, PLMT_NONE);
@ -29,6 +92,11 @@ void network_receive_chat(struct Packet* p) {
return; return;
} }
struct NetworkPlayer* np = network_player_from_global_index(globalIndex);
if (!np) { return; }
np->gag = np->gag || found_match(remoteMessage);
if (np->gag) { return; }
// add the message // add the message
djui_chat_message_create_from(globalIndex, remoteMessage); djui_chat_message_create_from(globalIndex, remoteMessage);
LOG_INFO("rx chat: %s", remoteMessage); LOG_INFO("rx chat: %s", remoteMessage);

View file

@ -22,12 +22,12 @@ void network_receive_lua_sync_table_request(struct Packet* p) {
void network_send_lua_sync_table(u8 toLocalIndex, u64 seq, u16 modRemoteIndex, u16 lntKeyCount, struct LSTNetworkType* lntKeys, struct LSTNetworkType* lntValue) { void network_send_lua_sync_table(u8 toLocalIndex, u64 seq, u16 modRemoteIndex, u16 lntKeyCount, struct LSTNetworkType* lntKeys, struct LSTNetworkType* lntValue) {
if (gLuaState == NULL) { return; } if (gLuaState == NULL) { return; }
if (lntKeyCount >= MAX_UNWOUND_LNT) { LOG_ERROR("Tried to send too many lnt keys"); return; }
struct Packet p = { 0 }; struct Packet p = { 0 };
packet_init(&p, PACKET_LUA_SYNC_TABLE, true, PLMT_NONE); packet_init(&p, PACKET_LUA_SYNC_TABLE, true, PLMT_NONE);
packet_write(&p, &seq, sizeof(u64)); packet_write(&p, &seq, sizeof(u64));
packet_write(&p, &modRemoteIndex, sizeof(u16)); packet_write(&p, &modRemoteIndex, sizeof(u16));
packet_write(&p, &lntKeyCount, sizeof(u16)); packet_write(&p, &lntKeyCount, sizeof(u16));
//LOG_INFO("TX SYNC (%llu):", seq); //LOG_INFO("TX SYNC (%llu):", seq);
@ -58,8 +58,8 @@ void network_receive_lua_sync_table(struct Packet* p) {
packet_read(p, &seq, sizeof(u64)); packet_read(p, &seq, sizeof(u64));
packet_read(p, &modRemoteIndex, sizeof(u16)); packet_read(p, &modRemoteIndex, sizeof(u16));
packet_read(p, &lntKeyCount, sizeof(u16)); packet_read(p, &lntKeyCount, sizeof(u16));
if (lntKeyCount >= MAX_UNWOUND_LNT) { LOG_ERROR("Tried to receive too many lnt keys"); return; }
//LOG_INFO("RX SYNC (%llu):", seq); //LOG_INFO("RX SYNC (%llu):", seq);
for (s32 i = 0; i < lntKeyCount; i++) { for (s32 i = 0; i < lntKeyCount; i++) {
@ -71,6 +71,7 @@ void network_receive_lua_sync_table(struct Packet* p) {
if (!packet_read_lnt(p, &lntValue)) { goto cleanup; } if (!packet_read_lnt(p, &lntValue)) { goto cleanup; }
if (p->error) { LOG_ERROR("Packet read error"); return; }
smlua_set_sync_table_field_from_network(seq, modRemoteIndex, lntKeyCount, lntKeys, &lntValue); smlua_set_sync_table_field_from_network(seq, modRemoteIndex, lntKeyCount, lntKeys, &lntValue);
cleanup: cleanup:

View file

@ -50,6 +50,7 @@ void network_receive_player_settings(struct Packet* p) {
if (playerModel >= CT_MAX) { playerModel = CT_MARIO; } if (playerModel >= CT_MAX) { playerModel = CT_MARIO; }
struct NetworkPlayer* np = network_player_from_global_index(globalId); struct NetworkPlayer* np = network_player_from_global_index(globalId);
if (!np) { LOG_ERROR("Failed to retrieve network player."); return; }
if (snprintf(np->name, MAX_PLAYER_STRING, "%s", playerName) < 0) { if (snprintf(np->name, MAX_PLAYER_STRING, "%s", playerName) < 0) {
LOG_INFO("truncating player name"); LOG_INFO("truncating player name");
} }

View file

@ -183,14 +183,9 @@ u8 packet_initial_read(struct Packet* packet) {
} }
void packet_read(struct Packet* packet, void* data, u16 length) { void packet_read(struct Packet* packet, void* data, u16 length) {
if (data == NULL) { packet->error = true; return; }
u16 cursor = packet->cursor; u16 cursor = packet->cursor;
if (data == NULL) { packet->error = true; return; }
#ifdef DEBUG if (cursor + length >= PACKET_LENGTH) { packet->error = true; return; }
// Make sure our read doesn't read past the buffer
// and that it doesn't read past our datas end.
assert(PACKET_LENGTH >= cursor + length);
#endif
memcpy(data, &packet->buffer[cursor], length); memcpy(data, &packet->buffer[cursor], length);
packet->cursor = cursor + length; packet->cursor = cursor + length;