Commit ce07b41f84 for aom

commit ce07b41f8492963b6a895a820f25044fce264c96
Author: Gerda Zsejke More <gerdazsejke.more@arm.com>
Date:   Sun Nov 23 12:54:41 2025 +0100

    Optimize Armv8.0 Neon impl of filter_intra_predictor_neon

    Read intra predictor inputs directly from dst instead of an
    intermediate buffer to eliminate redundant store/load operations.

    This is a port from SVT-AV1:
    https://gitlab.com/AOMediaCodec/SVT-AV1/-/merge_requests/2555

    Change-Id: If52f0f166821fe85cedf801148fc72996ca0def8

diff --git a/av1/common/arm/reconintra_neon.c b/av1/common/arm/reconintra_neon.c
index dc0cb17eb1..95503aacc9 100644
--- a/av1/common/arm/reconintra_neon.c
+++ b/av1/common/arm/reconintra_neon.c
@@ -20,10 +20,11 @@
 #include "aom_dsp/arm/sum_neon.h"

 #define MAX_UPSAMPLE_SZ 16
+#define FILTER_INTRA_SCALE_BITS 4

 // These kernels are a transposed version of those defined in reconintra.c,
 // with the absolute value of the negatives taken in the top row.
-DECLARE_ALIGNED(16, const uint8_t,
+DECLARE_ALIGNED(16, static const uint8_t,
                 av1_filter_intra_taps_neon[FILTER_INTRA_MODES][7][8]) = {
   // clang-format off
   {
@@ -74,7 +75,22 @@ DECLARE_ALIGNED(16, const uint8_t,
   // clang-format on
 };

-#define FILTER_INTRA_SCALE_BITS 4
+static inline uint8x8_t filter_intra_predictor(
+    uint8x8_t s0, uint8x8_t s1, uint8x8_t s2, uint8x8_t s3, uint8x8_t s4,
+    uint8x8_t s5, uint8x8_t s6, const uint8x8_t f0, const uint8x8_t f1,
+    const uint8x8_t f2, const uint8x8_t f3, const uint8x8_t f4,
+    const uint8x8_t f5, const uint8x8_t f6) {
+  uint16x8_t acc = vmull_u8(s1, f1);
+  // First row of each filter has all negative values so subtract.
+  acc = vmlsl_u8(acc, s0, f0);
+  acc = vmlal_u8(acc, s2, f2);
+  acc = vmlal_u8(acc, s3, f3);
+  acc = vmlal_u8(acc, s4, f4);
+  acc = vmlal_u8(acc, s5, f5);
+  acc = vmlal_u8(acc, s6, f6);
+
+  return vqrshrun_n_s16(vreinterpretq_s16_u16(acc), FILTER_INTRA_SCALE_BITS);
+}

 void av1_filter_intra_predictor_neon(uint8_t *dst, ptrdiff_t stride,
                                      TX_SIZE tx_size, const uint8_t *above,
@@ -91,82 +107,118 @@ void av1_filter_intra_predictor_neon(uint8_t *dst, ptrdiff_t stride,
   const uint8x8_t f5 = vld1_u8(av1_filter_intra_taps_neon[mode][5]);
   const uint8x8_t f6 = vld1_u8(av1_filter_intra_taps_neon[mode][6]);

-  uint8_t buffer[33][33];
-  // Populate the top row in the scratch buffer with data from above.
-  memcpy(buffer[0], &above[-1], (width + 1) * sizeof(uint8_t));
-  // Populate the first column in the scratch buffer with data from the left.
-  int r = 0;
-  do {
-    buffer[r + 1][0] = left[r];
-  } while (++r < height);
-
   // Computing 4 cols per iteration (instead of 8) for 8x<h> blocks is faster.
   if (width <= 8) {
-    r = 1;
+    uint8x8_t s0 = vdup_n_u8(above[-1]);
+    uint8x8_t s5 = vdup_n_u8(left[0]);
+    uint8x8_t s6 = vdup_n_u8(left[1]);
+
+    int c = 0;
+    do {
+      uint8x8_t s1234 = load_u8_4x1(above + c);
+      uint8x8_t s1 = vdup_lane_u8(s1234, 0);
+      uint8x8_t s2 = vdup_lane_u8(s1234, 1);
+      uint8x8_t s3 = vdup_lane_u8(s1234, 2);
+      uint8x8_t s4 = vdup_lane_u8(s1234, 3);
+
+      uint8x8_t res = filter_intra_predictor(s0, s1, s2, s3, s4, s5, s6, f0, f1,
+                                             f2, f3, f4, f5, f6);
+
+      store_u8x4_strided_x2(dst + c, stride, res);
+
+      s0 = s4;
+      s5 = vdup_lane_u8(res, 3);
+      s6 = vdup_lane_u8(res, 7);
+
+      c += 4;
+    } while (c < width);
+
+    int r = 2;
     do {
-      int c = 1;
-      uint8x8_t s0 = vld1_dup_u8(&buffer[r - 1][c - 1]);
-      uint8x8_t s5 = vld1_dup_u8(&buffer[r + 0][c - 1]);
-      uint8x8_t s6 = vld1_dup_u8(&buffer[r + 1][c - 1]);
+      s0 = vdup_n_u8(left[r - 1]);
+      s5 = vdup_n_u8(left[r + 0]);
+      s6 = vdup_n_u8(left[r + 1]);

+      c = 0;
       do {
-        uint8x8_t s1234 = load_unaligned_u8_4x1(&buffer[r - 1][c - 1] + 1);
+        uint8x8_t s1234 = load_u8_4x1(dst + (r - 1) * stride + c);
         uint8x8_t s1 = vdup_lane_u8(s1234, 0);
         uint8x8_t s2 = vdup_lane_u8(s1234, 1);
         uint8x8_t s3 = vdup_lane_u8(s1234, 2);
         uint8x8_t s4 = vdup_lane_u8(s1234, 3);

-        uint16x8_t sum = vmull_u8(s1, f1);
-        // First row of each filter has all negative values so subtract.
-        sum = vmlsl_u8(sum, s0, f0);
-        sum = vmlal_u8(sum, s2, f2);
-        sum = vmlal_u8(sum, s3, f3);
-        sum = vmlal_u8(sum, s4, f4);
-        sum = vmlal_u8(sum, s5, f5);
-        sum = vmlal_u8(sum, s6, f6);
-
-        uint8x8_t res =
-            vqrshrun_n_s16(vreinterpretq_s16_u16(sum), FILTER_INTRA_SCALE_BITS);
+        uint8x8_t res = filter_intra_predictor(s0, s1, s2, s3, s4, s5, s6, f0,
+                                               f1, f2, f3, f4, f5, f6);

-        // Store buffer[r + 0][c] and buffer[r + 1][c].
-        store_u8x4_strided_x2(&buffer[r][c], 33, res);
-
-        store_u8x4_strided_x2(dst + (r - 1) * stride + c - 1, stride, res);
+        store_u8x4_strided_x2(dst + r * stride + c, stride, res);

         s0 = s4;
         s5 = vdup_lane_u8(res, 3);
         s6 = vdup_lane_u8(res, 7);
+
         c += 4;
-      } while (c < width + 1);
+      } while (c < width);

       r += 2;
-    } while (r < height + 1);
+    } while (r < height);
   } else {
-    r = 1;
+    uint8x8_t s0_lo = vdup_n_u8(above[-1]);
+    uint8x8_t s5_lo = vdup_n_u8(left[0]);
+    uint8x8_t s6_lo = vdup_n_u8(left[1]);
+
+    int c = 0;
+    do {
+      uint8x8_t s1234 = vld1_u8(above + c);
+      uint8x8_t s1_lo = vdup_lane_u8(s1234, 0);
+      uint8x8_t s2_lo = vdup_lane_u8(s1234, 1);
+      uint8x8_t s3_lo = vdup_lane_u8(s1234, 2);
+      uint8x8_t s4_lo = vdup_lane_u8(s1234, 3);
+
+      uint8x8_t res_lo =
+          filter_intra_predictor(s0_lo, s1_lo, s2_lo, s3_lo, s4_lo, s5_lo,
+                                 s6_lo, f0, f1, f2, f3, f4, f5, f6);
+
+      uint8x8_t s0_hi = s4_lo;
+      uint8x8_t s1_hi = vdup_lane_u8(s1234, 4);
+      uint8x8_t s2_hi = vdup_lane_u8(s1234, 5);
+      uint8x8_t s3_hi = vdup_lane_u8(s1234, 6);
+      uint8x8_t s4_hi = vdup_lane_u8(s1234, 7);
+      uint8x8_t s5_hi = vdup_lane_u8(res_lo, 3);
+      uint8x8_t s6_hi = vdup_lane_u8(res_lo, 7);
+
+      uint8x8_t res_hi =
+          filter_intra_predictor(s0_hi, s1_hi, s2_hi, s3_hi, s4_hi, s5_hi,
+                                 s6_hi, f0, f1, f2, f3, f4, f5, f6);
+
+      uint32x2x2_t res =
+          vzip_u32(vreinterpret_u32_u8(res_lo), vreinterpret_u32_u8(res_hi));
+
+      vst1_u8(dst + 0 * stride + c, vreinterpret_u8_u32(res.val[0]));
+      vst1_u8(dst + 1 * stride + c, vreinterpret_u8_u32(res.val[1]));
+
+      s0_lo = s4_hi;
+      s5_lo = vdup_lane_u8(res_hi, 3);
+      s6_lo = vdup_lane_u8(res_hi, 7);
+      c += 8;
+    } while (c < width);
+
+    int r = 2;
     do {
-      int c = 1;
-      uint8x8_t s0_lo = vld1_dup_u8(&buffer[r - 1][c - 1]);
-      uint8x8_t s5_lo = vld1_dup_u8(&buffer[r + 0][c - 1]);
-      uint8x8_t s6_lo = vld1_dup_u8(&buffer[r + 1][c - 1]);
+      s0_lo = vdup_n_u8(left[r - 1]);
+      s5_lo = vdup_n_u8(left[r + 0]);
+      s6_lo = vdup_n_u8(left[r + 1]);

+      c = 0;
       do {
-        uint8x8_t s1234 = vld1_u8(&buffer[r - 1][c - 1] + 1);
+        uint8x8_t s1234 = vld1_u8(dst + (r - 1) * stride + c);
         uint8x8_t s1_lo = vdup_lane_u8(s1234, 0);
         uint8x8_t s2_lo = vdup_lane_u8(s1234, 1);
         uint8x8_t s3_lo = vdup_lane_u8(s1234, 2);
         uint8x8_t s4_lo = vdup_lane_u8(s1234, 3);

-        uint16x8_t sum_lo = vmull_u8(s1_lo, f1);
-        // First row of each filter has all negative values so subtract.
-        sum_lo = vmlsl_u8(sum_lo, s0_lo, f0);
-        sum_lo = vmlal_u8(sum_lo, s2_lo, f2);
-        sum_lo = vmlal_u8(sum_lo, s3_lo, f3);
-        sum_lo = vmlal_u8(sum_lo, s4_lo, f4);
-        sum_lo = vmlal_u8(sum_lo, s5_lo, f5);
-        sum_lo = vmlal_u8(sum_lo, s6_lo, f6);
-
-        uint8x8_t res_lo = vqrshrun_n_s16(vreinterpretq_s16_u16(sum_lo),
-                                          FILTER_INTRA_SCALE_BITS);
+        uint8x8_t res_lo =
+            filter_intra_predictor(s0_lo, s1_lo, s2_lo, s3_lo, s4_lo, s5_lo,
+                                   s6_lo, f0, f1, f2, f3, f4, f5, f6);

         uint8x8_t s0_hi = s4_lo;
         uint8x8_t s1_hi = vdup_lane_u8(s1234, 4);
@@ -176,37 +228,24 @@ void av1_filter_intra_predictor_neon(uint8_t *dst, ptrdiff_t stride,
         uint8x8_t s5_hi = vdup_lane_u8(res_lo, 3);
         uint8x8_t s6_hi = vdup_lane_u8(res_lo, 7);

-        uint16x8_t sum_hi = vmull_u8(s1_hi, f1);
-        // First row of each filter has all negative values so subtract.
-        sum_hi = vmlsl_u8(sum_hi, s0_hi, f0);
-        sum_hi = vmlal_u8(sum_hi, s2_hi, f2);
-        sum_hi = vmlal_u8(sum_hi, s3_hi, f3);
-        sum_hi = vmlal_u8(sum_hi, s4_hi, f4);
-        sum_hi = vmlal_u8(sum_hi, s5_hi, f5);
-        sum_hi = vmlal_u8(sum_hi, s6_hi, f6);
-
-        uint8x8_t res_hi = vqrshrun_n_s16(vreinterpretq_s16_u16(sum_hi),
-                                          FILTER_INTRA_SCALE_BITS);
+        uint8x8_t res_hi =
+            filter_intra_predictor(s0_hi, s1_hi, s2_hi, s3_hi, s4_hi, s5_hi,
+                                   s6_hi, f0, f1, f2, f3, f4, f5, f6);

         uint32x2x2_t res =
             vzip_u32(vreinterpret_u32_u8(res_lo), vreinterpret_u32_u8(res_hi));

-        vst1_u8(&buffer[r + 0][c], vreinterpret_u8_u32(res.val[0]));
-        vst1_u8(&buffer[r + 1][c], vreinterpret_u8_u32(res.val[1]));
-
-        vst1_u8(dst + (r - 1) * stride + c - 1,
-                vreinterpret_u8_u32(res.val[0]));
-        vst1_u8(dst + (r + 0) * stride + c - 1,
-                vreinterpret_u8_u32(res.val[1]));
+        vst1_u8(dst + (r + 0) * stride + c, vreinterpret_u8_u32(res.val[0]));
+        vst1_u8(dst + (r + 1) * stride + c, vreinterpret_u8_u32(res.val[1]));

         s0_lo = s4_hi;
         s5_lo = vdup_lane_u8(res_hi, 3);
         s6_lo = vdup_lane_u8(res_hi, 7);
         c += 8;
-      } while (c < width + 1);
+      } while (c < width);

       r += 2;
-    } while (r < height + 1);
+    } while (r < height);
   }
 }