vk_shader_decompiler: Implement OperationCode decompilation interface

This commit is contained in:
ReinUsesLisp 2019-03-14 02:36:14 -03:00
parent fec4eb9776
commit b758c861b0
1 changed files with 411 additions and 1 deletions

View File

@ -30,6 +30,7 @@ using namespace VideoCommon::Shader;
using Maxwell = Tegra::Engines::Maxwell3D::Regs;
using ShaderStage = Tegra::Engines::Maxwell3D::Regs::ShaderStage;
using Operation = const OperationNode&;
// TODO(Rodrigo): Use rasterizer's value
constexpr u32 MAX_CONSTBUFFER_ELEMENTS = 0x1000;
@ -73,6 +74,19 @@ constexpr u32 GetGenericAttributeLocation(Attribute::Index attribute) {
return static_cast<u32>(attribute) - static_cast<u32>(Attribute::Index::Attribute_0);
}
/// Returns true if an object has to be treated as precise
bool IsPrecise(Operation operand) {
const auto& meta = operand.GetMeta();
if (std::holds_alternative<MetaArithmetic>(meta)) {
return std::get<MetaArithmetic>(meta).precise;
}
if (std::holds_alternative<MetaHalfArithmetic>(meta)) {
return std::get<MetaHalfArithmetic>(meta).precise;
}
return false;
}
} // namespace
class SPIRVDecompiler : public Sirit::Module {
@ -194,6 +208,10 @@ public:
}
private:
using OperationDecompilerFn = Id (SPIRVDecompiler::*)(Operation);
using OperationDecompilersArray =
std::array<OperationDecompilerFn, static_cast<std::size_t>(OperationCode::Amount)>;
static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
static constexpr u32 CBUF_STRIDE = 16;
@ -468,7 +486,12 @@ private:
Id Visit(Node node) {
if (const auto operation = std::get_if<OperationNode>(node)) {
UNIMPLEMENTED();
const auto operation_index = static_cast<std::size_t>(operation->GetCode());
const auto decompiler = operation_decompilers[operation_index];
if (decompiler == nullptr) {
UNREACHABLE_MSG("Operation decompiler {} not defined", operation_index);
}
return (this->*decompiler)(*operation);
} else if (const auto gpr = std::get_if<GprNode>(node)) {
UNIMPLEMENTED();
@ -500,6 +523,184 @@ private:
return {};
}
template <Id (Module::*func)(Id, Id), Type result_type, Type type_a = result_type>
Id Unary(Operation operation) {
const Id type_def = GetTypeDefinition(result_type);
const Id op_a = VisitOperand<type_a>(operation, 0);
const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a)));
if (IsPrecise(operation)) {
Decorate(value, spv::Decoration::NoContraction);
}
return value;
}
template <Id (Module::*func)(Id, Id, Id), Type result_type, Type type_a = result_type,
Type type_b = type_a>
Id Binary(Operation operation) {
const Id type_def = GetTypeDefinition(result_type);
const Id op_a = VisitOperand<type_a>(operation, 0);
const Id op_b = VisitOperand<type_b>(operation, 1);
const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b)));
if (IsPrecise(operation)) {
Decorate(value, spv::Decoration::NoContraction);
}
return value;
}
template <Id (Module::*func)(Id, Id, Id, Id), Type result_type, Type type_a = result_type,
Type type_b = type_a, Type type_c = type_b>
Id Ternary(Operation operation) {
const Id type_def = GetTypeDefinition(result_type);
const Id op_a = VisitOperand<type_a>(operation, 0);
const Id op_b = VisitOperand<type_b>(operation, 1);
const Id op_c = VisitOperand<type_c>(operation, 2);
const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c)));
if (IsPrecise(operation)) {
Decorate(value, spv::Decoration::NoContraction);
}
return value;
}
template <Id (Module::*func)(Id, Id, Id, Id, Id), Type result_type, Type type_a = result_type,
Type type_b = type_a, Type type_c = type_b, Type type_d = type_c>
Id Quaternary(Operation operation) {
const Id type_def = GetTypeDefinition(result_type);
const Id op_a = VisitOperand<type_a>(operation, 0);
const Id op_b = VisitOperand<type_b>(operation, 1);
const Id op_c = VisitOperand<type_c>(operation, 2);
const Id op_d = VisitOperand<type_d>(operation, 3);
const Id value =
BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c, op_d)));
if (IsPrecise(operation)) {
Decorate(value, spv::Decoration::NoContraction);
}
return value;
}
Id Assign(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id HNegate(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id HMergeF32(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id HMergeH0(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id HMergeH1(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id HPack2(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id LogicalAssign(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id LogicalPick2(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id LogicalAll2(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id LogicalAny2(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id Texture(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id TextureLod(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id TextureGather(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id TextureQueryDimensions(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id TextureQueryLod(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id TexelFetch(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id Branch(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id PushFlowStack(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id PopFlowStack(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id Exit(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id Discard(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id EmitVertex(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id EndPrimitive(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id YNegate(Operation operation) {
UNIMPLEMENTED();
return {};
}
Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type,
const std::string& name) {
const Id id = OpVariable(type, storage);
@ -518,6 +719,215 @@ private:
return false;
}
template <Type type>
Id VisitOperand(Operation operation, std::size_t operand_index) {
const Id value = Visit(operation[operand_index]);
switch (type) {
case Type::Bool:
case Type::Bool2:
case Type::Float:
return value;
case Type::Int:
return Emit(OpBitcast(t_int, value));
case Type::Uint:
return Emit(OpBitcast(t_uint, value));
case Type::HalfFloat:
UNIMPLEMENTED();
}
UNREACHABLE();
return value;
}
template <Type type>
Id BitcastFrom(Id value) {
switch (type) {
case Type::Bool:
case Type::Bool2:
case Type::Float:
return value;
case Type::Int:
case Type::Uint:
return Emit(OpBitcast(t_float, value));
case Type::HalfFloat:
UNIMPLEMENTED();
}
UNREACHABLE();
return value;
}
template <Type type>
Id BitcastTo(Id value) {
switch (type) {
case Type::Bool:
case Type::Bool2:
UNREACHABLE();
case Type::Float:
return Emit(OpBitcast(t_float, value));
case Type::Int:
return Emit(OpBitcast(t_int, value));
case Type::Uint:
return Emit(OpBitcast(t_uint, value));
case Type::HalfFloat:
UNIMPLEMENTED();
}
UNREACHABLE();
return value;
}
Id GetTypeDefinition(Type type) {
switch (type) {
case Type::Bool:
return t_bool;
case Type::Bool2:
return t_bool2;
case Type::Float:
return t_float;
case Type::Int:
return t_int;
case Type::Uint:
return t_uint;
case Type::HalfFloat:
UNIMPLEMENTED();
}
UNREACHABLE();
return {};
}
static constexpr OperationDecompilersArray operation_decompilers = {
&SPIRVDecompiler::Assign,
&SPIRVDecompiler::Ternary<&Module::OpSelect, Type::Float, Type::Bool, Type::Float,
Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFAdd, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFMul, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFDiv, Type::Float>,
&SPIRVDecompiler::Ternary<&Module::OpFma, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpFNegate, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpFAbs, Type::Float>,
&SPIRVDecompiler::Ternary<&Module::OpFClamp, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFMin, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFMax, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpCos, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpSin, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpExp2, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpLog2, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpInverseSqrt, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpSqrt, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpRoundEven, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpFloor, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpCeil, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpTrunc, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpConvertSToF, Type::Float, Type::Int>,
&SPIRVDecompiler::Unary<&Module::OpConvertUToF, Type::Float, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpIAdd, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpIMul, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpSDiv, Type::Int>,
&SPIRVDecompiler::Unary<&Module::OpSNegate, Type::Int>,
&SPIRVDecompiler::Unary<&Module::OpSAbs, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpSMin, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpSMax, Type::Int>,
&SPIRVDecompiler::Unary<&Module::OpConvertFToS, Type::Int, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Int, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Int, Type::Int, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Int, Type::Int, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Int, Type::Int, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Int>,
&SPIRVDecompiler::Unary<&Module::OpNot, Type::Int>,
&SPIRVDecompiler::Quaternary<&Module::OpBitFieldInsert, Type::Int>,
&SPIRVDecompiler::Ternary<&Module::OpBitFieldSExtract, Type::Int>,
&SPIRVDecompiler::Unary<&Module::OpBitCount, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpIAdd, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpIMul, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpUDiv, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpUMin, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpUMax, Type::Uint>,
&SPIRVDecompiler::Unary<&Module::OpConvertFToU, Type::Uint, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Uint, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Uint>,
&SPIRVDecompiler::Unary<&Module::OpNot, Type::Uint>,
&SPIRVDecompiler::Quaternary<&Module::OpBitFieldInsert, Type::Uint>,
&SPIRVDecompiler::Ternary<&Module::OpBitFieldUExtract, Type::Uint>,
&SPIRVDecompiler::Unary<&Module::OpBitCount, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpFAdd, Type::HalfFloat>,
&SPIRVDecompiler::Binary<&Module::OpFMul, Type::HalfFloat>,
&SPIRVDecompiler::Ternary<&Module::OpFma, Type::HalfFloat>,
&SPIRVDecompiler::Unary<&Module::OpFAbs, Type::HalfFloat>,
&SPIRVDecompiler::HNegate,
&SPIRVDecompiler::HMergeF32,
&SPIRVDecompiler::HMergeH0,
&SPIRVDecompiler::HMergeH1,
&SPIRVDecompiler::HPack2,
&SPIRVDecompiler::LogicalAssign,
&SPIRVDecompiler::Binary<&Module::OpLogicalAnd, Type::Bool>,
&SPIRVDecompiler::Binary<&Module::OpLogicalOr, Type::Bool>,
&SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>,
&SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>,
&SPIRVDecompiler::LogicalPick2,
&SPIRVDecompiler::LogicalAll2,
&SPIRVDecompiler::LogicalAny2,
&SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool>,
&SPIRVDecompiler::Binary<&Module::OpSLessThan, Type::Bool, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpSLessThanEqual, Type::Bool, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpSGreaterThan, Type::Bool, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpSGreaterThanEqual, Type::Bool, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpULessThan, Type::Bool, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpULessThanEqual, Type::Bool, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpUGreaterThan, Type::Bool, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpUGreaterThanEqual, Type::Bool, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>,
&SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>,
&SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>,
&SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>,
&SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>,
&SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>,
&SPIRVDecompiler::Texture,
&SPIRVDecompiler::TextureLod,
&SPIRVDecompiler::TextureGather,
&SPIRVDecompiler::TextureQueryDimensions,
&SPIRVDecompiler::TextureQueryLod,
&SPIRVDecompiler::TexelFetch,
&SPIRVDecompiler::Branch,
&SPIRVDecompiler::PushFlowStack,
&SPIRVDecompiler::PopFlowStack,
&SPIRVDecompiler::Exit,
&SPIRVDecompiler::Discard,
&SPIRVDecompiler::EmitVertex,
&SPIRVDecompiler::EndPrimitive,
&SPIRVDecompiler::YNegate,
};
const ShaderIR& ir;
const ShaderStage stage;
const Tegra::Shader::Header header;