Commit 039042e9eb for aom

commit 039042e9ebb008b3fd0d39d6dbc635f8c80ea30e
Author: Jerome Jiang <jianj@google.com>
Date:   Wed Jun 10 11:23:21 2026 -0400

    Optimize av1_lowbd_pixel_proj_error_avx2.

    Optimize the AVX2 implementation of the low bitdepth pixel projection
    error calculation.

    - Implement stride decoupling to support distinct strides for all
      input buffers.
    - Add width-specific optimized paths (width == 8, width >= 32
      width >= 16) to improve efficiency for different block sizes.
    - Use register-direct SIMD horizontal reduction to avoid memory
      roundtrips.

        | Block   | Before  | After   | Speedup   |
        | :------ | :-----: | :-----: | :-------- |
        | 8x8     |  71.78n |  12.35n | 5.81x     |
        | 16x16   |  51.81n |  39.27n | 1.32x     |
        | 32x32   | 165.80n | 151.40n | 1.10x     |
        | 64x64   | 743.00n | 653.10n | 1.14x     |
        | 128x128 |  2.583µ |  2.300µ | 1.12x     |
        | geomean | 259.80n | 161.60n | 1.61x     |

    Change-Id: I80072dec3cb710fb9a6e4fdebc2276eed485af6e

diff --git a/av1/encoder/x86/pickrst_avx2.c b/av1/encoder/x86/pickrst_avx2.c
index f5375dd296..792bc81bb9 100644
--- a/av1/encoder/x86/pickrst_avx2.c
+++ b/av1/encoder/x86/pickrst_avx2.c
@@ -1545,6 +1545,51 @@ static inline __m256i pair_set_epi16(int a, int b) {
       (int32_t)(((uint16_t)(a)) | (((uint32_t)(uint16_t)(b)) << 16)));
 }

+static inline __m256i load_shuffled_u8_to_epi16(const uint8_t *ptr) {
+  const __m128i raw = xx_loadu_128(ptr);
+  const __m128i shuffled = _mm_shuffle_epi32(raw, _MM_SHUFFLE(3, 1, 2, 0));
+  return _mm256_cvtepu8_epi16(shuffled);
+}
+
+static inline __m256i load_shuffled_u8_dual8_to_epi16(const uint8_t *ptrA,
+                                                      const uint8_t *ptrB) {
+  const __m128i rawA = _mm_loadl_epi64((const __m128i *)ptrA);
+  const __m128i rawB = _mm_loadl_epi64((const __m128i *)ptrB);
+  const __m128i raw_AB = _mm_unpacklo_epi64(rawA, rawB);
+  const __m128i shuffled = _mm_shuffle_epi32(raw_AB, _MM_SHUFFLE(3, 1, 2, 0));
+  return _mm256_cvtepu8_epi16(shuffled);
+}
+
+static inline __m256i calc_proj_err_r0_r1_avx2(
+    const __m256i d0, const __m256i s0, const __m256i flt0_16b,
+    const __m256i flt1_16b, const __m256i xq_coeff, const __m256i rounding,
+    int shift) {
+  const __m256i u0 = _mm256_slli_epi16(d0, SGRPROJ_RST_BITS);
+  const __m256i v0 = _mm256_madd_epi16(
+      xq_coeff, _mm256_unpacklo_epi16(_mm256_sub_epi16(flt0_16b, u0),
+                                      _mm256_sub_epi16(flt1_16b, u0)));
+  const __m256i v1 = _mm256_madd_epi16(
+      xq_coeff, _mm256_unpackhi_epi16(_mm256_sub_epi16(flt0_16b, u0),
+                                      _mm256_sub_epi16(flt1_16b, u0)));
+  const __m256i vr = _mm256_packs_epi32(
+      _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift),
+      _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift));
+  return _mm256_add_epi16(vr, _mm256_sub_epi16(d0, s0));
+}
+
+static inline __m256i calc_proj_err_r0_or_r1_avx2(
+    const __m256i d0, const __m256i s0, const __m256i flt_16b,
+    const __m256i xq_coeff, const __m256i rounding, int shift) {
+  const __m256i v0 =
+      _mm256_madd_epi16(xq_coeff, _mm256_unpacklo_epi16(flt_16b, d0));
+  const __m256i v1 =
+      _mm256_madd_epi16(xq_coeff, _mm256_unpackhi_epi16(flt_16b, d0));
+  const __m256i vr_16b = _mm256_packs_epi32(
+      _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift),
+      _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift));
+  return _mm256_add_epi16(vr_16b, _mm256_sub_epi16(d0, s0));
+}
+
 int64_t av1_lowbd_pixel_proj_error_avx2(
     const uint8_t *src8, int width, int height, int src_stride,
     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
@@ -1556,53 +1601,220 @@ int64_t av1_lowbd_pixel_proj_error_avx2(
   const uint8_t *src = src8;
   const uint8_t *dat = dat8;
   int64_t err = 0;
+
   if (params->r[0] > 0 && params->r[1] > 0) {
     __m256i xq_coeff = pair_set_epi16(xq[0], xq[1]);
-    for (i = 0; i < height; ++i) {
+    if (width == 8) {
       __m256i sum32 = _mm256_setzero_si256();
-      for (j = 0; j <= width - 16; j += 16) {
-        const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
-        const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
-        const __m256i flt0_16b = _mm256_permute4x64_epi64(
-            _mm256_packs_epi32(yy_loadu_256(flt0 + j),
-                               yy_loadu_256(flt0 + j + 8)),
-            0xd8);
-        const __m256i flt1_16b = _mm256_permute4x64_epi64(
-            _mm256_packs_epi32(yy_loadu_256(flt1 + j),
-                               yy_loadu_256(flt1 + j + 8)),
-            0xd8);
-        const __m256i u0 = _mm256_slli_epi16(d0, SGRPROJ_RST_BITS);
-        const __m256i flt0_0_sub_u = _mm256_sub_epi16(flt0_16b, u0);
-        const __m256i flt1_0_sub_u = _mm256_sub_epi16(flt1_16b, u0);
-        const __m256i v0 = _mm256_madd_epi16(
-            xq_coeff, _mm256_unpacklo_epi16(flt0_0_sub_u, flt1_0_sub_u));
-        const __m256i v1 = _mm256_madd_epi16(
-            xq_coeff, _mm256_unpackhi_epi16(flt0_0_sub_u, flt1_0_sub_u));
-        const __m256i vr0 =
-            _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift);
-        const __m256i vr1 =
-            _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift);
-        const __m256i e0 = _mm256_sub_epi16(
-            _mm256_add_epi16(_mm256_packs_epi32(vr0, vr1), d0), s0);
+      const int height_even = height & ~1;
+      for (i = 0; i < height_even; i += 2) {
+        const uint8_t *dat_rowB = dat + dat_stride;
+        const uint8_t *src_rowB = src + src_stride;
+        const int32_t *flt0_rowB = flt0 + flt0_stride;
+        const int32_t *flt1_rowB = flt1 + flt1_stride;
+
+        const __m256i d0 = load_shuffled_u8_dual8_to_epi16(dat, dat_rowB);
+        const __m256i s0 = load_shuffled_u8_dual8_to_epi16(src, src_rowB);
+        const __m256i flt0_16b =
+            _mm256_packs_epi32(yy_loadu_256(flt0), yy_loadu_256(flt0_rowB));
+        const __m256i flt1_16b =
+            _mm256_packs_epi32(yy_loadu_256(flt1), yy_loadu_256(flt1_rowB));
+
+        const __m256i e0 = calc_proj_err_r0_r1_avx2(d0, s0, flt0_16b, flt1_16b,
+                                                    xq_coeff, rounding, shift);
         const __m256i err0 = _mm256_madd_epi16(e0, e0);
         sum32 = _mm256_add_epi32(sum32, err0);
+
+        dat += 2 * dat_stride;
+        src += 2 * src_stride;
+        flt0 += 2 * flt0_stride;
+        flt1 += 2 * flt1_stride;
       }
-      for (k = j; k < width; ++k) {
-        const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
-        int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
-        const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
-        err += ((int64_t)e * e);
+      if (i < height) {
+        for (k = 0; k < 8; ++k) {
+          const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+          int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
+          const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+          err += ((int64_t)e * e);
+        }
       }
-      dat += dat_stride;
-      src += src_stride;
-      flt0 += flt0_stride;
-      flt1 += flt1_stride;
       const __m256i sum64_0 =
           _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
       const __m256i sum64_1 =
           _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
-      sum64 = _mm256_add_epi64(sum64, sum64_0);
-      sum64 = _mm256_add_epi64(sum64, sum64_1);
+      sum64 = _mm256_add_epi64(sum64, _mm256_add_epi64(sum64_0, sum64_1));
+    } else if (width == 16) {
+      __m256i sum32_A = _mm256_setzero_si256();
+      __m256i sum32_B = _mm256_setzero_si256();
+      __m256i sum32_C = _mm256_setzero_si256();
+      __m256i sum32_D = _mm256_setzero_si256();
+      const int height_v4 = height & ~3;
+      for (i = 0; i < height_v4; i += 4) {
+        const uint8_t *dat_rowB = dat + dat_stride;
+        const uint8_t *dat_rowC = dat_rowB + dat_stride;
+        const uint8_t *dat_rowD = dat_rowC + dat_stride;
+        const uint8_t *src_rowB = src + src_stride;
+        const uint8_t *src_rowC = src_rowB + src_stride;
+        const uint8_t *src_rowD = src_rowC + src_stride;
+        const int32_t *flt0_rowB = flt0 + flt0_stride;
+        const int32_t *flt0_rowC = flt0_rowB + flt0_stride;
+        const int32_t *flt0_rowD = flt0_rowC + flt0_stride;
+        const int32_t *flt1_rowB = flt1 + flt1_stride;
+        const int32_t *flt1_rowC = flt1_rowB + flt1_stride;
+        const int32_t *flt1_rowD = flt1_rowC + flt1_stride;
+
+        // Row A
+        {
+          const __m256i d0 = load_shuffled_u8_to_epi16(dat);
+          const __m256i s0 = load_shuffled_u8_to_epi16(src);
+          const __m256i flt0_16b =
+              _mm256_packs_epi32(yy_loadu_256(flt0), yy_loadu_256(flt0 + 8));
+          const __m256i flt1_16b =
+              _mm256_packs_epi32(yy_loadu_256(flt1), yy_loadu_256(flt1 + 8));
+          const __m256i e = calc_proj_err_r0_r1_avx2(d0, s0, flt0_16b, flt1_16b,
+                                                     xq_coeff, rounding, shift);
+          sum32_A = _mm256_add_epi32(sum32_A, _mm256_madd_epi16(e, e));
+        }
+        // Row B
+        {
+          const __m256i d0 = load_shuffled_u8_to_epi16(dat_rowB);
+          const __m256i s0 = load_shuffled_u8_to_epi16(src_rowB);
+          const __m256i flt0_16b = _mm256_packs_epi32(
+              yy_loadu_256(flt0_rowB), yy_loadu_256(flt0_rowB + 8));
+          const __m256i flt1_16b = _mm256_packs_epi32(
+              yy_loadu_256(flt1_rowB), yy_loadu_256(flt1_rowB + 8));
+          const __m256i e = calc_proj_err_r0_r1_avx2(d0, s0, flt0_16b, flt1_16b,
+                                                     xq_coeff, rounding, shift);
+          sum32_B = _mm256_add_epi32(sum32_B, _mm256_madd_epi16(e, e));
+        }
+        // Row C
+        {
+          const __m256i d0 = load_shuffled_u8_to_epi16(dat_rowC);
+          const __m256i s0 = load_shuffled_u8_to_epi16(src_rowC);
+          const __m256i flt0_16b = _mm256_packs_epi32(
+              yy_loadu_256(flt0_rowC), yy_loadu_256(flt0_rowC + 8));
+          const __m256i flt1_16b = _mm256_packs_epi32(
+              yy_loadu_256(flt1_rowC), yy_loadu_256(flt1_rowC + 8));
+          const __m256i e = calc_proj_err_r0_r1_avx2(d0, s0, flt0_16b, flt1_16b,
+                                                     xq_coeff, rounding, shift);
+          sum32_C = _mm256_add_epi32(sum32_C, _mm256_madd_epi16(e, e));
+        }
+        // Row D
+        {
+          const __m256i d0 = load_shuffled_u8_to_epi16(dat_rowD);
+          const __m256i s0 = load_shuffled_u8_to_epi16(src_rowD);
+          const __m256i flt0_16b = _mm256_packs_epi32(
+              yy_loadu_256(flt0_rowD), yy_loadu_256(flt0_rowD + 8));
+          const __m256i flt1_16b = _mm256_packs_epi32(
+              yy_loadu_256(flt1_rowD), yy_loadu_256(flt1_rowD + 8));
+          const __m256i e = calc_proj_err_r0_r1_avx2(d0, s0, flt0_16b, flt1_16b,
+                                                     xq_coeff, rounding, shift);
+          sum32_D = _mm256_add_epi32(sum32_D, _mm256_madd_epi16(e, e));
+        }
+
+        dat += 4 * dat_stride;
+        src += 4 * src_stride;
+        flt0 += 4 * flt0_stride;
+        flt1 += 4 * flt1_stride;
+      }
+      for (; i < height; ++i) {
+        const __m256i d0 = load_shuffled_u8_to_epi16(dat);
+        const __m256i s0 = load_shuffled_u8_to_epi16(src);
+        const __m256i flt0_16b =
+            _mm256_packs_epi32(yy_loadu_256(flt0), yy_loadu_256(flt0 + 8));
+        const __m256i flt1_16b =
+            _mm256_packs_epi32(yy_loadu_256(flt1), yy_loadu_256(flt1 + 8));
+        const __m256i e = calc_proj_err_r0_r1_avx2(d0, s0, flt0_16b, flt1_16b,
+                                                   xq_coeff, rounding, shift);
+        sum32_A = _mm256_add_epi32(sum32_A, _mm256_madd_epi16(e, e));
+
+        dat += dat_stride;
+        src += src_stride;
+        flt0 += flt0_stride;
+        flt1 += flt1_stride;
+      }
+      __m256i sum32 = _mm256_add_epi32(_mm256_add_epi32(sum32_A, sum32_B),
+                                       _mm256_add_epi32(sum32_C, sum32_D));
+      const __m256i sum64_0 =
+          _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+      const __m256i sum64_1 =
+          _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+      sum64 = _mm256_add_epi64(sum64, _mm256_add_epi64(sum64_0, sum64_1));
+    } else if (width >= 32 && (width % 32 == 0)) {
+      int rows_per_batch = 4096 / width;
+      if (rows_per_batch < 1) rows_per_batch = 1;
+      for (i = 0; i < height;) {
+        int rows_to_do = height - i;
+        if (rows_to_do > rows_per_batch) rows_to_do = rows_per_batch;
+        const int next_i = i + rows_to_do;
+        __m256i sum32_A = _mm256_setzero_si256();
+        __m256i sum32_B = _mm256_setzero_si256();
+        for (; i < next_i; ++i) {
+          for (j = 0; j <= width - 32; j += 32) {
+            const __m256i d_A = load_shuffled_u8_to_epi16(dat + j);
+            const __m256i s_A = load_shuffled_u8_to_epi16(src + j);
+            const __m256i flt0_A = _mm256_packs_epi32(
+                yy_loadu_256(flt0 + j), yy_loadu_256(flt0 + j + 8));
+            const __m256i flt1_A = _mm256_packs_epi32(
+                yy_loadu_256(flt1 + j), yy_loadu_256(flt1 + j + 8));
+            const __m256i e_A = calc_proj_err_r0_r1_avx2(
+                d_A, s_A, flt0_A, flt1_A, xq_coeff, rounding, shift);
+            sum32_A = _mm256_add_epi32(sum32_A, _mm256_madd_epi16(e_A, e_A));
+
+            const __m256i d_B = load_shuffled_u8_to_epi16(dat + j + 16);
+            const __m256i s_B = load_shuffled_u8_to_epi16(src + j + 16);
+            const __m256i flt0_B = _mm256_packs_epi32(
+                yy_loadu_256(flt0 + j + 16), yy_loadu_256(flt0 + j + 24));
+            const __m256i flt1_B = _mm256_packs_epi32(
+                yy_loadu_256(flt1 + j + 16), yy_loadu_256(flt1 + j + 24));
+            const __m256i e_B = calc_proj_err_r0_r1_avx2(
+                d_B, s_B, flt0_B, flt1_B, xq_coeff, rounding, shift);
+            sum32_B = _mm256_add_epi32(sum32_B, _mm256_madd_epi16(e_B, e_B));
+          }
+          dat += dat_stride;
+          src += src_stride;
+          flt0 += flt0_stride;
+          flt1 += flt1_stride;
+        }
+        __m256i sum32 = _mm256_add_epi32(sum32_A, sum32_B);
+        const __m256i sum64_0 =
+            _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+        const __m256i sum64_1 =
+            _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+        sum64 = _mm256_add_epi64(sum64, _mm256_add_epi64(sum64_0, sum64_1));
+      }
+    } else {
+      // General fallback
+      for (i = 0; i < height; ++i) {
+        __m256i sum32 = _mm256_setzero_si256();
+        for (j = 0; j <= width - 16; j += 16) {
+          const __m256i d0 = load_shuffled_u8_to_epi16(dat + j);
+          const __m256i s0 = load_shuffled_u8_to_epi16(src + j);
+          const __m256i flt0_16b = _mm256_packs_epi32(
+              yy_loadu_256(flt0 + j), yy_loadu_256(flt0 + j + 8));
+          const __m256i flt1_16b = _mm256_packs_epi32(
+              yy_loadu_256(flt1 + j), yy_loadu_256(flt1 + j + 8));
+          const __m256i e0 = calc_proj_err_r0_r1_avx2(
+              d0, s0, flt0_16b, flt1_16b, xq_coeff, rounding, shift);
+          sum32 = _mm256_add_epi32(sum32, _mm256_madd_epi16(e0, e0));
+        }
+        for (k = j; k < width; ++k) {
+          const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+          int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
+          const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+          err += ((int64_t)e * e);
+        }
+        dat += dat_stride;
+        src += src_stride;
+        flt0 += flt0_stride;
+        flt1 += flt1_stride;
+        const __m256i sum64_0 =
+            _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+        const __m256i sum64_1 =
+            _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+        sum64 = _mm256_add_epi64(sum64, sum64_0);
+        sum64 = _mm256_add_epi64(sum64, sum64_1);
+      }
     }
   } else if (params->r[0] > 0 || params->r[1] > 0) {
     const int xq_active = (params->r[0] > 0) ? xq[0] : xq[1];
@@ -1610,66 +1822,322 @@ int64_t av1_lowbd_pixel_proj_error_avx2(
         pair_set_epi16(xq_active, -xq_active * (1 << SGRPROJ_RST_BITS));
     const int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
     const int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
-    for (i = 0; i < height; ++i) {
+
+    if (width == 8) {
       __m256i sum32 = _mm256_setzero_si256();
-      for (j = 0; j <= width - 16; j += 16) {
-        const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
-        const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
-        const __m256i flt_16b = _mm256_permute4x64_epi64(
-            _mm256_packs_epi32(yy_loadu_256(flt + j),
-                               yy_loadu_256(flt + j + 8)),
-            0xd8);
-        const __m256i v0 =
-            _mm256_madd_epi16(xq_coeff, _mm256_unpacklo_epi16(flt_16b, d0));
-        const __m256i v1 =
-            _mm256_madd_epi16(xq_coeff, _mm256_unpackhi_epi16(flt_16b, d0));
-        const __m256i vr0 =
-            _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift);
-        const __m256i vr1 =
-            _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift);
-        const __m256i e0 = _mm256_sub_epi16(
-            _mm256_add_epi16(_mm256_packs_epi32(vr0, vr1), d0), s0);
+      const int height_even = height & ~1;
+      for (i = 0; i < height_even; i += 2) {
+        const uint8_t *dat_rowB = dat + dat_stride;
+        const uint8_t *src_rowB = src + src_stride;
+        const int32_t *flt_rowB = flt + flt_stride;
+
+        const __m256i d0 = load_shuffled_u8_dual8_to_epi16(dat, dat_rowB);
+        const __m256i s0 = load_shuffled_u8_dual8_to_epi16(src, src_rowB);
+        const __m256i flt_16b =
+            _mm256_packs_epi32(yy_loadu_256(flt), yy_loadu_256(flt_rowB));
+
+        const __m256i e0 = calc_proj_err_r0_or_r1_avx2(
+            d0, s0, flt_16b, xq_coeff, rounding, shift);
         const __m256i err0 = _mm256_madd_epi16(e0, e0);
         sum32 = _mm256_add_epi32(sum32, err0);
+
+        dat += 2 * dat_stride;
+        src += 2 * src_stride;
+        flt += 2 * flt_stride;
       }
-      for (k = j; k < width; ++k) {
-        const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
-        int32_t v = xq_active * (flt[k] - u);
-        const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
-        err += ((int64_t)e * e);
+      if (i < height) {
+        for (k = 0; k < 8; ++k) {
+          const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+          int32_t v = xq_active * (flt[k] - u);
+          const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+          err += ((int64_t)e * e);
+        }
       }
-      dat += dat_stride;
-      src += src_stride;
-      flt += flt_stride;
       const __m256i sum64_0 =
           _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
       const __m256i sum64_1 =
           _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
-      sum64 = _mm256_add_epi64(sum64, sum64_0);
-      sum64 = _mm256_add_epi64(sum64, sum64_1);
+      sum64 = _mm256_add_epi64(sum64, _mm256_add_epi64(sum64_0, sum64_1));
+    } else if (width == 16) {
+      __m256i sum32_A = _mm256_setzero_si256();
+      __m256i sum32_B = _mm256_setzero_si256();
+      __m256i sum32_C = _mm256_setzero_si256();
+      __m256i sum32_D = _mm256_setzero_si256();
+      const int height_v4 = height & ~3;
+      for (i = 0; i < height_v4; i += 4) {
+        const uint8_t *dat_rowB = dat + dat_stride;
+        const uint8_t *dat_rowC = dat_rowB + dat_stride;
+        const uint8_t *dat_rowD = dat_rowC + dat_stride;
+        const uint8_t *src_rowB = src + src_stride;
+        const uint8_t *src_rowC = src_rowB + src_stride;
+        const uint8_t *src_rowD = src_rowC + src_stride;
+        const int32_t *flt_rowB = flt + flt_stride;
+        const int32_t *flt_rowC = flt_rowB + flt_stride;
+        const int32_t *flt_rowD = flt_rowC + flt_stride;
+
+        // Row A
+        {
+          const __m256i d0 = load_shuffled_u8_to_epi16(dat);
+          const __m256i s0 = load_shuffled_u8_to_epi16(src);
+          const __m256i flt_16b =
+              _mm256_packs_epi32(yy_loadu_256(flt), yy_loadu_256(flt + 8));
+          const __m256i e = calc_proj_err_r0_or_r1_avx2(
+              d0, s0, flt_16b, xq_coeff, rounding, shift);
+          sum32_A = _mm256_add_epi32(sum32_A, _mm256_madd_epi16(e, e));
+        }
+        // Row B
+        {
+          const __m256i d0 = load_shuffled_u8_to_epi16(dat_rowB);
+          const __m256i s0 = load_shuffled_u8_to_epi16(src_rowB);
+          const __m256i flt_16b = _mm256_packs_epi32(
+              yy_loadu_256(flt_rowB), yy_loadu_256(flt_rowB + 8));
+          const __m256i e = calc_proj_err_r0_or_r1_avx2(
+              d0, s0, flt_16b, xq_coeff, rounding, shift);
+          sum32_B = _mm256_add_epi32(sum32_B, _mm256_madd_epi16(e, e));
+        }
+        // Row C
+        {
+          const __m256i d0 = load_shuffled_u8_to_epi16(dat_rowC);
+          const __m256i s0 = load_shuffled_u8_to_epi16(src_rowC);
+          const __m256i flt_16b = _mm256_packs_epi32(
+              yy_loadu_256(flt_rowC), yy_loadu_256(flt_rowC + 8));
+          const __m256i e = calc_proj_err_r0_or_r1_avx2(
+              d0, s0, flt_16b, xq_coeff, rounding, shift);
+          sum32_C = _mm256_add_epi32(sum32_C, _mm256_madd_epi16(e, e));
+        }
+        // Row D
+        {
+          const __m256i d0 = load_shuffled_u8_to_epi16(dat_rowD);
+          const __m256i s0 = load_shuffled_u8_to_epi16(src_rowD);
+          const __m256i flt_16b = _mm256_packs_epi32(
+              yy_loadu_256(flt_rowD), yy_loadu_256(flt_rowD + 8));
+          const __m256i e = calc_proj_err_r0_or_r1_avx2(
+              d0, s0, flt_16b, xq_coeff, rounding, shift);
+          sum32_D = _mm256_add_epi32(sum32_D, _mm256_madd_epi16(e, e));
+        }
+
+        dat += 4 * dat_stride;
+        src += 4 * src_stride;
+        flt += 4 * flt_stride;
+      }
+      for (; i < height; ++i) {
+        const __m256i d0 = load_shuffled_u8_to_epi16(dat);
+        const __m256i s0 = load_shuffled_u8_to_epi16(src);
+        const __m256i flt_16b =
+            _mm256_packs_epi32(yy_loadu_256(flt), yy_loadu_256(flt + 8));
+        const __m256i e = calc_proj_err_r0_or_r1_avx2(d0, s0, flt_16b, xq_coeff,
+                                                      rounding, shift);
+        sum32_A = _mm256_add_epi32(sum32_A, _mm256_madd_epi16(e, e));
+
+        dat += dat_stride;
+        src += src_stride;
+        flt += flt_stride;
+      }
+      __m256i sum32 = _mm256_add_epi32(_mm256_add_epi32(sum32_A, sum32_B),
+                                       _mm256_add_epi32(sum32_C, sum32_D));
+      const __m256i sum64_0 =
+          _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+      const __m256i sum64_1 =
+          _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+      sum64 = _mm256_add_epi64(sum64, _mm256_add_epi64(sum64_0, sum64_1));
+    } else if (width >= 32 && (width % 32 == 0)) {
+      int rows_per_batch = 4096 / width;
+      if (rows_per_batch < 1) rows_per_batch = 1;
+      for (i = 0; i < height;) {
+        int rows_to_do = height - i;
+        if (rows_to_do > rows_per_batch) rows_to_do = rows_per_batch;
+        const int next_i = i + rows_to_do;
+        __m256i sum32_A = _mm256_setzero_si256();
+        __m256i sum32_B = _mm256_setzero_si256();
+        for (; i < next_i; ++i) {
+          for (j = 0; j <= width - 32; j += 32) {
+            const __m256i d_A = load_shuffled_u8_to_epi16(dat + j);
+            const __m256i s_A = load_shuffled_u8_to_epi16(src + j);
+            const __m256i flt_A = _mm256_packs_epi32(yy_loadu_256(flt + j),
+                                                     yy_loadu_256(flt + j + 8));
+            const __m256i e_A = calc_proj_err_r0_or_r1_avx2(
+                d_A, s_A, flt_A, xq_coeff, rounding, shift);
+            sum32_A = _mm256_add_epi32(sum32_A, _mm256_madd_epi16(e_A, e_A));
+
+            const __m256i d_B = load_shuffled_u8_to_epi16(dat + j + 16);
+            const __m256i s_B = load_shuffled_u8_to_epi16(src + j + 16);
+            const __m256i flt_B = _mm256_packs_epi32(
+                yy_loadu_256(flt + j + 16), yy_loadu_256(flt + j + 24));
+            const __m256i e_B = calc_proj_err_r0_or_r1_avx2(
+                d_B, s_B, flt_B, xq_coeff, rounding, shift);
+            sum32_B = _mm256_add_epi32(sum32_B, _mm256_madd_epi16(e_B, e_B));
+          }
+          dat += dat_stride;
+          src += src_stride;
+          flt += flt_stride;
+        }
+        __m256i sum32 = _mm256_add_epi32(sum32_A, sum32_B);
+        const __m256i sum64_0 =
+            _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+        const __m256i sum64_1 =
+            _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+        sum64 = _mm256_add_epi64(sum64, _mm256_add_epi64(sum64_0, sum64_1));
+      }
+    } else {
+      // General fallback
+      for (i = 0; i < height; ++i) {
+        __m256i sum32 = _mm256_setzero_si256();
+        for (j = 0; j <= width - 16; j += 16) {
+          const __m256i d0 = load_shuffled_u8_to_epi16(dat + j);
+          const __m256i s0 = load_shuffled_u8_to_epi16(src + j);
+          const __m256i flt_16b = _mm256_packs_epi32(yy_loadu_256(flt + j),
+                                                     yy_loadu_256(flt + j + 8));
+          const __m256i e0 = calc_proj_err_r0_or_r1_avx2(
+              d0, s0, flt_16b, xq_coeff, rounding, shift);
+          sum32 = _mm256_add_epi32(sum32, _mm256_madd_epi16(e0, e0));
+        }
+        for (k = j; k < width; ++k) {
+          const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+          int32_t v = xq_active * (flt[k] - u);
+          const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+          err += ((int64_t)e * e);
+        }
+        dat += dat_stride;
+        src += src_stride;
+        flt += flt_stride;
+        const __m256i sum64_0 =
+            _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+        const __m256i sum64_1 =
+            _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+        sum64 = _mm256_add_epi64(sum64, sum64_0);
+        sum64 = _mm256_add_epi64(sum64, sum64_1);
+      }
     }
   } else {
-    __m256i sum32 = _mm256_setzero_si256();
-    for (i = 0; i < height; ++i) {
-      for (j = 0; j <= width - 16; j += 16) {
-        const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
-        const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
-        const __m256i diff0 = _mm256_sub_epi16(d0, s0);
-        const __m256i err0 = _mm256_madd_epi16(diff0, diff0);
-        sum32 = _mm256_add_epi32(sum32, err0);
+    if (width == 8) {
+      __m256i sum32 = _mm256_setzero_si256();
+      const int height_even = height & ~1;
+      for (i = 0; i < height_even; i += 2) {
+        const uint8_t *dat_rowB = dat + dat_stride;
+        const uint8_t *src_rowB = src + src_stride;
+
+        const __m128i d_AB =
+            _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)dat),
+                               _mm_loadl_epi64((const __m128i *)dat_rowB));
+        const __m128i s_AB =
+            _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)src),
+                               _mm_loadl_epi64((const __m128i *)src_rowB));
+        const __m256i diff = _mm256_sub_epi16(_mm256_cvtepu8_epi16(d_AB),
+                                              _mm256_cvtepu8_epi16(s_AB));
+        sum32 = _mm256_add_epi32(sum32, _mm256_madd_epi16(diff, diff));
+        dat += 2 * dat_stride;
+        src += 2 * src_stride;
       }
-      for (k = j; k < width; ++k) {
-        const int32_t e = (int32_t)(dat[k]) - src[k];
-        err += ((int64_t)e * e);
+      if (i < height) {
+        const __m128i d_A = _mm_loadl_epi64((const __m128i *)dat);
+        const __m128i s_A = _mm_loadl_epi64((const __m128i *)src);
+        const __m256i diff = _mm256_sub_epi16(_mm256_cvtepu8_epi16(d_A),
+                                              _mm256_cvtepu8_epi16(s_A));
+        sum32 = _mm256_add_epi32(sum32, _mm256_madd_epi16(diff, diff));
       }
-      dat += dat_stride;
-      src += src_stride;
+      const __m256i sum64_0 =
+          _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+      const __m256i sum64_1 =
+          _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+      sum64 = _mm256_add_epi64(sum64_0, sum64_1);
+    } else if (width >= 32 && (width % 32 == 0)) {
+      __m256i sum32_A = _mm256_setzero_si256();
+      __m256i sum32_B = _mm256_setzero_si256();
+      for (i = 0; i < height; ++i) {
+        for (j = 0; j <= width - 32; j += 32) {
+          const __m256i d_A = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
+          const __m256i s_A = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
+          const __m256i diff_A = _mm256_sub_epi16(d_A, s_A);
+          sum32_A =
+              _mm256_add_epi32(sum32_A, _mm256_madd_epi16(diff_A, diff_A));
+
+          const __m256i d_B = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j + 16));
+          const __m256i s_B = _mm256_cvtepu8_epi16(xx_loadu_128(src + j + 16));
+          const __m256i diff_B = _mm256_sub_epi16(d_B, s_B);
+          sum32_B =
+              _mm256_add_epi32(sum32_B, _mm256_madd_epi16(diff_B, diff_B));
+        }
+        dat += dat_stride;
+        src += src_stride;
+      }
+      __m256i sum32 = _mm256_add_epi32(sum32_A, sum32_B);
+      const __m256i sum64_0 =
+          _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+      const __m256i sum64_1 =
+          _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+      sum64 = _mm256_add_epi64(sum64_0, sum64_1);
+    } else if (width >= 16) {
+      __m256i sum32_A = _mm256_setzero_si256();
+      __m256i sum32_B = _mm256_setzero_si256();
+      const int height_even = height & ~1;
+      for (i = 0; i < height_even; i += 2) {
+        const uint8_t *dat_rowB = dat + dat_stride;
+        const uint8_t *src_rowB = src + src_stride;
+        for (j = 0; j <= width - 16; j += 16) {
+          const __m256i d_A = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
+          const __m256i s_A = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
+          const __m256i diff_A = _mm256_sub_epi16(d_A, s_A);
+          sum32_A =
+              _mm256_add_epi32(sum32_A, _mm256_madd_epi16(diff_A, diff_A));
+
+          const __m256i d_B = _mm256_cvtepu8_epi16(xx_loadu_128(dat_rowB + j));
+          const __m256i s_B = _mm256_cvtepu8_epi16(xx_loadu_128(src_rowB + j));
+          const __m256i diff_B = _mm256_sub_epi16(d_B, s_B);
+          sum32_B =
+              _mm256_add_epi32(sum32_B, _mm256_madd_epi16(diff_B, diff_B));
+        }
+        for (k = j; k < width; ++k) {
+          const int32_t e_A = (int32_t)dat[k] - src[k];
+          err += (int64_t)e_A * e_A;
+          const int32_t e_B = (int32_t)dat_rowB[k] - src_rowB[k];
+          err += (int64_t)e_B * e_B;
+        }
+        dat += 2 * dat_stride;
+        src += 2 * src_stride;
+      }
+      if (i < height) {
+        for (j = 0; j <= width - 16; j += 16) {
+          const __m256i d_A = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
+          const __m256i s_A = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
+          const __m256i diff_A = _mm256_sub_epi16(d_A, s_A);
+          sum32_A =
+              _mm256_add_epi32(sum32_A, _mm256_madd_epi16(diff_A, diff_A));
+        }
+        for (k = j; k < width; ++k) {
+          const int32_t e_A = (int32_t)dat[k] - src[k];
+          err += (int64_t)e_A * e_A;
+        }
+      }
+      __m256i sum32 = _mm256_add_epi32(sum32_A, sum32_B);
+      const __m256i sum64_0 =
+          _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+      const __m256i sum64_1 =
+          _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+      sum64 = _mm256_add_epi64(sum64_0, sum64_1);
+    } else {
+      // General fallback
+      __m256i sum32 = _mm256_setzero_si256();
+      for (i = 0; i < height; ++i) {
+        for (j = 0; j <= width - 16; j += 16) {
+          const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
+          const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
+          const __m256i diff0 = _mm256_sub_epi16(d0, s0);
+          const __m256i err0 = _mm256_madd_epi16(diff0, diff0);
+          sum32 = _mm256_add_epi32(sum32, err0);
+        }
+        for (k = j; k < width; ++k) {
+          const int32_t e = (int32_t)(dat[k]) - src[k];
+          err += ((int64_t)e * e);
+        }
+        dat += dat_stride;
+        src += src_stride;
+      }
+      const __m256i sum64_0 =
+          _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+      const __m256i sum64_1 =
+          _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+      sum64 = _mm256_add_epi64(sum64_0, sum64_1);
     }
-    const __m256i sum64_0 =
-        _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
-    const __m256i sum64_1 =
-        _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
-    sum64 = _mm256_add_epi64(sum64_0, sum64_1);
   }
   int64_t sum[4];
   yy_storeu_256(sum, sum64);