#pragma once #include "basis/seadRawPrint.h" #include "container/seadFreeList.h" #include "prim/seadBitUtil.h" #include "prim/seadDelegate.h" #include "prim/seadSafeString.h" namespace sead { template class TreeMapNode; /// Sorted associative container, implemented using a left-leaning red-black tree. /// For an explanation of the algorithm, see https://www.cs.princeton.edu/~rs/talks/LLRB/LLRB.pdf template class TreeMapImpl { public: using Node = TreeMapNode; void insert(Node* node); void erase(const Key& key); void clear(); Node* find(const Key& key) const { return find(mRoot, key); } template void forEach(const Callable& callable) const { if (mRoot) forEach(mRoot, callable); } Node* startIterating() const { if (!mRoot) return nullptr; return startIterating(mRoot); } Node* nextNode(Node* node) const { if (!node) return nullptr; // If there is a right child node, explore that branch first. if (node->mRight) { node->mRight->setParent(node); return startIterating(node->mRight); } // Otherwise, walk back up to the node P from which we reached this node // by following P's left child pointer. while (auto* const parent = node->getParent()) { if (parent->mLeft == node) return parent; node = parent; } return nullptr; } protected: /// Returns the left most child of a given node, marking each node with its parent /// along the way. static Node* startIterating(Node* node) { while (node->mLeft) { node->mLeft->setParent(node); node = node->mLeft; } return node; } Node* insert(Node* root, Node* node); Node* erase(Node* root, const Key& key); Node* find(Node* root, const Key& key) const; static inline Node* rotateLeft(Node* node); static inline Node* rotateRight(Node* node); static inline Node* moveRedLeft(Node* node); static inline Node* moveRedRight(Node* node); static Node* findMin(Node* node); static Node* eraseMin(Node* node); static inline Node* fixUp(Node* node); static bool isRed(const Node* node) { return node && node->isRed(); } static inline void flipColors(Node* node); template static void forEach(Node* start, const Callable& callable); Node* mRoot = nullptr; }; /// Requires Key to have a compare() member function, which returns -1 if lhs < rhs, 0 if lhs = rhs /// and 1 if lhs > rhs. template class TreeMapNode { public: TreeMapNode() { mLeft = mRight = nullptr; mColorAndPtr = 0; } virtual ~TreeMapNode() = default; virtual void erase_() = 0; const Key& key() const { return mKey; } protected: friend class TreeMapImpl; enum class Color { Red = 0, Black = 1, }; void flipColor() { BitUtil::bitCastWrite(mColorAndPtr ^ 1u, &mColorAndPtr); } void setColor(Color color) { mColorAndPtr = uintptr_t(color); } void setParent(TreeMapNode* parent) { mColorAndPtr = (mColorAndPtr & 1) | uintptr_t(parent); } /// @warning Only valid if setParent has been called! TreeMapNode* getParent() const { return reinterpret_cast(mColorAndPtr & ~1); } bool isRed() const { return (mColorAndPtr & 1u) == bool(Color::Red); } TreeMapNode* mLeft; TreeMapNode* mRight; uintptr_t mColorAndPtr; Key mKey; }; /// Requires Key to have operator< defined /// This can be specialized, but all specializations must define `compare` and `key` as follows. template struct TreeMapKeyImpl { TreeMapKeyImpl() = default; TreeMapKeyImpl(const Key& key_) : key(key_) {} TreeMapKeyImpl& operator=(const Key& key_) { key = key_; return *this; } /// Returns -1 if mKey < rhs, 0 if mKey = rhs and 1 if mKey > rhs. s32 compare(const TreeMapKeyImpl& rhs) const { if (key < rhs.key) return -1; if (rhs.key < key) return 1; return 0; } Key key; }; /// Sorted associative container. /// This is essentially std::map template class TreeMap : public TreeMapImpl> { public: using MapImpl = TreeMapImpl>; class Node : public MapImpl::Node { public: Node(TreeMap* map, const Key& key, const Value& value) : mValue(value), mMap(map) { this->mKey = key; } void erase_() override; Value& value() { return mValue; } const Value& value() const { return mValue; } private: friend class TreeMap; Value mValue; TreeMap* mMap; }; ~TreeMap(); void allocBuffer(s32 node_max, Heap* heap, s32 alignment = sizeof(void*)); void setBuffer(s32 node_max, void* buffer); void freeBuffer(); Value* insert(const Key& key, const Value& value); void clear(); Node* find(const Key& key) const; // Callable must have the signature Key&, Value& template void forEach(const Callable& delegate) const; Node* startIterating() const { return static_cast(MapImpl::startIterating()); } Node* nextNode(Node* node) const { return static_cast(MapImpl::nextNode(node)); } private: void eraseNodeForClear_(typename MapImpl::Node* node); FreeList mFreeList; s32 mSize = 0; s32 mCapacity = 0; }; template class IntrusiveTreeMap : public TreeMapImpl { public: using MapImpl = TreeMapImpl; Node* find(const Key& key) const { return static_cast(MapImpl::find(key)); } // Callable must have the signature Node* template void forEach(const Callable& delegate) const { MapImpl::forEach([delegate](auto* base_node) { auto* node = static_cast(base_node); delegate(node); }); } Node* startIterating() const { return static_cast(MapImpl::startIterating()); } Node* nextNode(Node* node) const { return static_cast(MapImpl::nextNode(node)); } }; template inline void TreeMapImpl::insert(Node* node) { mRoot = insert(mRoot, node); mRoot->setColor(Node::Color::Black); } template inline TreeMapNode* TreeMapImpl::insert(Node* root, Node* node) { if (!root) { node->mLeft = node->mRight = nullptr; node->setColor(Node::Color::Red); return node; } const s32 cmp = node->key().compare(root->key()); if (cmp < 0) { root->mLeft = insert(root->mLeft, node); } else if (cmp > 0) { root->mRight = insert(root->mRight, node); } else if (root != node) { node->mRight = root->mRight; node->mLeft = root->mLeft; node->mColorAndPtr = root->mColorAndPtr; root->erase_(); root = node; } if (isRed(root->mRight) && !isRed(root->mLeft)) root = rotateLeft(root); if (isRed(root->mLeft) && isRed(root->mLeft->mLeft)) root = rotateRight(root); if (isRed(root->mLeft) && isRed(root->mRight)) flipColors(root); return root; } template inline void TreeMapImpl::erase(const Key& key) { mRoot = erase(mRoot, key); if (mRoot) mRoot->setColor(Node::Color::Black); } template inline TreeMapNode* TreeMapImpl::erase(Node* root, const Key& key) { if (key.compare(root->key()) < 0) { if (!isRed(root->mLeft) && !isRed(root->mLeft->mLeft)) root = moveRedLeft(root); root->mLeft = erase(root->mLeft, key); } else { if (isRed(root->mLeft)) root = rotateRight(root); if (key.compare(root->key()) == 0 && !root->mRight) { root->erase_(); return nullptr; } if (!isRed(root->mRight) && !isRed(root->mRight->mLeft)) root = moveRedRight(root); if (key.compare(root->key()) == 0) { Node* const min_node = findMin(root->mRight); Node* target = root->mRight; if (root->mRight) target = find(root->mRight, min_node->key()); target->mRight = eraseMin(root->mRight); target->mLeft = root->mLeft; target->mColorAndPtr = root->mColorAndPtr; root->erase_(); root = target; } else { root->mRight = erase(root->mRight, key); } } return fixUp(root); } template inline void TreeMapImpl::clear() { mRoot = nullptr; } template inline TreeMapNode* TreeMapImpl::find(Node* root, const Key& key) const { Node* node = root; while (node) { const s32 cmp = key.compare(node->key()); if (cmp < 0) node = node->mLeft; else if (cmp > 0) node = node->mRight; else return node; } return nullptr; } template template inline void TreeMapImpl::forEach(Node* start, const Callable& callable) { Node* i = start; do { Node* node = i; if (i->mLeft) forEach(i->mLeft, callable); i = i->mRight; callable(node); } while (i); } template inline TreeMapNode* TreeMapImpl::rotateLeft(Node* node) { TreeMapNode* j = node->mRight; node->mRight = j->mLeft; j->mLeft = node; j->mColorAndPtr = node->mColorAndPtr; node->setColor(Node::Color::Red); return j; } template inline TreeMapNode* TreeMapImpl::rotateRight(Node* node) { TreeMapNode* j = node->mLeft; node->mLeft = j->mRight; j->mRight = node; j->mColorAndPtr = node->mColorAndPtr; node->setColor(Node::Color::Red); return j; } // NON_MATCHING: this version matches the LLRB tree implementation and is better optimized; // there is a useless store to node->mRight in the original version template inline TreeMapNode* TreeMapImpl::moveRedLeft(Node* node) { flipColors(node); if (isRed(node->mRight->mLeft)) { node->mRight = rotateRight(node->mRight); node = rotateLeft(node); flipColors(node); } return node; } template inline TreeMapNode* TreeMapImpl::moveRedRight(Node* node) { flipColors(node); if (isRed(node->mLeft->mLeft)) { node = rotateRight(node); flipColors(node); } return node; } template inline TreeMapNode* TreeMapImpl::findMin(Node* node) { while (node->mLeft) node = node->mLeft; return node; } // NON_MATCHING: this version matches the LLRB tree implementation and is better optimized template inline TreeMapNode* TreeMapImpl::eraseMin(Node* node) { if (!node->mLeft) return nullptr; if (!isRed(node->mLeft) && !isRed(node->mLeft->mLeft)) node = moveRedLeft(node); node->mLeft = eraseMin(node->mLeft); #ifdef MATCHING_HACK_NX_CLANG asm(""); #endif return fixUp(node); } template inline TreeMapNode* TreeMapImpl::fixUp(Node* node) { if (isRed(node->mRight)) node = rotateLeft(node); if (isRed(node->mLeft) && isRed(node->mLeft->mLeft)) node = rotateRight(node); if (isRed(node->mLeft) && isRed(node->mRight)) flipColors(node); return node; } template inline void TreeMapImpl::flipColors(Node* node) { node->flipColor(); node->mLeft->flipColor(); node->mRight->flipColor(); } template inline void TreeMap::Node::erase_() { TreeMap* const map = mMap; void* const this_ = this; // Note: Nintendo does not call the destructor, which is dangerous... map->mFreeList.free(this_); --map->mSize; } template inline TreeMap::~TreeMap() { void* work = mFreeList.work(); if (!work) return; clear(); freeBuffer(); } template inline void TreeMap::allocBuffer(s32 node_max, Heap* heap, s32 alignment) { SEAD_ASSERT(mFreeList.work() == nullptr); if (node_max <= 0) { SEAD_ASSERT_MSG(false, "node_max[%d] must be larger than zero", node_max); AllocFailAssert(heap, node_max * sizeof(Node), alignment); } void* work = AllocBuffer(node_max * sizeof(Node), heap, alignment); if (work) setBuffer(node_max, work); } template inline void TreeMap::setBuffer(s32 node_max, void* buffer) { mCapacity = node_max; mFreeList.setWork(buffer, sizeof(Node), node_max); } template inline void TreeMap::freeBuffer() { void* buffer = mFreeList.work(); if (!buffer) return; ::operator delete[](buffer); mCapacity = 0; mFreeList.reset(); } template inline Value* TreeMap::insert(const Key& key, const Value& value) { if (mSize >= mCapacity) { if (Node* node = find(key)) { node->value() = value; return &node->value(); } SEAD_ASSERT_MSG(false, "map is full."); return nullptr; } Node* node = new (mFreeList.alloc()) Node(this, key, value); ++mSize; MapImpl::insert(node); return &node->value(); } template inline void TreeMap::clear() { Delegate1, typename MapImpl::Node*> delegate(this, &TreeMap::eraseNodeForClear_); MapImpl::forEach(delegate); mSize = 0; MapImpl::clear(); } template inline typename TreeMap::Node* TreeMap::find(const Key& key) const { return static_cast(MapImpl::find(key)); } template template inline void TreeMap::forEach(const Callable& delegate) const { MapImpl::forEach([&delegate](auto* base_node) { auto* node = static_cast(base_node); delegate(node->key(), node->value()); }); } template inline void TreeMap::eraseNodeForClear_(typename MapImpl::Node* node) { // Note: Nintendo does not call the destructor, which is dangerous... mFreeList.free(node); } } // namespace sead