diff --git a/build-windows-visual-studio/sm64ex.vcxproj b/build-windows-visual-studio/sm64ex.vcxproj
index f67eaf21..e5f56dbc 100644
--- a/build-windows-visual-studio/sm64ex.vcxproj
+++ b/build-windows-visual-studio/sm64ex.vcxproj
@@ -513,6 +513,7 @@
+
@@ -968,6 +969,7 @@
+
diff --git a/build-windows-visual-studio/sm64ex.vcxproj.filters b/build-windows-visual-studio/sm64ex.vcxproj.filters
index 69402c77..ae381596 100644
--- a/build-windows-visual-studio/sm64ex.vcxproj.filters
+++ b/build-windows-visual-studio/sm64ex.vcxproj.filters
@@ -4860,6 +4860,9 @@
Source Files\src\pc\djui\panel
+
+ Source Files\src\pc\lua
+
@@ -6001,5 +6004,8 @@
Source Files\src\pc\djui\panel
+
+ Source Files\src\pc\lua
+
\ No newline at end of file
diff --git a/src/pc/lua/smlua.c b/src/pc/lua/smlua.c
index 83c3808f..f93e4d22 100644
--- a/src/pc/lua/smlua.c
+++ b/src/pc/lua/smlua.c
@@ -65,6 +65,8 @@ static void smlua_init_mario_states(void) {
void smlua_init(void) {
smlua_shutdown();
+ smlua_cobject_allowlist_init();
+
gLuaState = luaL_newstate();
lua_State* L = gLuaState;
@@ -108,6 +110,7 @@ void smlua_update(void) {
}
void smlua_shutdown(void) {
+ smlua_cobject_allowlist_shutdown();
lua_State* L = gLuaState;
if (L != NULL) {
lua_close(L);
diff --git a/src/pc/lua/smlua.h b/src/pc/lua/smlua.h
index d404ec4c..04e8afc0 100644
--- a/src/pc/lua/smlua.h
+++ b/src/pc/lua/smlua.h
@@ -5,9 +5,11 @@
#include
#include
+#include
#include "types.h"
#include "smlua_cobject.h"
+#include "smlua_cobject_allowlist.h"
#include "smlua_utils.h"
#include "smlua_functions.h"
#include "smlua_functions_autogen.h"
diff --git a/src/pc/lua/smlua_cobject.c b/src/pc/lua/smlua_cobject.c
index 0b07e659..2e32ab4a 100644
--- a/src/pc/lua/smlua_cobject.c
+++ b/src/pc/lua/smlua_cobject.c
@@ -313,6 +313,11 @@ static int smlua__get_field(lua_State* L) {
return 0;
}
+ if (!smlua_cobject_allowlist_contains(lot, pointer)) {
+ LOG_LUA("_get_field received a pointer not in allow list. '%u', '%llu", lot, (u64)pointer);
+ return 0;
+ }
+
struct LuaObjectField* data = smlua_get_object_field(&sLuaObjectTable[lot], key);
if (data == NULL) {
LOG_LUA("_get_field on invalid key '%s', lot '%d'", key, lot);
@@ -353,6 +358,11 @@ static int smlua__set_field(lua_State* L) {
return 0;
}
+ if (!smlua_cobject_allowlist_contains(lot, pointer)) {
+ LOG_LUA("_set_field received a pointer not in allow list. '%u', '%llu", lot, (u64)pointer);
+ return 0;
+ }
+
struct LuaObjectField* data = smlua_get_object_field(&sLuaObjectTable[lot], key);
if (data == NULL) {
LOG_LUA("_set_field on invalid key '%s'", key);
diff --git a/src/pc/lua/smlua_cobject_allowlist.c b/src/pc/lua/smlua_cobject_allowlist.c
new file mode 100644
index 00000000..e131a563
--- /dev/null
+++ b/src/pc/lua/smlua_cobject_allowlist.c
@@ -0,0 +1,68 @@
+#include
+#include "smlua.h"
+
+#pragma pack(1)
+struct CObjectAllowListNode {
+ u64 pointer;
+ struct CObjectAllowListNode* next;
+};
+
+static struct CObjectAllowListNode* sAllowList[LOT_MAX] = { 0 };
+static u16 sCachedAllowed[LOT_MAX] = { 0 };
+
+void smlua_cobject_allowlist_init(void) {
+ smlua_cobject_allowlist_shutdown();
+}
+
+void smlua_cobject_allowlist_shutdown(void) {
+ for (int i = 0; i < LOT_MAX; i++) {
+ sCachedAllowed[i] = 0;
+ struct CObjectAllowListNode* node = sAllowList[i];
+ while (node != NULL) {
+ struct CObjectAllowListNode* nextNode = node->next;
+ free(node);
+ node = nextNode;
+ }
+ sAllowList[i] = NULL;
+ }
+}
+
+void smlua_cobject_allowlist_add(enum LuaObjectType objectType, u64 pointer) {
+ if (pointer == 0) { return; }
+ if (objectType == LOT_NONE || objectType >= LOT_MAX) { return; }
+
+ if (sCachedAllowed[objectType] == pointer) { return; }
+ sCachedAllowed[objectType] = pointer;
+
+ struct CObjectAllowListNode* curNode = sAllowList[objectType];
+ struct CObjectAllowListNode* prevNode = NULL;
+ while (curNode != NULL) {
+ if (pointer == curNode->pointer) { return; }
+ if (pointer < curNode->pointer) { break; }
+ prevNode = curNode;
+ curNode = curNode->next;
+ }
+
+ struct CObjectAllowListNode* node = malloc(sizeof(struct CObjectAllowListNode));
+ node->pointer = pointer;
+ node->next = curNode;
+ if (prevNode == NULL) {
+ sAllowList[objectType] = node;
+ } else {
+ prevNode->next = node;
+ }
+}
+
+bool smlua_cobject_allowlist_contains(enum LuaObjectType objectType, u64 pointer) {
+ if (pointer == 0) { return false; }
+ if (objectType == LOT_NONE || objectType >= LOT_MAX) { return false; }
+ if (sCachedAllowed[objectType] == pointer) { return true; }
+
+ struct CObjectAllowListNode* node = sAllowList[objectType];
+ while (node != NULL) {
+ if (pointer == node->pointer) { return true; }
+ if (pointer < node->pointer) { return false; }
+ node = node->next;
+ }
+ return false;
+}
\ No newline at end of file
diff --git a/src/pc/lua/smlua_cobject_allowlist.h b/src/pc/lua/smlua_cobject_allowlist.h
new file mode 100644
index 00000000..68501991
--- /dev/null
+++ b/src/pc/lua/smlua_cobject_allowlist.h
@@ -0,0 +1,9 @@
+#ifndef SMLUA_COBJECT_ALLOWLIST_H
+#define SMLUA_COBJECT_ALLOWLIST_H
+
+void smlua_cobject_allowlist_init(void);
+void smlua_cobject_allowlist_shutdown(void);
+void smlua_cobject_allowlist_add(enum LuaObjectType objectType, u64 pointer);
+bool smlua_cobject_allowlist_contains(enum LuaObjectType objectType, u64 pointer);
+
+#endif
\ No newline at end of file
diff --git a/src/pc/lua/smlua_utils.c b/src/pc/lua/smlua_utils.c
index f3cced4c..f04ca5e9 100644
--- a/src/pc/lua/smlua_utils.c
+++ b/src/pc/lua/smlua_utils.c
@@ -71,6 +71,7 @@ void* smlua_to_cobject(lua_State* L, int index, enum LuaObjectType lot) {
return 0;
}
+ // get LOT
lua_getfield(L, index, "_lot");
enum LuaObjectType objLot = smlua_to_integer(L, -1);
lua_pop(L, 1);
@@ -83,11 +84,18 @@ void* smlua_to_cobject(lua_State* L, int index, enum LuaObjectType lot) {
return NULL;
}
+ // get pointer
lua_getfield(L, index, "_pointer");
void* pointer = (void*)smlua_to_integer(L, -1);
lua_pop(L, 1);
if (!gSmLuaConvertSuccess) { return NULL; }
- // TODO: check address whitelists
+
+ // check allowlist
+ if (!smlua_cobject_allowlist_contains(lot, (u64)pointer)) {
+ LOG_LUA("LUA: smlua_to_cobject received a pointer not in allow list. '%u', '%llu", lot, (u64)pointer);
+ gSmLuaConvertSuccess = false;
+ return NULL;
+ }
if (pointer == NULL) {
LOG_LUA("LUA: smlua_to_cobject received null pointer.");
@@ -107,6 +115,9 @@ void smlua_push_object(lua_State* L, enum LuaObjectType lot, void* p) {
lua_pushnil(L);
return;
}
+ // add to allowlist
+ smlua_cobject_allowlist_add(lot, (u64)p);
+
lua_newtable(L);
int t = lua_gettop(L);
smlua_push_integer_field(t, "_lot", lot);