vk_shader_compiler: Implement the decompiler in SPIR-V

This commit is contained in:
Fernando Sahmkow 2019-08-25 15:32:00 -04:00 committed by FernandoS27
parent 0366c18d87
commit ca9901867e
3 changed files with 301 additions and 23 deletions

View File

@ -88,6 +88,9 @@ bool IsPrecise(Operation operand) {
} // namespace } // namespace
class ASTDecompiler;
class ExprDecompiler;
class SPIRVDecompiler : public Sirit::Module { class SPIRVDecompiler : public Sirit::Module {
public: public:
explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage) explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage)
@ -97,27 +100,7 @@ public:
AddExtension("SPV_KHR_variable_pointers"); AddExtension("SPV_KHR_variable_pointers");
} }
void Decompile() { void DecompileBranchMode() {
AllocateBindings();
AllocateLabels();
DeclareVertex();
DeclareGeometry();
DeclareFragment();
DeclareRegisters();
DeclarePredicates();
DeclareLocalMemory();
DeclareInternalFlags();
DeclareInputAttributes();
DeclareOutputAttributes();
DeclareConstantBuffers();
DeclareGlobalBuffers();
DeclareSamplers();
execute_function =
Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
Emit(OpLabel());
const u32 first_address = ir.GetBasicBlocks().begin()->first; const u32 first_address = ir.GetBasicBlocks().begin()->first;
const Id loop_label = OpLabel("loop"); const Id loop_label = OpLabel("loop");
const Id merge_label = OpLabel("merge"); const Id merge_label = OpLabel("merge");
@ -174,6 +157,43 @@ public:
Emit(continue_label); Emit(continue_label);
Emit(OpBranch(loop_label)); Emit(OpBranch(loop_label));
Emit(merge_label); Emit(merge_label);
}
void DecompileAST();
void Decompile() {
const bool is_fully_decompiled = ir.IsDecompiled();
AllocateBindings();
if (!is_fully_decompiled) {
AllocateLabels();
}
DeclareVertex();
DeclareGeometry();
DeclareFragment();
DeclareRegisters();
DeclarePredicates();
if (is_fully_decompiled) {
DeclareFlowVariables();
}
DeclareLocalMemory();
DeclareInternalFlags();
DeclareInputAttributes();
DeclareOutputAttributes();
DeclareConstantBuffers();
DeclareGlobalBuffers();
DeclareSamplers();
execute_function =
Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
Emit(OpLabel());
if (is_fully_decompiled) {
DecompileAST();
} else {
DecompileBranchMode();
}
Emit(OpReturn()); Emit(OpReturn());
Emit(OpFunctionEnd()); Emit(OpFunctionEnd());
} }
@ -206,6 +226,9 @@ public:
} }
private: private:
friend class ASTDecompiler;
friend class ExprDecompiler;
static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount); static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
void AllocateBindings() { void AllocateBindings() {
@ -294,6 +317,14 @@ private:
} }
} }
void DeclareFlowVariables() {
for (u32 i = 0; i < ir.GetASTNumVariables(); i++) {
const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
Name(id, fmt::format("flow_var_{}", static_cast<u32>(i)));
flow_variables.emplace(i, AddGlobalVariable(id));
}
}
void DeclareLocalMemory() { void DeclareLocalMemory() {
if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) { if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) {
const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4); const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4);
@ -1019,7 +1050,7 @@ private:
return {}; return {};
} }
Id Exit(Operation operation) { Id PreExit() {
switch (stage) { switch (stage) {
case ShaderStage::Vertex: { case ShaderStage::Vertex: {
// TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't
@ -1067,6 +1098,11 @@ private:
} }
} }
return {};
}
Id Exit(Operation operation) {
PreExit();
BranchingOp([&]() { Emit(OpReturn()); }); BranchingOp([&]() { Emit(OpReturn()); });
return {}; return {};
} }
@ -1545,6 +1581,7 @@ private:
Id per_vertex{}; Id per_vertex{};
std::map<u32, Id> registers; std::map<u32, Id> registers;
std::map<Tegra::Shader::Pred, Id> predicates; std::map<Tegra::Shader::Pred, Id> predicates;
std::map<u32, Id> flow_variables;
Id local_memory{}; Id local_memory{};
std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{}; std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
std::map<Attribute::Index, Id> input_attributes; std::map<Attribute::Index, Id> input_attributes;
@ -1580,6 +1617,223 @@ private:
std::map<u32, Id> labels; std::map<u32, Id> labels;
}; };
class ExprDecompiler {
public:
ExprDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}
void operator()(VideoCommon::Shader::ExprAnd& expr) {
const Id type_def = decomp.GetTypeDefinition(Type::Bool);
const Id op1 = Visit(expr.operand1);
const Id op2 = Visit(expr.operand2);
current_id = decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2));
}
void operator()(VideoCommon::Shader::ExprOr& expr) {
const Id type_def = decomp.GetTypeDefinition(Type::Bool);
const Id op1 = Visit(expr.operand1);
const Id op2 = Visit(expr.operand2);
current_id = decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2));
}
void operator()(VideoCommon::Shader::ExprNot& expr) {
const Id type_def = decomp.GetTypeDefinition(Type::Bool);
const Id op1 = Visit(expr.operand1);
current_id = decomp.Emit(decomp.OpLogicalNot(type_def, op1));
}
void operator()(VideoCommon::Shader::ExprPredicate& expr) {
auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate);
current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred)));
}
void operator()(VideoCommon::Shader::ExprCondCode& expr) {
Node cc = decomp.ir.GetConditionCode(expr.cc);
Id target;
if (const auto pred = std::get_if<PredicateNode>(&*cc)) {
const auto index = pred->GetIndex();
switch (index) {
case Tegra::Shader::Pred::NeverExecute:
target = decomp.v_false;
case Tegra::Shader::Pred::UnusedIndex:
target = decomp.v_true;
default:
target = decomp.predicates.at(index);
}
} else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) {
target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag()));
}
current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, target));
}
void operator()(VideoCommon::Shader::ExprVar& expr) {
current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index)));
}
void operator()(VideoCommon::Shader::ExprBoolean& expr) {
current_id = expr.value ? decomp.v_true : decomp.v_false;
}
Id GetResult() {
return current_id;
}
Id Visit(VideoCommon::Shader::Expr& node) {
std::visit(*this, *node);
return current_id;
}
private:
Id current_id;
SPIRVDecompiler& decomp;
};
class ASTDecompiler {
public:
ASTDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}
void operator()(VideoCommon::Shader::ASTProgram& ast) {
ASTNode current = ast.nodes.GetFirst();
while (current) {
Visit(current);
current = current->GetNext();
}
}
void operator()(VideoCommon::Shader::ASTIfThen& ast) {
ExprDecompiler expr_parser{decomp};
const Id condition = expr_parser.Visit(ast.condition);
const Id then_label = decomp.OpLabel();
const Id endif_label = decomp.OpLabel();
decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
decomp.Emit(then_label);
ASTNode current = ast.nodes.GetFirst();
while (current) {
Visit(current);
current = current->GetNext();
}
decomp.Emit(endif_label);
}
void operator()(VideoCommon::Shader::ASTIfElse& ast) {
UNREACHABLE();
}
void operator()(VideoCommon::Shader::ASTBlockEncoded& ast) {
UNREACHABLE();
}
void operator()(VideoCommon::Shader::ASTBlockDecoded& ast) {
decomp.VisitBasicBlock(ast.nodes);
}
void operator()(VideoCommon::Shader::ASTVarSet& ast) {
ExprDecompiler expr_parser{decomp};
const Id condition = expr_parser.Visit(ast.condition);
decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition));
}
void operator()(VideoCommon::Shader::ASTLabel& ast) {
// Do nothing
}
void operator()(VideoCommon::Shader::ASTGoto& ast) {
UNREACHABLE();
}
void operator()(VideoCommon::Shader::ASTDoWhile& ast) {
const Id loop_label = decomp.OpLabel();
const Id endloop_label = decomp.OpLabel();
const Id loop_start_block = decomp.OpLabel();
const Id loop_end_block = decomp.OpLabel();
current_loop_exit = endloop_label;
decomp.Emit(loop_label);
decomp.Emit(decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone));
decomp.Emit(decomp.OpBranch(loop_start_block));
decomp.Emit(loop_start_block);
ASTNode current = ast.nodes.GetFirst();
while (current) {
Visit(current);
current = current->GetNext();
}
decomp.Emit(decomp.OpBranch(loop_end_block));
decomp.Emit(loop_end_block);
ExprDecompiler expr_parser{decomp};
const Id condition = expr_parser.Visit(ast.condition);
decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label));
decomp.Emit(endloop_label);
}
void operator()(VideoCommon::Shader::ASTReturn& ast) {
bool is_true = VideoCommon::Shader::ExprIsTrue(ast.condition);
if (!is_true) {
ExprDecompiler expr_parser{decomp};
const Id condition = expr_parser.Visit(ast.condition);
const Id then_label = decomp.OpLabel();
const Id endif_label = decomp.OpLabel();
decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
decomp.Emit(then_label);
if (ast.kills) {
decomp.Emit(decomp.OpKill());
} else {
decomp.PreExit();
decomp.Emit(decomp.OpReturn());
}
decomp.Emit(endif_label);
} else {
decomp.Emit(decomp.OpLabel());
if (ast.kills) {
decomp.Emit(decomp.OpKill());
} else {
decomp.PreExit();
decomp.Emit(decomp.OpReturn());
}
decomp.Emit(decomp.OpLabel());
}
}
void operator()(VideoCommon::Shader::ASTBreak& ast) {
bool is_true = VideoCommon::Shader::ExprIsTrue(ast.condition);
if (!is_true) {
ExprDecompiler expr_parser{decomp};
const Id condition = expr_parser.Visit(ast.condition);
const Id then_label = decomp.OpLabel();
const Id endif_label = decomp.OpLabel();
decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
decomp.Emit(then_label);
decomp.Emit(decomp.OpBranch(current_loop_exit));
decomp.Emit(endif_label);
} else {
decomp.Emit(decomp.OpLabel());
decomp.Emit(decomp.OpBranch(current_loop_exit));
decomp.Emit(decomp.OpLabel());
}
}
void Visit(VideoCommon::Shader::ASTNode& node) {
std::visit(*this, *node->GetInnerData());
}
private:
SPIRVDecompiler& decomp;
Id current_loop_exit;
};
void SPIRVDecompiler::DecompileAST() {
u32 num_flow_variables = ir.GetASTNumVariables();
for (u32 i = 0; i < num_flow_variables; i++) {
const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
Name(id, fmt::format("flow_var_{}", i));
flow_variables.emplace(i, AddGlobalVariable(id));
}
ASTDecompiler decompiler{*this};
VideoCommon::Shader::ASTNode program = ir.GetASTProgram();
decompiler.Visit(program);
}
DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
Maxwell::ShaderStage stage) { Maxwell::ShaderStage stage) {
auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage); auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage);

View File

@ -205,13 +205,29 @@ public:
return nullptr; return nullptr;
} }
void MarkLabelUnused() const { void MarkLabelUnused() {
auto inner = std::get_if<ASTLabel>(&data); auto inner = std::get_if<ASTLabel>(&data);
if (inner) { if (inner) {
inner->unused = true; inner->unused = true;
} }
} }
bool IsLabelUnused() const {
auto inner = std::get_if<ASTLabel>(&data);
if (inner) {
return inner->unused;
}
return true;
}
u32 GetLabelIndex() const {
auto inner = std::get_if<ASTLabel>(&data);
if (inner) {
return inner->index;
}
return -1;
}
Expr GetIfCondition() const { Expr GetIfCondition() const {
auto inner = std::get_if<ASTIfThen>(&data); auto inner = std::get_if<ASTIfThen>(&data);
if (inner) { if (inner) {
@ -336,6 +352,10 @@ public:
return variables; return variables;
} }
const std::vector<ASTNode>& GetLabels() const {
return labels;
}
private: private:
bool IsBackwardsJump(ASTNode goto_node, ASTNode label_node) const; bool IsBackwardsJump(ASTNode goto_node, ASTNode label_node) const;

View File

@ -151,6 +151,10 @@ public:
return decompiled; return decompiled;
} }
const ASTManager& GetASTManager() const {
return program_manager;
}
ASTNode GetASTProgram() const { ASTNode GetASTProgram() const {
return program_manager.GetProgram(); return program_manager.GetProgram();
} }