Commit 9113e3206b for aom
commit 9113e3206bb431cbc57131588b4bd96cd2cf48d6
Author: Li Zhang <li.zhang2@arm.com>
Date: Wed Oct 29 09:52:39 2025 +0100
Optimize 8-tap Neon I8MM av1_dist_wtd_convolve_x implementation
Decomposing the 8-tap filter to a 7-tap filter followed by a 1-tap
filter enables us to use the Neon I8MM USMMLA instructions - which is
faster than the existing USDOT approach for the cases of no averaging,
or basic averaging.
Change-Id: Icae80657f9f83543ff70ebe9f356d484d774b9f8
diff --git a/av1/common/arm/compound_convolve_neon_i8mm.c b/av1/common/arm/compound_convolve_neon_i8mm.c
index 0589dfb153..4f22e427e1 100644
--- a/av1/common/arm/compound_convolve_neon_i8mm.c
+++ b/av1/common/arm/compound_convolve_neon_i8mm.c
@@ -23,13 +23,26 @@ DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
};
-DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = {
+DECLARE_ALIGNED(16, static const uint8_t, kMatMul6PermuteTbl[32]) = {
// clang-format off
0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9,
4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13
// clang-format on
};
+DECLARE_ALIGNED(16, static const uint8_t, kMatMul8PermuteTbl[32]) = {
+ // clang-format off
+ 1, 2, 3, 4, 5, 6, 7, 8, 3, 4, 5, 6, 7, 8, 9, 10,
+ 5, 6, 7, 8, 9, 10, 11, 12, 7, 8, 9, 10, 11, 12, 13, 14
+ // clang-format on
+};
+
+DECLARE_ALIGNED(16, static const uint8_t, kFilterPermuteTbl[16]) = {
+ // clang-format off
+ 1, 2, 3, 4, 5, 6, 7, 16, 16, 1, 2, 3, 4, 5, 6, 7
+ // clang-format on
+};
+
static inline int16x4_t convolve6_4_2d_h(uint8x16_t samples,
const int8x16_t x_filter,
const uint8x16_t permute_tbl,
@@ -90,7 +103,7 @@ static inline void dist_wtd_convolve_2d_horiz_6tap_neon_i8mm(
int height = im_h;
if (w == 4) {
- const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
+ const uint8x16_t permute_tbl = vld1q_u8(kMatMul6PermuteTbl);
do {
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
@@ -118,7 +131,7 @@ static inline void dist_wtd_convolve_2d_horiz_6tap_neon_i8mm(
dst_ptr += dst_stride;
} while (--height != 0);
} else {
- const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
+ const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul6PermuteTbl);
do {
const uint8_t *s = src_ptr;
int16_t *d = dst_ptr;
@@ -366,10 +379,10 @@ static inline uint16x8_t convolve6_8_x(uint8x16_t samples,
return vreinterpretq_u16_s16(res);
}
-static inline uint16x8_t convolve8_8_x(uint8x16_t samples,
- const int8x8_t x_filter,
- const uint8x16x3_t permute_tbl,
- const int32x4_t round_offset) {
+static inline uint16x8_t convolve8_8_x_usdot(uint8x16_t samples,
+ const int8x8_t x_filter,
+ const uint8x16x3_t permute_tbl,
+ const int32x4_t round_offset) {
uint8x16_t permuted_samples[3];
int32x4_t sum[2];
@@ -395,6 +408,32 @@ static inline uint16x8_t convolve8_8_x(uint8x16_t samples,
return vreinterpretq_u16_s16(res);
}
+static inline uint16x8_t convolve8_8_x_usmmla(uint8x16_t samples,
+ const int8x16_t x_filter,
+ const uint8x8_t f0,
+ const uint8x16x2_t permute_tbl,
+ const uint16x8_t horiz_const) {
+ // Permute samples ready for matrix multiply.
+ // { 1, 2, 3, 4, 5, 6, 7, 8, 3, 4, 5, 6, 7, 8, 9, 10 }
+ // { 5, 6, 7, 8, 9, 10, 11, 12, 7, 8, 9, 10, 11, 12, 13, 14 }
+ uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
+ vqtbl1q_u8(samples, permute_tbl.val[1]) };
+
+ // Calculate partial 7-tap convolution.
+ int32x4_t sum0123 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[0], x_filter);
+ int32x4_t sum4567 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[1], x_filter);
+ uint16x8_t sum = vreinterpretq_u16_s16(
+ vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567)));
+
+ // Apply tap 0 and accumulate.
+ sum = vmlsl_u8(sum, vget_low_u8(samples), f0);
+
+ sum = vaddq_u16(sum, horiz_const);
+
+ // We halved the convolution filter values so -1 from the right shift.
+ return vshrq_n_u16(sum, ROUND0_BITS - 1);
+}
+
static inline void dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm(
const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr,
@@ -421,7 +460,7 @@ static inline void dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm(
vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
if (w == 4) {
- const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
+ const uint8x16_t permute_tbl = vld1q_u8(kMatMul6PermuteTbl);
do {
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
@@ -451,7 +490,7 @@ static inline void dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm(
h -= 4;
} while (h != 0);
} else {
- const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
+ const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul6PermuteTbl);
do {
const uint8_t *s = src;
uint16_t *d = dst;
@@ -527,13 +566,13 @@ static inline void dist_wtd_convolve_x_dist_wtd_avg_8tap_neon_i8mm(
load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
uint16x8_t d0 =
- convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
+ convolve8_8_x_usdot(s0, x_filter, permute_tbl, round_offset_shim);
uint16x8_t d1 =
- convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
+ convolve8_8_x_usdot(s1, x_filter, permute_tbl, round_offset_shim);
uint16x8_t d2 =
- convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
+ convolve8_8_x_usdot(s2, x_filter, permute_tbl, round_offset_shim);
uint16x8_t d3 =
- convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
+ convolve8_8_x_usdot(s3, x_filter, permute_tbl, round_offset_shim);
uint16x8_t dd0, dd1, dd2, dd3;
load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -582,7 +621,7 @@ static inline void dist_wtd_convolve_x_avg_6tap_neon_i8mm(
vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
if (w == 4) {
- const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
+ const uint8x16_t permute_tbl = vld1q_u8(kMatMul6PermuteTbl);
do {
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
@@ -612,7 +651,7 @@ static inline void dist_wtd_convolve_x_avg_6tap_neon_i8mm(
h -= 4;
} while (h != 0);
} else {
- const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
+ const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul6PermuteTbl);
do {
const uint8_t *s = src;
uint16_t *d = dst;
@@ -668,12 +707,23 @@ static inline void dist_wtd_convolve_x_avg_8tap_neon_i8mm(
// A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
// shifts - which are generally faster than rounding shifts on modern CPUs.
// (The extra -1 is needed because we halved the filter values.)
- const int32x4_t round_offset_shim = vdupq_n_s32(
+ const uint16x8_t round_offset_shim = vdupq_n_u16(
(round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
- const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
+ const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul8PermuteTbl);
+
// Filter values are even, so halve to reduce intermediate precision reqs.
- const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+ const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+
+ // Stagger the filter for use with the matrix multiply instructions.
+ // { f1, f2, f3, f4, f5, f6, f7, 0, 0, f1, f2, f3, f4, f5, f6, f7 }
+ const uint8x16_t filter_idx = vld1q_u8(kFilterPermuteTbl);
+ const int8x16_t x_filter =
+ vqtbl1q_s8(vcombine_s8(x_filter_s8, vdup_n_s8(0)), filter_idx);
+
+ // Since f0 is always negative and s0 is unsigned, subtract (unsigned) s0 *
+ // -f0 to avoid signed overflow.
+ const uint8x8_t f0 = vdup_n_u8(-x_filter_ptr[0] >> 1);
do {
const uint8_t *s = src;
@@ -685,14 +735,14 @@ static inline void dist_wtd_convolve_x_avg_8tap_neon_i8mm(
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
- uint16x8_t d0 =
- convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
- uint16x8_t d1 =
- convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
- uint16x8_t d2 =
- convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
- uint16x8_t d3 =
- convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
+ uint16x8_t d0 = convolve8_8_x_usmmla(s0, x_filter, f0, permute_tbl,
+ round_offset_shim);
+ uint16x8_t d1 = convolve8_8_x_usmmla(s1, x_filter, f0, permute_tbl,
+ round_offset_shim);
+ uint16x8_t d2 = convolve8_8_x_usmmla(s2, x_filter, f0, permute_tbl,
+ round_offset_shim);
+ uint16x8_t d3 = convolve8_8_x_usmmla(s3, x_filter, f0, permute_tbl,
+ round_offset_shim);
uint16x8_t dd0, dd1, dd2, dd3;
load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -739,7 +789,7 @@ static inline void dist_wtd_convolve_x_6tap_neon_i8mm(
vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
if (w == 4) {
- const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
+ const uint8x16_t permute_tbl = vld1q_u8(kMatMul6PermuteTbl);
do {
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
@@ -760,7 +810,7 @@ static inline void dist_wtd_convolve_x_6tap_neon_i8mm(
h -= 4;
} while (h != 0);
} else {
- const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
+ const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul6PermuteTbl);
do {
const uint8_t *s = src;
uint16_t *d = dst;
@@ -805,12 +855,23 @@ static inline void dist_wtd_convolve_x_8tap_neon_i8mm(
// A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
// shifts - which are generally faster than rounding shifts on modern CPUs.
// (The extra -1 is needed because we halved the filter values.)
- const int32x4_t round_offset_shim = vdupq_n_s32(
+ const uint16x8_t round_offset_shim = vdupq_n_u16(
(round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
- const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
+ const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul8PermuteTbl);
+
// Filter values are even, so halve to reduce intermediate precision reqs.
- const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+ const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+
+ // Stagger the filter for use with the matrix multiply instructions.
+ // { f1, f2, f3, f4, f5, f6, f7, 0, 0, f1, f2, f3, f4, f5, f6, f7 }
+ const uint8x16_t filter_idx = vld1q_u8(kFilterPermuteTbl);
+ const int8x16_t x_filter =
+ vqtbl1q_s8(vcombine_s8(x_filter_s8, vdup_n_s8(0)), filter_idx);
+
+ // Since f0 is always negative and s0 is unsigned, subtract (unsigned) s0 *
+ // -f0 to avoid signed overflow.
+ const uint8x8_t f0 = vdup_n_u8(-x_filter_ptr[0] >> 1);
do {
const uint8_t *s = src;
@@ -821,14 +882,14 @@ static inline void dist_wtd_convolve_x_8tap_neon_i8mm(
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
- uint16x8_t d0 =
- convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
- uint16x8_t d1 =
- convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
- uint16x8_t d2 =
- convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
- uint16x8_t d3 =
- convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
+ uint16x8_t d0 = convolve8_8_x_usmmla(s0, x_filter, f0, permute_tbl,
+ round_offset_shim);
+ uint16x8_t d1 = convolve8_8_x_usmmla(s1, x_filter, f0, permute_tbl,
+ round_offset_shim);
+ uint16x8_t d2 = convolve8_8_x_usmmla(s2, x_filter, f0, permute_tbl,
+ round_offset_shim);
+ uint16x8_t d3 = convolve8_8_x_usmmla(s3, x_filter, f0, permute_tbl,
+ round_offset_shim);
store_u16_8x4(d, dst_stride, d0, d1, d2, d3);