From b8ca47e094738689e4e8470b0f1f00d60ee26f4c Mon Sep 17 00:00:00 2001 From: Ameer J <52414509+ameerj@users.noreply.github.com> Date: Sun, 30 Jul 2023 12:56:32 -0400 Subject: [PATCH] EncodingData pack --- src/video_core/host_shaders/astc_decoder.comp | 113 +++++++++++------- 1 file changed, 69 insertions(+), 44 deletions(-) diff --git a/src/video_core/host_shaders/astc_decoder.comp b/src/video_core/host_shaders/astc_decoder.comp index f720df6d2..37b502324 100644 --- a/src/video_core/host_shaders/astc_decoder.comp +++ b/src/video_core/host_shaders/astc_decoder.comp @@ -33,10 +33,7 @@ UNIFORM(6) uint block_height_mask; END_PUSH_CONSTANTS struct EncodingData { - uint encoding; - uint num_bits; - uint bit_value; - uint quint_trit_value; + uint data; }; struct TexelWeightParams { @@ -66,17 +63,47 @@ const int TRIT = 2; // ASTC Encodings data, sorted in ascending order based on their BitLength value // (see GetBitLength() function) -EncodingData encoding_values[22] = EncodingData[]( - EncodingData(JUST_BITS, 0, 0, 0), EncodingData(JUST_BITS, 1, 0, 0), EncodingData(TRIT, 0, 0, 0), - EncodingData(JUST_BITS, 2, 0, 0), EncodingData(QUINT, 0, 0, 0), EncodingData(TRIT, 1, 0, 0), - EncodingData(JUST_BITS, 3, 0, 0), EncodingData(QUINT, 1, 0, 0), EncodingData(TRIT, 2, 0, 0), - EncodingData(JUST_BITS, 4, 0, 0), EncodingData(QUINT, 2, 0, 0), EncodingData(TRIT, 3, 0, 0), - EncodingData(JUST_BITS, 5, 0, 0), EncodingData(QUINT, 3, 0, 0), EncodingData(TRIT, 4, 0, 0), - EncodingData(JUST_BITS, 6, 0, 0), EncodingData(QUINT, 4, 0, 0), EncodingData(TRIT, 5, 0, 0), - EncodingData(JUST_BITS, 7, 0, 0), EncodingData(QUINT, 5, 0, 0), EncodingData(TRIT, 6, 0, 0), - EncodingData(JUST_BITS, 8, 0, 0) +const EncodingData encoding_values[22] = EncodingData[]( + EncodingData(JUST_BITS), EncodingData(JUST_BITS | (1u << 8u)), EncodingData(TRIT), EncodingData(JUST_BITS | (2u << 8u)), + EncodingData(QUINT), EncodingData(TRIT | (1u << 8u)), EncodingData(JUST_BITS | (3u << 8u)), EncodingData(QUINT | (1u << 8u)), + EncodingData(TRIT | (2u << 8u)), EncodingData(JUST_BITS | (4u << 8u)), EncodingData(QUINT | (2u << 8u)), EncodingData(TRIT | (3u << 8u)), + EncodingData(JUST_BITS | (5u << 8u)), EncodingData(QUINT | (3u << 8u)), EncodingData(TRIT | (4u << 8u)), EncodingData(JUST_BITS | (6u << 8u)), + EncodingData(QUINT | (4u << 8u)), EncodingData(TRIT | (5u << 8u)), EncodingData(JUST_BITS | (7u << 8u)), EncodingData(QUINT | (5u << 8u)), + EncodingData(TRIT | (6u << 8u)), EncodingData(JUST_BITS | (8u << 8u)) ); +// EncodingData helpers +uint Encoding(EncodingData val) { + return bitfieldExtract(val.data, 0, 8); +} +uint NumBits(EncodingData val) { + return bitfieldExtract(val.data, 8, 8); +} +uint BitValue(EncodingData val) { + return bitfieldExtract(val.data, 16, 8); +} +uint QuintTritValue(EncodingData val) { + return bitfieldExtract(val.data, 24, 8); +} + +void Encoding(inout EncodingData val, uint v) { + val.data = bitfieldInsert(val.data, v, 0, 8); +} +void NumBits(inout EncodingData val, uint v) { + val.data = bitfieldInsert(val.data, v, 8, 8); +} +void BitValue(inout EncodingData val, uint v) { + val.data = bitfieldInsert(val.data, v, 16, 8); +} +void QuintTritValue(inout EncodingData val, uint v) { + val.data = bitfieldInsert(val.data, v, 24, 8); +} + +EncodingData CreateEncodingData(uint encoding, uint num_bits, uint bit_val, uint quint_trit_val) { + return EncodingData(((encoding) << 0u) | ((num_bits) << 8u) | + ((bit_val) << 16u) | ((quint_trit_val) << 24u)); +} + // The following constants are expanded variants of the Replicate() // function calls corresponding to the following arguments: // value: index into the generated table @@ -379,10 +406,12 @@ void ResultEmplaceBack(EncodingData val) { // Returns the number of bits required to encode n_vals values. uint GetBitLength(uint n_vals, uint encoding_index) { - uint total_bits = encoding_values[encoding_index].num_bits * n_vals; - if (encoding_values[encoding_index].encoding == TRIT) { + const EncodingData encoding_value = encoding_values[encoding_index]; + const uint encoding = Encoding(encoding_value); + uint total_bits = NumBits(encoding_value) * n_vals; + if (encoding == TRIT) { total_bits += Div5Ceil(n_vals * 8); - } else if (encoding_values[encoding_index].encoding == QUINT) { + } else if (encoding == QUINT) { total_bits += Div3Ceil(n_vals * 7); } return total_bits; @@ -451,11 +480,7 @@ void DecodeQuintBlock(uint num_bits) { } } for (uint i = 0; i < 3; i++) { - EncodingData val; - val.encoding = QUINT; - val.num_bits = num_bits; - val.bit_value = m[i]; - val.quint_trit_value = q[i]; + const EncodingData val = CreateEncodingData(QUINT, num_bits, m[i], q[i]); ResultEmplaceBack(val); } } @@ -503,30 +528,28 @@ void DecodeTritBlock(uint num_bits) { t[0] = (BitsBracket(C, 1) << 1) | (BitsBracket(C, 0) & ~BitsBracket(C, 1)); } for (uint i = 0; i < 5; i++) { - EncodingData val; - val.encoding = TRIT; - val.num_bits = num_bits; - val.bit_value = m[i]; - val.quint_trit_value = t[i]; + const EncodingData val = CreateEncodingData(TRIT, num_bits, m[i], t[i]); ResultEmplaceBack(val); } } void DecodeIntegerSequence(uint max_range, uint num_values) { EncodingData val = encoding_values[max_range]; + const uint encoding = Encoding(val); + const uint num_bits = NumBits(val); uint vals_decoded = 0; while (vals_decoded < num_values) { - switch (val.encoding) { + switch (encoding) { case QUINT: - DecodeQuintBlock(val.num_bits); + DecodeQuintBlock(num_bits); vals_decoded += 3; break; case TRIT: - DecodeTritBlock(val.num_bits); + DecodeTritBlock(num_bits); vals_decoded += 5; break; case JUST_BITS: - val.bit_value = StreamColorBits(val.num_bits); + BitValue(val, StreamColorBits(num_bits)); ResultEmplaceBack(val); vals_decoded++; break; @@ -554,17 +577,18 @@ void DecodeColorValues(uvec4 modes, uint num_partitions, uint color_data_bits) { if (out_index >= num_values) { break; } - EncodingData val = result_vector[itr]; - uint bitlen = val.num_bits; - uint bitval = val.bit_value; + const EncodingData val = result_vector[itr]; + const uint encoding = Encoding(val); + const uint bitlen = NumBits(val); + const uint bitval = BitValue(val); uint A = 0, B = 0, C = 0, D = 0; A = ReplicateBitTo9((bitval & 1)); - switch (val.encoding) { + switch (encoding) { case JUST_BITS: color_values[out_index++] = FastReplicateTo8(bitval, bitlen); break; case TRIT: { - D = val.quint_trit_value; + D = QuintTritValue(val); switch (bitlen) { case 1: C = 204; @@ -603,7 +627,7 @@ void DecodeColorValues(uvec4 modes, uint num_partitions, uint color_data_bits) { break; } case QUINT: { - D = val.quint_trit_value; + D = QuintTritValue(val); switch (bitlen) { case 1: C = 113; @@ -636,7 +660,7 @@ void DecodeColorValues(uvec4 modes, uint num_partitions, uint color_data_bits) { break; } } - if (val.encoding != JUST_BITS) { + if (encoding != JUST_BITS) { uint T = (D * C) + B; T ^= A; T = (A & 0x80) | (T >> 2); @@ -806,17 +830,18 @@ void ComputeEndpoints(out uvec4 ep1, out uvec4 ep2, uint color_endpoint_mode) { } uint UnquantizeTexelWeight(EncodingData val) { - uint bitval = val.bit_value; - uint bitlen = val.num_bits; - uint A = ReplicateBitTo7((bitval & 1)); + const uint encoding = Encoding(val); + const uint bitlen = NumBits(val); + const uint bitval = BitValue(val); + const uint A = ReplicateBitTo7((bitval & 1)); uint B = 0, C = 0, D = 0; uint result = 0; - switch (val.encoding) { + switch (encoding) { case JUST_BITS: result = FastReplicateTo6(bitval, bitlen); break; case TRIT: { - D = val.quint_trit_value; + D = QuintTritValue(val); switch (bitlen) { case 0: { uint results[3] = {0, 32, 63}; @@ -845,7 +870,7 @@ uint UnquantizeTexelWeight(EncodingData val) { break; } case QUINT: { - D = val.quint_trit_value; + D = QuintTritValue(val); switch (bitlen) { case 0: { uint results[5] = {0, 16, 32, 47, 63}; @@ -866,7 +891,7 @@ uint UnquantizeTexelWeight(EncodingData val) { break; } } - if (val.encoding != JUST_BITS && bitlen > 0) { + if (encoding != JUST_BITS && bitlen > 0) { result = D * C + B; result ^= A; result = (A & 0x20) | (result >> 2);