Commit e719db5ab5 for aom
commit e719db5ab559946d8a5755e9bd930523f369f90b
Author: Li Zhang <li.zhang2@arm.com>
Date: Wed Oct 29 09:52:39 2025 +0100
Optimize 8-tap Neon I8MM av1_dist_wtd_convolve_2d implementation
Decomposing the 8-tap filter to a 7-tap filter followed by a 1-tap
filter enables us to use the Neon I8MM USMMLA instructions - which is
faster than the existing USDOT approach.
Change-Id: I183ceb7d1adb7ae7f205db715587e13513c5e383
diff --git a/av1/common/arm/compound_convolve_neon_i8mm.c b/av1/common/arm/compound_convolve_neon_i8mm.c
index 4f22e427e1..fa913080fc 100644
--- a/av1/common/arm/compound_convolve_neon_i8mm.c
+++ b/av1/common/arm/compound_convolve_neon_i8mm.c
@@ -180,31 +180,29 @@ static inline void dist_wtd_convolve_2d_horiz_6tap_neon_i8mm(
}
static inline int16x8_t convolve8_8_2d_h(uint8x16_t samples,
- const int8x8_t x_filter,
- const uint8x16x3_t permute_tbl,
- const int32x4_t horiz_const) {
- uint8x16_t permuted_samples[3];
- int32x4_t sum[2];
+ const int8x16_t x_filter,
+ const uint8x8_t f0,
+ const uint8x16x2_t permute_tbl,
+ const uint16x8_t horiz_const) {
+ // Permute samples ready for matrix multiply.
+ // { 1, 2, 3, 4, 5, 6, 7, 8, 3, 4, 5, 6, 7, 8, 9, 10 }
+ // { 5, 6, 7, 8, 9, 10, 11, 12, 7, 8, 9, 10, 11, 12, 13, 14 }
+ uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
+ vqtbl1q_u8(samples, permute_tbl.val[1]) };
- // Permute samples ready for dot product.
- // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
- permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
- // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
- permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
- // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
- permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
+ // Calculate partial 7-tap convolution.
+ int32x4_t sum0123 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[0], x_filter);
+ int32x4_t sum4567 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[1], x_filter);
+ uint16x8_t sum = vreinterpretq_u16_s16(
+ vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567)));
- // First 4 output values.
- sum[0] = vusdotq_lane_s32(horiz_const, permuted_samples[0], x_filter, 0);
- sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
- // Second 4 output values.
- sum[1] = vusdotq_lane_s32(horiz_const, permuted_samples[1], x_filter, 0);
- sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
+ // Apply tap 0 and accumulate.
+ sum = vmlsl_u8(sum, vget_low_u8(samples), f0);
+
+ sum = vaddq_u16(sum, horiz_const);
- // Narrow and re-pack.
// We halved the convolution filter values so -1 from the right shift.
- return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
- vshrn_n_s32(sum[1], ROUND0_BITS - 1));
+ return vreinterpretq_s16_u16(vshrq_n_u16(sum, ROUND0_BITS - 1));
}
static inline void dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(
@@ -214,12 +212,23 @@ static inline void dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(
// A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
// shifts - which are generally faster than rounding shifts on modern CPUs.
// (The extra -1 is needed because we halved the filter values.)
- const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
- (1 << ((ROUND0_BITS - 1) - 1)));
+ const uint16x8_t horiz_const = vdupq_n_u16((1 << (bd + FILTER_BITS - 2)) +
+ (1 << ((ROUND0_BITS - 1) - 1)));
+
+ const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul8PermuteTbl);
- const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
// Filter values are even, so halve to reduce intermediate precision reqs.
- const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+ const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+
+ // Stagger the filter for use with the matrix multiply instructions.
+ // { f1, f2, f3, f4, f5, f6, f7, 0, 0, f1, f2, f3, f4, f5, f6, f7 }
+ const uint8x16_t filter_idx = vld1q_u8(kFilterPermuteTbl);
+ const int8x16_t x_filter =
+ vqtbl1q_s8(vcombine_s8(x_filter_s8, vdup_n_s8(0)), filter_idx);
+
+ // Since f0 is always negative and s0 is unsigned, subtract (unsigned) s0 *
+ // -f0 to avoid signed overflow.
+ const uint8x8_t f0 = vdup_n_u8(-x_filter_ptr[0] >> 1);
const uint8_t *src_ptr = src;
int16_t *dst_ptr = im_block;
@@ -235,10 +244,14 @@ static inline void dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
- int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
- int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, permute_tbl, horiz_const);
- int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, permute_tbl, horiz_const);
- int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, permute_tbl, horiz_const);
+ int16x8_t d0 =
+ convolve8_8_2d_h(s0, x_filter, f0, permute_tbl, horiz_const);
+ int16x8_t d1 =
+ convolve8_8_2d_h(s1, x_filter, f0, permute_tbl, horiz_const);
+ int16x8_t d2 =
+ convolve8_8_2d_h(s2, x_filter, f0, permute_tbl, horiz_const);
+ int16x8_t d3 =
+ convolve8_8_2d_h(s3, x_filter, f0, permute_tbl, horiz_const);
store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
@@ -259,7 +272,8 @@ static inline void dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(
do {
uint8x16_t s0 = vld1q_u8(s);
- int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
+ int16x8_t d0 =
+ convolve8_8_2d_h(s0, x_filter, f0, permute_tbl, horiz_const);
vst1q_s16(d, d0);