Commit ce07b41f84 for aom
commit ce07b41f8492963b6a895a820f25044fce264c96
Author: Gerda Zsejke More <gerdazsejke.more@arm.com>
Date: Sun Nov 23 12:54:41 2025 +0100
Optimize Armv8.0 Neon impl of filter_intra_predictor_neon
Read intra predictor inputs directly from dst instead of an
intermediate buffer to eliminate redundant store/load operations.
This is a port from SVT-AV1:
https://gitlab.com/AOMediaCodec/SVT-AV1/-/merge_requests/2555
Change-Id: If52f0f166821fe85cedf801148fc72996ca0def8
diff --git a/av1/common/arm/reconintra_neon.c b/av1/common/arm/reconintra_neon.c
index dc0cb17eb1..95503aacc9 100644
--- a/av1/common/arm/reconintra_neon.c
+++ b/av1/common/arm/reconintra_neon.c
@@ -20,10 +20,11 @@
#include "aom_dsp/arm/sum_neon.h"
#define MAX_UPSAMPLE_SZ 16
+#define FILTER_INTRA_SCALE_BITS 4
// These kernels are a transposed version of those defined in reconintra.c,
// with the absolute value of the negatives taken in the top row.
-DECLARE_ALIGNED(16, const uint8_t,
+DECLARE_ALIGNED(16, static const uint8_t,
av1_filter_intra_taps_neon[FILTER_INTRA_MODES][7][8]) = {
// clang-format off
{
@@ -74,7 +75,22 @@ DECLARE_ALIGNED(16, const uint8_t,
// clang-format on
};
-#define FILTER_INTRA_SCALE_BITS 4
+static inline uint8x8_t filter_intra_predictor(
+ uint8x8_t s0, uint8x8_t s1, uint8x8_t s2, uint8x8_t s3, uint8x8_t s4,
+ uint8x8_t s5, uint8x8_t s6, const uint8x8_t f0, const uint8x8_t f1,
+ const uint8x8_t f2, const uint8x8_t f3, const uint8x8_t f4,
+ const uint8x8_t f5, const uint8x8_t f6) {
+ uint16x8_t acc = vmull_u8(s1, f1);
+ // First row of each filter has all negative values so subtract.
+ acc = vmlsl_u8(acc, s0, f0);
+ acc = vmlal_u8(acc, s2, f2);
+ acc = vmlal_u8(acc, s3, f3);
+ acc = vmlal_u8(acc, s4, f4);
+ acc = vmlal_u8(acc, s5, f5);
+ acc = vmlal_u8(acc, s6, f6);
+
+ return vqrshrun_n_s16(vreinterpretq_s16_u16(acc), FILTER_INTRA_SCALE_BITS);
+}
void av1_filter_intra_predictor_neon(uint8_t *dst, ptrdiff_t stride,
TX_SIZE tx_size, const uint8_t *above,
@@ -91,82 +107,118 @@ void av1_filter_intra_predictor_neon(uint8_t *dst, ptrdiff_t stride,
const uint8x8_t f5 = vld1_u8(av1_filter_intra_taps_neon[mode][5]);
const uint8x8_t f6 = vld1_u8(av1_filter_intra_taps_neon[mode][6]);
- uint8_t buffer[33][33];
- // Populate the top row in the scratch buffer with data from above.
- memcpy(buffer[0], &above[-1], (width + 1) * sizeof(uint8_t));
- // Populate the first column in the scratch buffer with data from the left.
- int r = 0;
- do {
- buffer[r + 1][0] = left[r];
- } while (++r < height);
-
// Computing 4 cols per iteration (instead of 8) for 8x<h> blocks is faster.
if (width <= 8) {
- r = 1;
+ uint8x8_t s0 = vdup_n_u8(above[-1]);
+ uint8x8_t s5 = vdup_n_u8(left[0]);
+ uint8x8_t s6 = vdup_n_u8(left[1]);
+
+ int c = 0;
+ do {
+ uint8x8_t s1234 = load_u8_4x1(above + c);
+ uint8x8_t s1 = vdup_lane_u8(s1234, 0);
+ uint8x8_t s2 = vdup_lane_u8(s1234, 1);
+ uint8x8_t s3 = vdup_lane_u8(s1234, 2);
+ uint8x8_t s4 = vdup_lane_u8(s1234, 3);
+
+ uint8x8_t res = filter_intra_predictor(s0, s1, s2, s3, s4, s5, s6, f0, f1,
+ f2, f3, f4, f5, f6);
+
+ store_u8x4_strided_x2(dst + c, stride, res);
+
+ s0 = s4;
+ s5 = vdup_lane_u8(res, 3);
+ s6 = vdup_lane_u8(res, 7);
+
+ c += 4;
+ } while (c < width);
+
+ int r = 2;
do {
- int c = 1;
- uint8x8_t s0 = vld1_dup_u8(&buffer[r - 1][c - 1]);
- uint8x8_t s5 = vld1_dup_u8(&buffer[r + 0][c - 1]);
- uint8x8_t s6 = vld1_dup_u8(&buffer[r + 1][c - 1]);
+ s0 = vdup_n_u8(left[r - 1]);
+ s5 = vdup_n_u8(left[r + 0]);
+ s6 = vdup_n_u8(left[r + 1]);
+ c = 0;
do {
- uint8x8_t s1234 = load_unaligned_u8_4x1(&buffer[r - 1][c - 1] + 1);
+ uint8x8_t s1234 = load_u8_4x1(dst + (r - 1) * stride + c);
uint8x8_t s1 = vdup_lane_u8(s1234, 0);
uint8x8_t s2 = vdup_lane_u8(s1234, 1);
uint8x8_t s3 = vdup_lane_u8(s1234, 2);
uint8x8_t s4 = vdup_lane_u8(s1234, 3);
- uint16x8_t sum = vmull_u8(s1, f1);
- // First row of each filter has all negative values so subtract.
- sum = vmlsl_u8(sum, s0, f0);
- sum = vmlal_u8(sum, s2, f2);
- sum = vmlal_u8(sum, s3, f3);
- sum = vmlal_u8(sum, s4, f4);
- sum = vmlal_u8(sum, s5, f5);
- sum = vmlal_u8(sum, s6, f6);
-
- uint8x8_t res =
- vqrshrun_n_s16(vreinterpretq_s16_u16(sum), FILTER_INTRA_SCALE_BITS);
+ uint8x8_t res = filter_intra_predictor(s0, s1, s2, s3, s4, s5, s6, f0,
+ f1, f2, f3, f4, f5, f6);
- // Store buffer[r + 0][c] and buffer[r + 1][c].
- store_u8x4_strided_x2(&buffer[r][c], 33, res);
-
- store_u8x4_strided_x2(dst + (r - 1) * stride + c - 1, stride, res);
+ store_u8x4_strided_x2(dst + r * stride + c, stride, res);
s0 = s4;
s5 = vdup_lane_u8(res, 3);
s6 = vdup_lane_u8(res, 7);
+
c += 4;
- } while (c < width + 1);
+ } while (c < width);
r += 2;
- } while (r < height + 1);
+ } while (r < height);
} else {
- r = 1;
+ uint8x8_t s0_lo = vdup_n_u8(above[-1]);
+ uint8x8_t s5_lo = vdup_n_u8(left[0]);
+ uint8x8_t s6_lo = vdup_n_u8(left[1]);
+
+ int c = 0;
+ do {
+ uint8x8_t s1234 = vld1_u8(above + c);
+ uint8x8_t s1_lo = vdup_lane_u8(s1234, 0);
+ uint8x8_t s2_lo = vdup_lane_u8(s1234, 1);
+ uint8x8_t s3_lo = vdup_lane_u8(s1234, 2);
+ uint8x8_t s4_lo = vdup_lane_u8(s1234, 3);
+
+ uint8x8_t res_lo =
+ filter_intra_predictor(s0_lo, s1_lo, s2_lo, s3_lo, s4_lo, s5_lo,
+ s6_lo, f0, f1, f2, f3, f4, f5, f6);
+
+ uint8x8_t s0_hi = s4_lo;
+ uint8x8_t s1_hi = vdup_lane_u8(s1234, 4);
+ uint8x8_t s2_hi = vdup_lane_u8(s1234, 5);
+ uint8x8_t s3_hi = vdup_lane_u8(s1234, 6);
+ uint8x8_t s4_hi = vdup_lane_u8(s1234, 7);
+ uint8x8_t s5_hi = vdup_lane_u8(res_lo, 3);
+ uint8x8_t s6_hi = vdup_lane_u8(res_lo, 7);
+
+ uint8x8_t res_hi =
+ filter_intra_predictor(s0_hi, s1_hi, s2_hi, s3_hi, s4_hi, s5_hi,
+ s6_hi, f0, f1, f2, f3, f4, f5, f6);
+
+ uint32x2x2_t res =
+ vzip_u32(vreinterpret_u32_u8(res_lo), vreinterpret_u32_u8(res_hi));
+
+ vst1_u8(dst + 0 * stride + c, vreinterpret_u8_u32(res.val[0]));
+ vst1_u8(dst + 1 * stride + c, vreinterpret_u8_u32(res.val[1]));
+
+ s0_lo = s4_hi;
+ s5_lo = vdup_lane_u8(res_hi, 3);
+ s6_lo = vdup_lane_u8(res_hi, 7);
+ c += 8;
+ } while (c < width);
+
+ int r = 2;
do {
- int c = 1;
- uint8x8_t s0_lo = vld1_dup_u8(&buffer[r - 1][c - 1]);
- uint8x8_t s5_lo = vld1_dup_u8(&buffer[r + 0][c - 1]);
- uint8x8_t s6_lo = vld1_dup_u8(&buffer[r + 1][c - 1]);
+ s0_lo = vdup_n_u8(left[r - 1]);
+ s5_lo = vdup_n_u8(left[r + 0]);
+ s6_lo = vdup_n_u8(left[r + 1]);
+ c = 0;
do {
- uint8x8_t s1234 = vld1_u8(&buffer[r - 1][c - 1] + 1);
+ uint8x8_t s1234 = vld1_u8(dst + (r - 1) * stride + c);
uint8x8_t s1_lo = vdup_lane_u8(s1234, 0);
uint8x8_t s2_lo = vdup_lane_u8(s1234, 1);
uint8x8_t s3_lo = vdup_lane_u8(s1234, 2);
uint8x8_t s4_lo = vdup_lane_u8(s1234, 3);
- uint16x8_t sum_lo = vmull_u8(s1_lo, f1);
- // First row of each filter has all negative values so subtract.
- sum_lo = vmlsl_u8(sum_lo, s0_lo, f0);
- sum_lo = vmlal_u8(sum_lo, s2_lo, f2);
- sum_lo = vmlal_u8(sum_lo, s3_lo, f3);
- sum_lo = vmlal_u8(sum_lo, s4_lo, f4);
- sum_lo = vmlal_u8(sum_lo, s5_lo, f5);
- sum_lo = vmlal_u8(sum_lo, s6_lo, f6);
-
- uint8x8_t res_lo = vqrshrun_n_s16(vreinterpretq_s16_u16(sum_lo),
- FILTER_INTRA_SCALE_BITS);
+ uint8x8_t res_lo =
+ filter_intra_predictor(s0_lo, s1_lo, s2_lo, s3_lo, s4_lo, s5_lo,
+ s6_lo, f0, f1, f2, f3, f4, f5, f6);
uint8x8_t s0_hi = s4_lo;
uint8x8_t s1_hi = vdup_lane_u8(s1234, 4);
@@ -176,37 +228,24 @@ void av1_filter_intra_predictor_neon(uint8_t *dst, ptrdiff_t stride,
uint8x8_t s5_hi = vdup_lane_u8(res_lo, 3);
uint8x8_t s6_hi = vdup_lane_u8(res_lo, 7);
- uint16x8_t sum_hi = vmull_u8(s1_hi, f1);
- // First row of each filter has all negative values so subtract.
- sum_hi = vmlsl_u8(sum_hi, s0_hi, f0);
- sum_hi = vmlal_u8(sum_hi, s2_hi, f2);
- sum_hi = vmlal_u8(sum_hi, s3_hi, f3);
- sum_hi = vmlal_u8(sum_hi, s4_hi, f4);
- sum_hi = vmlal_u8(sum_hi, s5_hi, f5);
- sum_hi = vmlal_u8(sum_hi, s6_hi, f6);
-
- uint8x8_t res_hi = vqrshrun_n_s16(vreinterpretq_s16_u16(sum_hi),
- FILTER_INTRA_SCALE_BITS);
+ uint8x8_t res_hi =
+ filter_intra_predictor(s0_hi, s1_hi, s2_hi, s3_hi, s4_hi, s5_hi,
+ s6_hi, f0, f1, f2, f3, f4, f5, f6);
uint32x2x2_t res =
vzip_u32(vreinterpret_u32_u8(res_lo), vreinterpret_u32_u8(res_hi));
- vst1_u8(&buffer[r + 0][c], vreinterpret_u8_u32(res.val[0]));
- vst1_u8(&buffer[r + 1][c], vreinterpret_u8_u32(res.val[1]));
-
- vst1_u8(dst + (r - 1) * stride + c - 1,
- vreinterpret_u8_u32(res.val[0]));
- vst1_u8(dst + (r + 0) * stride + c - 1,
- vreinterpret_u8_u32(res.val[1]));
+ vst1_u8(dst + (r + 0) * stride + c, vreinterpret_u8_u32(res.val[0]));
+ vst1_u8(dst + (r + 1) * stride + c, vreinterpret_u8_u32(res.val[1]));
s0_lo = s4_hi;
s5_lo = vdup_lane_u8(res_hi, 3);
s6_lo = vdup_lane_u8(res_hi, 7);
c += 8;
- } while (c < width + 1);
+ } while (c < width);
r += 2;
- } while (r < height + 1);
+ } while (r < height);
}
}