Added Lua allow-list for cobject pointers

This commit is contained in:
MysterD 2022-01-23 16:35:43 -08:00
parent fe11e25e0b
commit b45c61a605
8 changed files with 112 additions and 1 deletions

View file

@ -513,6 +513,7 @@
<ClCompile Include="..\src\pc\ini.c" /> <ClCompile Include="..\src\pc\ini.c" />
<ClCompile Include="..\src\pc\lua\smlua.c" /> <ClCompile Include="..\src\pc\lua\smlua.c" />
<ClCompile Include="..\src\pc\lua\smlua_cobject.c" /> <ClCompile Include="..\src\pc\lua\smlua_cobject.c" />
<ClCompile Include="..\src\pc\lua\smlua_cobject_allowlist.c" />
<ClCompile Include="..\src\pc\lua\smlua_constants_autogen.c" /> <ClCompile Include="..\src\pc\lua\smlua_constants_autogen.c" />
<ClCompile Include="..\src\pc\lua\smlua_functions.c" /> <ClCompile Include="..\src\pc\lua\smlua_functions.c" />
<ClCompile Include="..\src\pc\lua\smlua_functions_autogen.c" /> <ClCompile Include="..\src\pc\lua\smlua_functions_autogen.c" />
@ -968,6 +969,7 @@
<ClInclude Include="..\src\pc\djui\djui_types.h" /> <ClInclude Include="..\src\pc\djui\djui_types.h" />
<ClInclude Include="..\src\pc\lua\smlua.h" /> <ClInclude Include="..\src\pc\lua\smlua.h" />
<ClInclude Include="..\src\pc\lua\smlua_cobject.h" /> <ClInclude Include="..\src\pc\lua\smlua_cobject.h" />
<ClInclude Include="..\src\pc\lua\smlua_cobject_allowlist.h" />
<ClInclude Include="..\src\pc\lua\smlua_functions.h" /> <ClInclude Include="..\src\pc\lua\smlua_functions.h" />
<ClInclude Include="..\src\pc\lua\smlua_functions_autogen.h" /> <ClInclude Include="..\src\pc\lua\smlua_functions_autogen.h" />
<ClInclude Include="..\src\pc\lua\smlua_hooks.h" /> <ClInclude Include="..\src\pc\lua\smlua_hooks.h" />

View file

@ -4860,6 +4860,9 @@
<ClCompile Include="..\src\pc\djui\djui_panel_host_mods.c"> <ClCompile Include="..\src\pc\djui\djui_panel_host_mods.c">
<Filter>Source Files\src\pc\djui\panel</Filter> <Filter>Source Files\src\pc\djui\panel</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="..\src\pc\lua\smlua_cobject_allowlist.c">
<Filter>Source Files\src\pc\lua</Filter>
</ClCompile>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="..\actors\common0.h"> <ClCompile Include="..\actors\common0.h">
@ -6001,5 +6004,8 @@
<ClInclude Include="..\src\pc\djui\djui_panel_host_mods.h"> <ClInclude Include="..\src\pc\djui\djui_panel_host_mods.h">
<Filter>Source Files\src\pc\djui\panel</Filter> <Filter>Source Files\src\pc\djui\panel</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\src\pc\lua\smlua_cobject_allowlist.h">
<Filter>Source Files\src\pc\lua</Filter>
</ClInclude>
</ItemGroup> </ItemGroup>
</Project> </Project>

View file

@ -65,6 +65,8 @@ static void smlua_init_mario_states(void) {
void smlua_init(void) { void smlua_init(void) {
smlua_shutdown(); smlua_shutdown();
smlua_cobject_allowlist_init();
gLuaState = luaL_newstate(); gLuaState = luaL_newstate();
lua_State* L = gLuaState; lua_State* L = gLuaState;
@ -108,6 +110,7 @@ void smlua_update(void) {
} }
void smlua_shutdown(void) { void smlua_shutdown(void) {
smlua_cobject_allowlist_shutdown();
lua_State* L = gLuaState; lua_State* L = gLuaState;
if (L != NULL) { if (L != NULL) {
lua_close(L); lua_close(L);

View file

@ -5,9 +5,11 @@
#include <lualib.h> #include <lualib.h>
#include <lauxlib.h> #include <lauxlib.h>
#include <stdbool.h>
#include "types.h" #include "types.h"
#include "smlua_cobject.h" #include "smlua_cobject.h"
#include "smlua_cobject_allowlist.h"
#include "smlua_utils.h" #include "smlua_utils.h"
#include "smlua_functions.h" #include "smlua_functions.h"
#include "smlua_functions_autogen.h" #include "smlua_functions_autogen.h"

View file

@ -313,6 +313,11 @@ static int smlua__get_field(lua_State* L) {
return 0; return 0;
} }
if (!smlua_cobject_allowlist_contains(lot, pointer)) {
LOG_LUA("_get_field received a pointer not in allow list. '%u', '%llu", lot, (u64)pointer);
return 0;
}
struct LuaObjectField* data = smlua_get_object_field(&sLuaObjectTable[lot], key); struct LuaObjectField* data = smlua_get_object_field(&sLuaObjectTable[lot], key);
if (data == NULL) { if (data == NULL) {
LOG_LUA("_get_field on invalid key '%s', lot '%d'", key, lot); LOG_LUA("_get_field on invalid key '%s', lot '%d'", key, lot);
@ -353,6 +358,11 @@ static int smlua__set_field(lua_State* L) {
return 0; return 0;
} }
if (!smlua_cobject_allowlist_contains(lot, pointer)) {
LOG_LUA("_set_field received a pointer not in allow list. '%u', '%llu", lot, (u64)pointer);
return 0;
}
struct LuaObjectField* data = smlua_get_object_field(&sLuaObjectTable[lot], key); struct LuaObjectField* data = smlua_get_object_field(&sLuaObjectTable[lot], key);
if (data == NULL) { if (data == NULL) {
LOG_LUA("_set_field on invalid key '%s'", key); LOG_LUA("_set_field on invalid key '%s'", key);

View file

@ -0,0 +1,68 @@
#include <stdio.h>
#include "smlua.h"
#pragma pack(1)
struct CObjectAllowListNode {
u64 pointer;
struct CObjectAllowListNode* next;
};
static struct CObjectAllowListNode* sAllowList[LOT_MAX] = { 0 };
static u16 sCachedAllowed[LOT_MAX] = { 0 };
void smlua_cobject_allowlist_init(void) {
smlua_cobject_allowlist_shutdown();
}
void smlua_cobject_allowlist_shutdown(void) {
for (int i = 0; i < LOT_MAX; i++) {
sCachedAllowed[i] = 0;
struct CObjectAllowListNode* node = sAllowList[i];
while (node != NULL) {
struct CObjectAllowListNode* nextNode = node->next;
free(node);
node = nextNode;
}
sAllowList[i] = NULL;
}
}
void smlua_cobject_allowlist_add(enum LuaObjectType objectType, u64 pointer) {
if (pointer == 0) { return; }
if (objectType == LOT_NONE || objectType >= LOT_MAX) { return; }
if (sCachedAllowed[objectType] == pointer) { return; }
sCachedAllowed[objectType] = pointer;
struct CObjectAllowListNode* curNode = sAllowList[objectType];
struct CObjectAllowListNode* prevNode = NULL;
while (curNode != NULL) {
if (pointer == curNode->pointer) { return; }
if (pointer < curNode->pointer) { break; }
prevNode = curNode;
curNode = curNode->next;
}
struct CObjectAllowListNode* node = malloc(sizeof(struct CObjectAllowListNode));
node->pointer = pointer;
node->next = curNode;
if (prevNode == NULL) {
sAllowList[objectType] = node;
} else {
prevNode->next = node;
}
}
bool smlua_cobject_allowlist_contains(enum LuaObjectType objectType, u64 pointer) {
if (pointer == 0) { return false; }
if (objectType == LOT_NONE || objectType >= LOT_MAX) { return false; }
if (sCachedAllowed[objectType] == pointer) { return true; }
struct CObjectAllowListNode* node = sAllowList[objectType];
while (node != NULL) {
if (pointer == node->pointer) { return true; }
if (pointer < node->pointer) { return false; }
node = node->next;
}
return false;
}

View file

@ -0,0 +1,9 @@
#ifndef SMLUA_COBJECT_ALLOWLIST_H
#define SMLUA_COBJECT_ALLOWLIST_H
void smlua_cobject_allowlist_init(void);
void smlua_cobject_allowlist_shutdown(void);
void smlua_cobject_allowlist_add(enum LuaObjectType objectType, u64 pointer);
bool smlua_cobject_allowlist_contains(enum LuaObjectType objectType, u64 pointer);
#endif

View file

@ -71,6 +71,7 @@ void* smlua_to_cobject(lua_State* L, int index, enum LuaObjectType lot) {
return 0; return 0;
} }
// get LOT
lua_getfield(L, index, "_lot"); lua_getfield(L, index, "_lot");
enum LuaObjectType objLot = smlua_to_integer(L, -1); enum LuaObjectType objLot = smlua_to_integer(L, -1);
lua_pop(L, 1); lua_pop(L, 1);
@ -83,11 +84,18 @@ void* smlua_to_cobject(lua_State* L, int index, enum LuaObjectType lot) {
return NULL; return NULL;
} }
// get pointer
lua_getfield(L, index, "_pointer"); lua_getfield(L, index, "_pointer");
void* pointer = (void*)smlua_to_integer(L, -1); void* pointer = (void*)smlua_to_integer(L, -1);
lua_pop(L, 1); lua_pop(L, 1);
if (!gSmLuaConvertSuccess) { return NULL; } if (!gSmLuaConvertSuccess) { return NULL; }
// TODO: check address whitelists
// check allowlist
if (!smlua_cobject_allowlist_contains(lot, (u64)pointer)) {
LOG_LUA("LUA: smlua_to_cobject received a pointer not in allow list. '%u', '%llu", lot, (u64)pointer);
gSmLuaConvertSuccess = false;
return NULL;
}
if (pointer == NULL) { if (pointer == NULL) {
LOG_LUA("LUA: smlua_to_cobject received null pointer."); LOG_LUA("LUA: smlua_to_cobject received null pointer.");
@ -107,6 +115,9 @@ void smlua_push_object(lua_State* L, enum LuaObjectType lot, void* p) {
lua_pushnil(L); lua_pushnil(L);
return; return;
} }
// add to allowlist
smlua_cobject_allowlist_add(lot, (u64)p);
lua_newtable(L); lua_newtable(L);
int t = lua_gettop(L); int t = lua_gettop(L);
smlua_push_integer_field(t, "_lot", lot); smlua_push_integer_field(t, "_lot", lot);