Merge pull request #2784 from ReinUsesLisp/smem
shader_ir: Implement shared memory
This commit is contained in:
commit
b31880dc5e
5 changed files with 81 additions and 21 deletions
|
@ -325,6 +325,7 @@ public:
|
||||||
DeclareRegisters();
|
DeclareRegisters();
|
||||||
DeclarePredicates();
|
DeclarePredicates();
|
||||||
DeclareLocalMemory();
|
DeclareLocalMemory();
|
||||||
|
DeclareSharedMemory();
|
||||||
DeclareInternalFlags();
|
DeclareInternalFlags();
|
||||||
DeclareInputAttributes();
|
DeclareInputAttributes();
|
||||||
DeclareOutputAttributes();
|
DeclareOutputAttributes();
|
||||||
|
@ -499,6 +500,13 @@ private:
|
||||||
code.AddNewLine();
|
code.AddNewLine();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DeclareSharedMemory() {
|
||||||
|
if (stage != ProgramType::Compute) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
code.AddLine("shared uint {}[];", GetSharedMemory());
|
||||||
|
}
|
||||||
|
|
||||||
void DeclareInternalFlags() {
|
void DeclareInternalFlags() {
|
||||||
for (u32 flag = 0; flag < static_cast<u32>(InternalFlag::Amount); flag++) {
|
for (u32 flag = 0; flag < static_cast<u32>(InternalFlag::Amount); flag++) {
|
||||||
const auto flag_code = static_cast<InternalFlag>(flag);
|
const auto flag_code = static_cast<InternalFlag>(flag);
|
||||||
|
@ -881,6 +889,12 @@ private:
|
||||||
Type::Uint};
|
Type::Uint};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (const auto smem = std::get_if<SmemNode>(&*node)) {
|
||||||
|
return {
|
||||||
|
fmt::format("{}[{} >> 2]", GetSharedMemory(), Visit(smem->GetAddress()).AsUint()),
|
||||||
|
Type::Uint};
|
||||||
|
}
|
||||||
|
|
||||||
if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) {
|
if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) {
|
||||||
return {GetInternalFlag(internal_flag->GetFlag()), Type::Bool};
|
return {GetInternalFlag(internal_flag->GetFlag()), Type::Bool};
|
||||||
}
|
}
|
||||||
|
@ -1286,6 +1300,11 @@ private:
|
||||||
target = {
|
target = {
|
||||||
fmt::format("{}[{} >> 2]", GetLocalMemory(), Visit(lmem->GetAddress()).AsUint()),
|
fmt::format("{}[{} >> 2]", GetLocalMemory(), Visit(lmem->GetAddress()).AsUint()),
|
||||||
Type::Uint};
|
Type::Uint};
|
||||||
|
} else if (const auto smem = std::get_if<SmemNode>(&*dest)) {
|
||||||
|
ASSERT(stage == ProgramType::Compute);
|
||||||
|
target = {
|
||||||
|
fmt::format("{}[{} >> 2]", GetSharedMemory(), Visit(smem->GetAddress()).AsUint()),
|
||||||
|
Type::Uint};
|
||||||
} else if (const auto gmem = std::get_if<GmemNode>(&*dest)) {
|
} else if (const auto gmem = std::get_if<GmemNode>(&*dest)) {
|
||||||
const std::string real = Visit(gmem->GetRealAddress()).AsUint();
|
const std::string real = Visit(gmem->GetRealAddress()).AsUint();
|
||||||
const std::string base = Visit(gmem->GetBaseAddress()).AsUint();
|
const std::string base = Visit(gmem->GetBaseAddress()).AsUint();
|
||||||
|
@ -2175,6 +2194,10 @@ private:
|
||||||
return "lmem_" + suffix;
|
return "lmem_" + suffix;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string GetSharedMemory() const {
|
||||||
|
return fmt::format("smem_{}", suffix);
|
||||||
|
}
|
||||||
|
|
||||||
std::string GetInternalFlag(InternalFlag flag) const {
|
std::string GetInternalFlag(InternalFlag flag) const {
|
||||||
constexpr std::array InternalFlagNames = {"zero_flag", "sign_flag", "carry_flag",
|
constexpr std::array InternalFlagNames = {"zero_flag", "sign_flag", "carry_flag",
|
||||||
"overflow_flag"};
|
"overflow_flag"};
|
||||||
|
|
|
@ -35,7 +35,7 @@ u32 GetUniformTypeElementsCount(Tegra::Shader::UniformType uniform_type) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // Anonymous namespace
|
||||||
|
|
||||||
u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) {
|
u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) {
|
||||||
const Instruction instr = {program_code[pc]};
|
const Instruction instr = {program_code[pc]};
|
||||||
|
@ -106,16 +106,17 @@ u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OpCode::Id::LD_L: {
|
case OpCode::Id::LD_L:
|
||||||
LOG_DEBUG(HW_GPU, "LD_L cache management mode: {}",
|
LOG_DEBUG(HW_GPU, "LD_L cache management mode: {}", static_cast<u64>(instr.ld_l.unknown));
|
||||||
static_cast<u64>(instr.ld_l.unknown.Value()));
|
[[fallthrough]];
|
||||||
|
case OpCode::Id::LD_S: {
|
||||||
const auto GetLmem = [&](s32 offset) {
|
const auto GetMemory = [&](s32 offset) {
|
||||||
ASSERT(offset % 4 == 0);
|
ASSERT(offset % 4 == 0);
|
||||||
const Node immediate_offset = Immediate(static_cast<s32>(instr.smem_imm) + offset);
|
const Node immediate_offset = Immediate(static_cast<s32>(instr.smem_imm) + offset);
|
||||||
const Node address = Operation(OperationCode::IAdd, NO_PRECISE, GetRegister(instr.gpr8),
|
const Node address = Operation(OperationCode::IAdd, NO_PRECISE, GetRegister(instr.gpr8),
|
||||||
immediate_offset);
|
immediate_offset);
|
||||||
return GetLocalMemory(address);
|
return opcode->get().GetId() == OpCode::Id::LD_S ? GetSharedMemory(address)
|
||||||
|
: GetLocalMemory(address);
|
||||||
};
|
};
|
||||||
|
|
||||||
switch (instr.ldst_sl.type.Value()) {
|
switch (instr.ldst_sl.type.Value()) {
|
||||||
|
@ -135,14 +136,16 @@ u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}();
|
}();
|
||||||
for (u32 i = 0; i < count; ++i)
|
for (u32 i = 0; i < count; ++i) {
|
||||||
SetTemporary(bb, i, GetLmem(i * 4));
|
SetTemporary(bb, i, GetMemory(i * 4));
|
||||||
for (u32 i = 0; i < count; ++i)
|
}
|
||||||
|
for (u32 i = 0; i < count; ++i) {
|
||||||
SetRegister(bb, instr.gpr0.Value() + i, GetTemporary(i));
|
SetRegister(bb, instr.gpr0.Value() + i, GetTemporary(i));
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
UNIMPLEMENTED_MSG("LD_L Unhandled type: {}",
|
UNIMPLEMENTED_MSG("{} Unhandled type: {}", opcode->get().GetName(),
|
||||||
static_cast<u32>(instr.ldst_sl.type.Value()));
|
static_cast<u32>(instr.ldst_sl.type.Value()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
@ -209,27 +212,34 @@ u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) {
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OpCode::Id::ST_L: {
|
case OpCode::Id::ST_L:
|
||||||
LOG_DEBUG(HW_GPU, "ST_L cache management mode: {}",
|
LOG_DEBUG(HW_GPU, "ST_L cache management mode: {}",
|
||||||
static_cast<u64>(instr.st_l.cache_management.Value()));
|
static_cast<u64>(instr.st_l.cache_management.Value()));
|
||||||
|
[[fallthrough]];
|
||||||
const auto GetLmemAddr = [&](s32 offset) {
|
case OpCode::Id::ST_S: {
|
||||||
|
const auto GetAddress = [&](s32 offset) {
|
||||||
ASSERT(offset % 4 == 0);
|
ASSERT(offset % 4 == 0);
|
||||||
const Node immediate = Immediate(static_cast<s32>(instr.smem_imm) + offset);
|
const Node immediate = Immediate(static_cast<s32>(instr.smem_imm) + offset);
|
||||||
return Operation(OperationCode::IAdd, NO_PRECISE, GetRegister(instr.gpr8), immediate);
|
return Operation(OperationCode::IAdd, NO_PRECISE, GetRegister(instr.gpr8), immediate);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const auto set_memory = opcode->get().GetId() == OpCode::Id::ST_L
|
||||||
|
? &ShaderIR::SetLocalMemory
|
||||||
|
: &ShaderIR::SetSharedMemory;
|
||||||
|
|
||||||
switch (instr.ldst_sl.type.Value()) {
|
switch (instr.ldst_sl.type.Value()) {
|
||||||
case Tegra::Shader::StoreType::Bits128:
|
case Tegra::Shader::StoreType::Bits128:
|
||||||
SetLocalMemory(bb, GetLmemAddr(12), GetRegister(instr.gpr0.Value() + 3));
|
(this->*set_memory)(bb, GetAddress(12), GetRegister(instr.gpr0.Value() + 3));
|
||||||
SetLocalMemory(bb, GetLmemAddr(8), GetRegister(instr.gpr0.Value() + 2));
|
(this->*set_memory)(bb, GetAddress(8), GetRegister(instr.gpr0.Value() + 2));
|
||||||
|
[[fallthrough]];
|
||||||
case Tegra::Shader::StoreType::Bits64:
|
case Tegra::Shader::StoreType::Bits64:
|
||||||
SetLocalMemory(bb, GetLmemAddr(4), GetRegister(instr.gpr0.Value() + 1));
|
(this->*set_memory)(bb, GetAddress(4), GetRegister(instr.gpr0.Value() + 1));
|
||||||
|
[[fallthrough]];
|
||||||
case Tegra::Shader::StoreType::Bits32:
|
case Tegra::Shader::StoreType::Bits32:
|
||||||
SetLocalMemory(bb, GetLmemAddr(0), GetRegister(instr.gpr0));
|
(this->*set_memory)(bb, GetAddress(0), GetRegister(instr.gpr0));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
UNIMPLEMENTED_MSG("ST_L Unhandled type: {}",
|
UNIMPLEMENTED_MSG("{} unhandled type: {}", opcode->get().GetName(),
|
||||||
static_cast<u32>(instr.ldst_sl.type.Value()));
|
static_cast<u32>(instr.ldst_sl.type.Value()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -206,12 +206,13 @@ class PredicateNode;
|
||||||
class AbufNode;
|
class AbufNode;
|
||||||
class CbufNode;
|
class CbufNode;
|
||||||
class LmemNode;
|
class LmemNode;
|
||||||
|
class SmemNode;
|
||||||
class GmemNode;
|
class GmemNode;
|
||||||
class CommentNode;
|
class CommentNode;
|
||||||
|
|
||||||
using NodeData =
|
using NodeData =
|
||||||
std::variant<OperationNode, ConditionalNode, GprNode, ImmediateNode, InternalFlagNode,
|
std::variant<OperationNode, ConditionalNode, GprNode, ImmediateNode, InternalFlagNode,
|
||||||
PredicateNode, AbufNode, CbufNode, LmemNode, GmemNode, CommentNode>;
|
PredicateNode, AbufNode, CbufNode, LmemNode, SmemNode, GmemNode, CommentNode>;
|
||||||
using Node = std::shared_ptr<NodeData>;
|
using Node = std::shared_ptr<NodeData>;
|
||||||
using Node4 = std::array<Node, 4>;
|
using Node4 = std::array<Node, 4>;
|
||||||
using NodeBlock = std::vector<Node>;
|
using NodeBlock = std::vector<Node>;
|
||||||
|
@ -583,6 +584,19 @@ private:
|
||||||
Node address;
|
Node address;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Shared memory node
|
||||||
|
class SmemNode final {
|
||||||
|
public:
|
||||||
|
explicit SmemNode(Node address) : address{std::move(address)} {}
|
||||||
|
|
||||||
|
const Node& GetAddress() const {
|
||||||
|
return address;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Node address;
|
||||||
|
};
|
||||||
|
|
||||||
/// Global memory node
|
/// Global memory node
|
||||||
class GmemNode final {
|
class GmemNode final {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -137,6 +137,10 @@ Node ShaderIR::GetLocalMemory(Node address) {
|
||||||
return MakeNode<LmemNode>(std::move(address));
|
return MakeNode<LmemNode>(std::move(address));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Node ShaderIR::GetSharedMemory(Node address) {
|
||||||
|
return MakeNode<SmemNode>(std::move(address));
|
||||||
|
}
|
||||||
|
|
||||||
Node ShaderIR::GetTemporary(u32 id) {
|
Node ShaderIR::GetTemporary(u32 id) {
|
||||||
return GetRegister(Register::ZeroIndex + 1 + id);
|
return GetRegister(Register::ZeroIndex + 1 + id);
|
||||||
}
|
}
|
||||||
|
@ -378,6 +382,11 @@ void ShaderIR::SetLocalMemory(NodeBlock& bb, Node address, Node value) {
|
||||||
Operation(OperationCode::Assign, GetLocalMemory(std::move(address)), std::move(value)));
|
Operation(OperationCode::Assign, GetLocalMemory(std::move(address)), std::move(value)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ShaderIR::SetSharedMemory(NodeBlock& bb, Node address, Node value) {
|
||||||
|
bb.push_back(
|
||||||
|
Operation(OperationCode::Assign, GetSharedMemory(std::move(address)), std::move(value)));
|
||||||
|
}
|
||||||
|
|
||||||
void ShaderIR::SetTemporary(NodeBlock& bb, u32 id, Node value) {
|
void ShaderIR::SetTemporary(NodeBlock& bb, u32 id, Node value) {
|
||||||
SetRegister(bb, Register::ZeroIndex + 1 + id, std::move(value));
|
SetRegister(bb, Register::ZeroIndex + 1 + id, std::move(value));
|
||||||
}
|
}
|
||||||
|
|
|
@ -208,6 +208,8 @@ private:
|
||||||
Node GetInternalFlag(InternalFlag flag, bool negated = false);
|
Node GetInternalFlag(InternalFlag flag, bool negated = false);
|
||||||
/// Generates a node representing a local memory address
|
/// Generates a node representing a local memory address
|
||||||
Node GetLocalMemory(Node address);
|
Node GetLocalMemory(Node address);
|
||||||
|
/// Generates a node representing a shared memory address
|
||||||
|
Node GetSharedMemory(Node address);
|
||||||
/// Generates a temporary, internally it uses a post-RZ register
|
/// Generates a temporary, internally it uses a post-RZ register
|
||||||
Node GetTemporary(u32 id);
|
Node GetTemporary(u32 id);
|
||||||
|
|
||||||
|
@ -217,8 +219,10 @@ private:
|
||||||
void SetPredicate(NodeBlock& bb, u64 dest, Node src);
|
void SetPredicate(NodeBlock& bb, u64 dest, Node src);
|
||||||
/// Sets an internal flag. src value must be a bool-evaluated node
|
/// Sets an internal flag. src value must be a bool-evaluated node
|
||||||
void SetInternalFlag(NodeBlock& bb, InternalFlag flag, Node value);
|
void SetInternalFlag(NodeBlock& bb, InternalFlag flag, Node value);
|
||||||
/// Sets a local memory address. address and value must be a number-evaluated node
|
/// Sets a local memory address with a value.
|
||||||
void SetLocalMemory(NodeBlock& bb, Node address, Node value);
|
void SetLocalMemory(NodeBlock& bb, Node address, Node value);
|
||||||
|
/// Sets a shared memory address with a value.
|
||||||
|
void SetSharedMemory(NodeBlock& bb, Node address, Node value);
|
||||||
/// Sets a temporary. Internally it uses a post-RZ register
|
/// Sets a temporary. Internally it uses a post-RZ register
|
||||||
void SetTemporary(NodeBlock& bb, u32 id, Node value);
|
void SetTemporary(NodeBlock& bb, u32 id, Node value);
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue