Commit 5e6b124c6d for aom

commit 5e6b124c6daeb6c92a6da70758d71fb247db411c
Author: Jerome Jiang <jianj@google.com>
Date:   Tue Jun 23 17:01:03 2026 -0400

    Optimize av1_dist_wtd_convolve_2d_avx2

    - Folds the offset addition into the vertical rounding constant
    (`round_offset_const_v`), enabling the use of `_mm256_packus_epi32`
    instead of `_mm256_packs_epi32` followed by a vector add instruction.

    - Reduces initialization loads from 4 to 3 by loading non-overlapping
    256-bit vectors and reconstructing intermediate lines using
    `_mm256_permute2x128_si256`.

    - Caches `coeffs_y[1]` and `coeffs_y[2]` in registers to avoid
    redundant memory accesses inside the loop.

    - Localizes scope for `wt`, `offset_const`, and `rounding_const` to
    minimize register pressure.

    Speed up:

         Block    |  Before  |   After  |   Delta   | Speedup
        ----------+----------+----------+-----------+---------
         4x4      |    39.77 |    35.49 |   -10.75% |    1.12x
         4x8      |    44.35 |    40.01 |    -9.79% |    1.11x
         4x16     |    54.71 |    51.32 |    -6.21% |    1.07x
         8x4      |    42.03 |    37.36 |   -11.10% |    1.12x
         8x8      |    49.94 |    44.43 |   -11.03% |    1.12x
         8x16     |    65.87 |    60.02 |    -8.88% |    1.10x
         8x32     |   101.20 |    97.14 |    -4.01% |    1.04x
         16x4     |    73.46 |    64.61 |   -12.05% |    1.14x
         16x8     |    88.90 |    78.48 |   -11.72% |    1.13x
         16x16    |   120.40 |   110.90 |    -7.89% |    1.09x
         16x32    |   190.76 |   183.38 |    -3.87% |    1.04x
         16x64    |   358.98 |   343.80 |    -4.23% |    1.04x
         32x8     |   166.80 |   146.77 |   -12.01% |    1.14x
         32x16    |   229.80 |   212.71 |    -7.44% |    1.08x
         32x32    |   370.39 |   358.12 |    -3.31% |    1.03x
         32x64    |   709.49 |   679.62 |    -4.21% |    1.04x
         64x16    |   448.85 |   417.14 |    -7.07% |    1.08x
         64x32    |   727.52 |   703.81 |    -3.26% |    1.03x
         64x64    |  1413.65 |  1355.12 |    -4.14% |    1.04x
         64x128   |  3088.88 |  3005.88 |    -2.69% |    1.03x
         128x64   |  2826.65 |  2709.65 |    -4.14% |    1.04x
         128x128  |  5984.29 |  5790.35 |    -3.24% |    1.03x

    Change-Id: I416da6f9f7489e359d649a45d2a2dec9eb12776e

diff --git a/av1/common/x86/jnt_convolve_avx2.c b/av1/common/x86/jnt_convolve_avx2.c
index e3ac6d466d..2eefb38aef 100644
--- a/av1/common/x86/jnt_convolve_avx2.c
+++ b/av1/common/x86/jnt_convolve_avx2.c
@@ -583,6 +583,126 @@ void av1_dist_wtd_convolve_y_avx2(const uint8_t *src, int src_stride,
   }
 }

+#define JNT_CONVOLVE_2D_VERTICAL_FILTER_4TAP(last_4, mode)                     \
+  do {                                                                         \
+    __m256i s0_reg =                                                           \
+        _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));             \
+    __m256i s2_reg =                                                           \
+        _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));             \
+    __m256i s4_init =                                                          \
+        _mm256_loadu_si256((__m256i *)(im_block + 4 * im_stride));             \
+    __m256i s1_reg = _mm256_permute2x128_si256(s0_reg, s2_reg, 0x21);          \
+    __m256i s3_reg = _mm256_permute2x128_si256(s2_reg, s4_init, 0x21);         \
+    __m256i s0 = _mm256_unpacklo_epi16(s0_reg, s1_reg);                        \
+    __m256i s1 = _mm256_unpacklo_epi16(s2_reg, s3_reg);                        \
+    __m256i s3 = _mm256_setzero_si256();                                       \
+    __m256i s4 = _mm256_setzero_si256();                                       \
+    if (!(last_4)) {                                                           \
+      s3 = _mm256_unpackhi_epi16(s0_reg, s1_reg);                              \
+      s4 = _mm256_unpackhi_epi16(s2_reg, s3_reg);                              \
+    }                                                                          \
+    for (i = 0; i < h; i += 2) {                                               \
+      const int16_t *data = &im_block[i * im_stride];                          \
+                                                                               \
+      const __m256i s4_reg =                                                   \
+          _mm256_loadu_si256((__m256i *)(data + 4 * im_stride));               \
+      const __m256i s5_reg =                                                   \
+          _mm256_loadu_si256((__m256i *)(data + 5 * im_stride));               \
+                                                                               \
+      __m256i s2 = _mm256_unpacklo_epi16(s4_reg, s5_reg);                      \
+                                                                               \
+      const __m256i res_a_1 = _mm256_madd_epi16(s0, coeffs_y[1]);              \
+      const __m256i res_a_2 = _mm256_madd_epi16(s1, coeffs_y[2]);              \
+      const __m256i res_a = _mm256_add_epi32(res_a_1, res_a_2);                \
+      const __m256i res_a_round =                                              \
+          _mm256_srai_epi32(_mm256_add_epi32(res_a, round_offset_const_v), 7); \
+                                                                               \
+      __m256i res_unsigned;                                                    \
+      if (last_4) {                                                            \
+        res_unsigned = _mm256_packus_epi32(res_a_round, res_a_round);          \
+      } else {                                                                 \
+        __m256i s5 = _mm256_unpackhi_epi16(s4_reg, s5_reg);                    \
+        const __m256i res_b_1 = _mm256_madd_epi16(s3, coeffs_y[1]);            \
+        const __m256i res_b_2 = _mm256_madd_epi16(s4, coeffs_y[2]);            \
+        const __m256i res_b = _mm256_add_epi32(res_b_1, res_b_2);              \
+        const __m256i res_b_round = _mm256_srai_epi32(                         \
+            _mm256_add_epi32(res_b, round_offset_const_v), 7);                 \
+        res_unsigned = _mm256_packus_epi32(res_a_round, res_b_round);          \
+        s3 = s4;                                                               \
+        s4 = s5;                                                               \
+      }                                                                        \
+      s0 = s1;                                                                 \
+      s1 = s2;                                                                 \
+                                                                               \
+      if (mode == 0) {                                                         \
+        const __m256i comp_const = _mm256_set1_epi32(-16 * offset + 128);      \
+        const __m256i wt = unpack_weights_avx2(conv_params);                   \
+        __m256i round_result;                                                  \
+        const __m256i data_ref_0 = load_line2_avx2(                            \
+            &dst[i * dst_stride + j], &dst[i * dst_stride + j + dst_stride]);  \
+        const __m256i data_lo =                                                \
+            _mm256_unpacklo_epi16(data_ref_0, res_unsigned);                   \
+        const __m256i wt_res_lo = _mm256_madd_epi16(data_lo, wt);              \
+        const __m256i fused_lo =                                               \
+            _mm256_srai_epi32(_mm256_add_epi32(wt_res_lo, comp_const), 8);     \
+        if (last_4) {                                                          \
+          round_result = _mm256_packs_epi32(fused_lo, fused_lo);               \
+        } else {                                                               \
+          const __m256i data_hi =                                              \
+              _mm256_unpackhi_epi16(data_ref_0, res_unsigned);                 \
+          const __m256i wt_res_hi = _mm256_madd_epi16(data_hi, wt);            \
+          const __m256i fused_hi =                                             \
+              _mm256_srai_epi32(_mm256_add_epi32(wt_res_hi, comp_const), 8);   \
+          round_result = _mm256_packs_epi32(fused_lo, fused_hi);               \
+        }                                                                      \
+        const __m256i res_8 = _mm256_packus_epi16(round_result, round_result); \
+        const __m128i res_0 = _mm256_castsi256_si128(res_8);                   \
+        const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);              \
+        if (last_4) {                                                          \
+          *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);     \
+          *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) =                 \
+              _mm_cvtsi128_si32(res_1);                                        \
+        } else {                                                               \
+          _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);    \
+          _mm_storel_epi64(                                                    \
+              (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1); \
+        }                                                                      \
+      } else if (mode == 1) {                                                  \
+        const int rounding_shift =                                             \
+            2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;     \
+        const __m256i offset_const = _mm256_set1_epi16(offset);                \
+        const __m256i rounding_const =                                         \
+            _mm256_set1_epi16((1 << rounding_shift) >> 1);                     \
+        const __m256i data_ref_0 = load_line2_avx2(                            \
+            &dst[i * dst_stride + j], &dst[i * dst_stride + j + dst_stride]);  \
+        const __m256i wt_res = _mm256_add_epi16(data_ref_0, res_unsigned);     \
+        const __m256i comp_avg_res = _mm256_srai_epi16(wt_res, 1);             \
+        const __m256i res_signed =                                             \
+            _mm256_sub_epi16(comp_avg_res, offset_const);                      \
+        const __m256i round_result = _mm256_srai_epi16(                        \
+            _mm256_add_epi16(res_signed, rounding_const), 4);                  \
+        const __m256i res_8 = _mm256_packus_epi16(round_result, round_result); \
+        const __m128i res_0 = _mm256_castsi256_si128(res_8);                   \
+        const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);              \
+        if (last_4) {                                                          \
+          *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);     \
+          *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) =                 \
+              _mm_cvtsi128_si32(res_1);                                        \
+        } else {                                                               \
+          _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);    \
+          _mm_storel_epi64(                                                    \
+              (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1); \
+        }                                                                      \
+      } else {                                                                 \
+        const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);            \
+        _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);         \
+        const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);       \
+        _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),    \
+                        res_1);                                                \
+      }                                                                        \
+    }                                                                          \
+  } while (0)
+
 void av1_dist_wtd_convolve_2d_avx2(const uint8_t *src, int src_stride,
                                    uint8_t *dst0, int dst_stride0, int w, int h,
                                    const InterpFilterParams *filter_params_x,
@@ -597,24 +717,22 @@ void av1_dist_wtd_convolve_2d_avx2(const uint8_t *src, int src_stride,

   int im_stride = 8;
   int i, is_horiz_4tap = 0, is_vert_4tap = 0;
-  const __m256i wt = unpack_weights_avx2(conv_params);
   const int do_average = conv_params->do_average;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
   const int offset_0 =
       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
-  const __m256i offset_const = _mm256_set1_epi16(offset);
-  const int rounding_shift =
-      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
-  const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);

   assert(conv_params->round_0 > 0);

   const __m256i round_const_h = _mm256_set1_epi16(
       ((1 << (conv_params->round_0 - 1)) >> 1) + (1 << (bd + FILTER_BITS - 2)));
-  const __m256i round_const_v = _mm256_set1_epi32(
+  const int round_const_v_val =
       ((1 << conv_params->round_1) >> 1) -
-      (1 << (bd + 2 * FILTER_BITS - conv_params->round_0 - 1)));
+      (1 << (bd + 2 * FILTER_BITS - conv_params->round_0 - 1));
+  const __m256i round_const_v = _mm256_set1_epi32(round_const_v_val);
+  const __m256i round_offset_const_v =
+      _mm256_set1_epi32(round_const_v_val + (offset << 7));

   DECLARE_ALIGNED(32, __m256i, filt[4]);
   DECLARE_ALIGNED(32, __m256i, coeffs_x[4]);
@@ -640,405 +758,31 @@ void av1_dist_wtd_convolve_2d_avx2(const uint8_t *src, int src_stride,
     const int fo_horiz = 1;
     const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
     if (w > 4) {
-      if (do_average) {
-        if (use_dist_wtd_comp_avg) {
-          const __m256i comp_const = _mm256_set1_epi32(-98176);
-          for (int j = 0; j < w; j += 8) {
-            JNT_CONVOLVE_HORIZONTAL_FILTER(src_ptr + j, convolve_lowbd_x_4tap,
-                                           coeffs_x + 1);
-
-            /* Vertical filter */
-            __m256i s[6];
-            __m256i s0 =
-                _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));
-            __m256i s1 =
-                _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));
-            __m256i s2 =
-                _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));
-            __m256i s3 =
-                _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));
-
-            s[0] = _mm256_unpacklo_epi16(s0, s1);
-            s[1] = _mm256_unpacklo_epi16(s2, s3);
-
-            s[3] = _mm256_unpackhi_epi16(s0, s1);
-            s[4] = _mm256_unpackhi_epi16(s2, s3);
-
-            for (i = 0; i < h; i += 2) {
-              const int16_t *data = &im_block[i * im_stride];
-
-              const __m256i s4 =
-                  _mm256_loadu_si256((__m256i *)(data + 4 * im_stride));
-              const __m256i s5 =
-                  _mm256_loadu_si256((__m256i *)(data + 5 * im_stride));
-
-              s[2] = _mm256_unpacklo_epi16(s4, s5);
-              s[5] = _mm256_unpackhi_epi16(s4, s5);
-
-              const __m256i res_a_1 = _mm256_madd_epi16(s[0], coeffs_y[1]);
-              const __m256i res_b_1 = _mm256_madd_epi16(s[3], coeffs_y[1]);
-              const __m256i res_a_2 = _mm256_madd_epi16(s[1], coeffs_y[2]);
-              const __m256i res_b_2 = _mm256_madd_epi16(s[4], coeffs_y[2]);
-              const __m256i res_a = _mm256_add_epi32(res_a_1, res_a_2);
-              const __m256i res_b = _mm256_add_epi32(res_b_1, res_b_2);
-
-              const __m256i res_a_round =
-                  _mm256_srai_epi32(_mm256_add_epi32(res_a, round_const_v), 7);
-              const __m256i res_b_round =
-                  _mm256_srai_epi32(_mm256_add_epi32(res_b, round_const_v), 7);
-              const __m256i res_16b =
-                  _mm256_packs_epi32(res_a_round, res_b_round);
-              const __m256i res_unsigned =
-                  _mm256_add_epi16(res_16b, offset_const);
-
-              const __m256i data_ref_0 =
-                  load_line2_avx2(&dst[i * dst_stride + j],
-                                  &dst[i * dst_stride + j + dst_stride]);
-
-              const __m256i data_lo =
-                  _mm256_unpacklo_epi16(data_ref_0, res_unsigned);
-              const __m256i data_hi =
-                  _mm256_unpackhi_epi16(data_ref_0, res_unsigned);
-
-              const __m256i wt_res_lo = _mm256_madd_epi16(data_lo, wt);
-              const __m256i wt_res_hi = _mm256_madd_epi16(data_hi, wt);
-
-              const __m256i add_lo = _mm256_add_epi32(wt_res_lo, comp_const);
-              const __m256i add_hi = _mm256_add_epi32(wt_res_hi, comp_const);
-              const __m256i fused_lo = _mm256_srai_epi32(add_lo, 8);
-              const __m256i fused_hi = _mm256_srai_epi32(add_hi, 8);
-
-              const __m256i round_result =
-                  _mm256_packs_epi32(fused_lo, fused_hi);
-
-              const __m256i res_8 =
-                  _mm256_packus_epi16(round_result, round_result);
-              const __m128i res_0 = _mm256_castsi256_si128(res_8);
-              const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
-
-              _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
-              _mm_storel_epi64(
-                  (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])),
-                  res_1);
-
-              s[0] = s[1];
-              s[1] = s[2];
-              s[3] = s[4];
-              s[4] = s[5];
-            }
+      for (int j = 0; j < w; j += 8) {
+        JNT_CONVOLVE_HORIZONTAL_FILTER(src_ptr + j, convolve_lowbd_x_4tap,
+                                       coeffs_x + 1);
+        if (do_average) {
+          if (use_dist_wtd_comp_avg) {
+            JNT_CONVOLVE_2D_VERTICAL_FILTER_4TAP(0, 0);
+          } else {
+            JNT_CONVOLVE_2D_VERTICAL_FILTER_4TAP(0, 1);
           }
         } else {
-          for (int j = 0; j < w; j += 8) {
-            JNT_CONVOLVE_HORIZONTAL_FILTER(src_ptr + j, convolve_lowbd_x_4tap,
-                                           coeffs_x + 1);
-
-            /* Vertical filter */
-            __m256i s[6];
-            __m256i s0 =
-                _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));
-            __m256i s1 =
-                _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));
-            __m256i s2 =
-                _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));
-            __m256i s3 =
-                _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));
-
-            s[0] = _mm256_unpacklo_epi16(s0, s1);
-            s[1] = _mm256_unpacklo_epi16(s2, s3);
-
-            s[3] = _mm256_unpackhi_epi16(s0, s1);
-            s[4] = _mm256_unpackhi_epi16(s2, s3);
-
-            for (i = 0; i < h; i += 2) {
-              const int16_t *data = &im_block[i * im_stride];
-
-              const __m256i s4 =
-                  _mm256_loadu_si256((__m256i *)(data + 4 * im_stride));
-              const __m256i s5 =
-                  _mm256_loadu_si256((__m256i *)(data + 5 * im_stride));
-
-              s[2] = _mm256_unpacklo_epi16(s4, s5);
-              s[5] = _mm256_unpackhi_epi16(s4, s5);
-
-              const __m256i res_a_1 = _mm256_madd_epi16(s[0], coeffs_y[1]);
-              const __m256i res_b_1 = _mm256_madd_epi16(s[3], coeffs_y[1]);
-              const __m256i res_a_2 = _mm256_madd_epi16(s[1], coeffs_y[2]);
-              const __m256i res_b_2 = _mm256_madd_epi16(s[4], coeffs_y[2]);
-              const __m256i res_a = _mm256_add_epi32(res_a_1, res_a_2);
-              const __m256i res_b = _mm256_add_epi32(res_b_1, res_b_2);
-
-              const __m256i res_a_round =
-                  _mm256_srai_epi32(_mm256_add_epi32(res_a, round_const_v), 7);
-              const __m256i res_b_round =
-                  _mm256_srai_epi32(_mm256_add_epi32(res_b, round_const_v), 7);
-              const __m256i res_16b =
-                  _mm256_packs_epi32(res_a_round, res_b_round);
-              const __m256i res_unsigned =
-                  _mm256_add_epi16(res_16b, offset_const);
-
-              const __m256i data_ref_0 =
-                  load_line2_avx2(&dst[i * dst_stride + j],
-                                  &dst[i * dst_stride + j + dst_stride]);
-              const __m256i wt_res = _mm256_add_epi16(data_ref_0, res_unsigned);
-              const __m256i comp_avg_res = _mm256_srai_epi16(wt_res, 1);
-
-              const __m256i res_signed =
-                  _mm256_sub_epi16(comp_avg_res, offset_const);
-              const __m256i round_result = _mm256_srai_epi16(
-                  _mm256_add_epi16(res_signed, rounding_const), 4);
-
-              const __m256i res_8 =
-                  _mm256_packus_epi16(round_result, round_result);
-              const __m128i res_0 = _mm256_castsi256_si128(res_8);
-              const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
-
-              _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
-              _mm_storel_epi64(
-                  (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])),
-                  res_1);
-
-              s[0] = s[1];
-              s[1] = s[2];
-              s[3] = s[4];
-              s[4] = s[5];
-            }
-          }
-        }
-      } else {
-        for (int j = 0; j < w; j += 8) {
-          JNT_CONVOLVE_HORIZONTAL_FILTER(src_ptr + j, convolve_lowbd_x_4tap,
-                                         coeffs_x + 1);
-
-          /* Vertical filter */
-          __m256i s[6];
-          __m256i s0 =
-              _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));
-          __m256i s1 =
-              _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));
-          __m256i s2 =
-              _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));
-          __m256i s3 =
-              _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));
-
-          s[0] = _mm256_unpacklo_epi16(s0, s1);
-          s[1] = _mm256_unpacklo_epi16(s2, s3);
-
-          s[3] = _mm256_unpackhi_epi16(s0, s1);
-          s[4] = _mm256_unpackhi_epi16(s2, s3);
-
-          for (i = 0; i < h; i += 2) {
-            const int16_t *data = &im_block[i * im_stride];
-
-            const __m256i s4 =
-                _mm256_loadu_si256((__m256i *)(data + 4 * im_stride));
-            const __m256i s5 =
-                _mm256_loadu_si256((__m256i *)(data + 5 * im_stride));
-
-            s[2] = _mm256_unpacklo_epi16(s4, s5);
-            s[5] = _mm256_unpackhi_epi16(s4, s5);
-
-            const __m256i res_a_1 = _mm256_madd_epi16(s[0], coeffs_y[1]);
-            const __m256i res_b_1 = _mm256_madd_epi16(s[3], coeffs_y[1]);
-            const __m256i res_a_2 = _mm256_madd_epi16(s[1], coeffs_y[2]);
-            const __m256i res_b_2 = _mm256_madd_epi16(s[4], coeffs_y[2]);
-            const __m256i res_a = _mm256_add_epi32(res_a_1, res_a_2);
-            const __m256i res_b = _mm256_add_epi32(res_b_1, res_b_2);
-
-            const __m256i res_a_round =
-                _mm256_srai_epi32(_mm256_add_epi32(res_a, round_const_v), 7);
-            const __m256i res_b_round =
-                _mm256_srai_epi32(_mm256_add_epi32(res_b, round_const_v), 7);
-            const __m256i res_16b =
-                _mm256_packs_epi32(res_a_round, res_b_round);
-            const __m256i res_unsigned =
-                _mm256_add_epi16(res_16b, offset_const);
-
-            const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
-            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
-
-            const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
-            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
-                            res_1);
-
-            s[0] = s[1];
-            s[1] = s[2];
-            s[3] = s[4];
-            s[4] = s[5];
-          }
+          JNT_CONVOLVE_2D_VERTICAL_FILTER_4TAP(0, 2);
         }
       }
     } else {
+      const int j = 0;
+      JNT_CONVOLVE_HORIZONTAL_FILTER(src_ptr, convolve_lowbd_x_4tap,
+                                     coeffs_x + 1);
       if (do_average) {
         if (use_dist_wtd_comp_avg) {
-          const __m256i comp_const = _mm256_set1_epi32(-98176);
-          JNT_CONVOLVE_HORIZONTAL_FILTER(src_ptr, convolve_lowbd_x_4tap,
-                                         coeffs_x + 1);
-
-          /* Vertical filter */
-          __m256i s[3];
-          __m256i s0 =
-              _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));
-          __m256i s1 =
-              _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));
-          __m256i s2 =
-              _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));
-          __m256i s3 =
-              _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));
-
-          s[0] = _mm256_unpacklo_epi16(s0, s1);
-          s[1] = _mm256_unpacklo_epi16(s2, s3);
-
-          for (i = 0; i < h; i += 2) {
-            const int16_t *data = &im_block[i * im_stride];
-
-            const __m256i s4 =
-                _mm256_loadu_si256((__m256i *)(data + 4 * im_stride));
-            const __m256i s5 =
-                _mm256_loadu_si256((__m256i *)(data + 5 * im_stride));
-
-            s[2] = _mm256_unpacklo_epi16(s4, s5);
-
-            const __m256i res_a_1 = _mm256_madd_epi16(s[0], coeffs_y[1]);
-            const __m256i res_a_2 = _mm256_madd_epi16(s[1], coeffs_y[2]);
-            const __m256i res_a = _mm256_add_epi32(res_a_1, res_a_2);
-
-            const __m256i res_a_round =
-                _mm256_srai_epi32(_mm256_add_epi32(res_a, round_const_v), 7);
-            const __m256i res_16b =
-                _mm256_packs_epi32(res_a_round, res_a_round);
-            const __m256i res_unsigned =
-                _mm256_add_epi16(res_16b, offset_const);
-
-            const __m256i data_ref_0 = load_line2_avx2(
-                &dst[i * dst_stride], &dst[i * dst_stride + dst_stride]);
-
-            const __m256i data_lo =
-                _mm256_unpacklo_epi16(data_ref_0, res_unsigned);
-
-            const __m256i wt_res_lo = _mm256_madd_epi16(data_lo, wt);
-
-            const __m256i fused_lo =
-                _mm256_srai_epi32(_mm256_add_epi32(wt_res_lo, comp_const), 8);
-
-            const __m256i round_result = _mm256_packs_epi32(fused_lo, fused_lo);
-
-            const __m256i res_8 =
-                _mm256_packus_epi16(round_result, round_result);
-            const __m128i res_0 = _mm256_castsi256_si128(res_8);
-            const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
-
-            *(int *)(&dst0[i * dst_stride0]) = _mm_cvtsi128_si32(res_0);
-            *(int *)(&dst0[i * dst_stride0 + dst_stride0]) =
-                _mm_cvtsi128_si32(res_1);
-
-            s[0] = s[1];
-            s[1] = s[2];
-          }
+          JNT_CONVOLVE_2D_VERTICAL_FILTER_4TAP(1, 0);
         } else {
-          JNT_CONVOLVE_HORIZONTAL_FILTER(src_ptr, convolve_lowbd_x_4tap,
-                                         coeffs_x + 1);
-
-          /* Vertical filter */
-          __m256i s[3];
-          __m256i s0 =
-              _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));
-          __m256i s1 =
-              _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));
-          __m256i s2 =
-              _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));
-          __m256i s3 =
-              _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));
-
-          s[0] = _mm256_unpacklo_epi16(s0, s1);
-          s[1] = _mm256_unpacklo_epi16(s2, s3);
-
-          for (i = 0; i < h; i += 2) {
-            const int16_t *data = &im_block[i * im_stride];
-
-            const __m256i s4 =
-                _mm256_loadu_si256((__m256i *)(data + 4 * im_stride));
-            const __m256i s5 =
-                _mm256_loadu_si256((__m256i *)(data + 5 * im_stride));
-
-            s[2] = _mm256_unpacklo_epi16(s4, s5);
-
-            const __m256i res_a_1 = _mm256_madd_epi16(s[0], coeffs_y[1]);
-            const __m256i res_a_2 = _mm256_madd_epi16(s[1], coeffs_y[2]);
-            const __m256i res_a = _mm256_add_epi32(res_a_1, res_a_2);
-
-            const __m256i res_a_round =
-                _mm256_srai_epi32(_mm256_add_epi32(res_a, round_const_v), 7);
-            const __m256i res_16b =
-                _mm256_packs_epi32(res_a_round, res_a_round);
-            const __m256i res_unsigned =
-                _mm256_add_epi16(res_16b, offset_const);
-
-            const __m256i data_ref_0 = load_line2_avx2(
-                &dst[i * dst_stride], &dst[i * dst_stride + dst_stride]);
-            const __m256i wt_res = _mm256_add_epi16(data_ref_0, res_unsigned);
-            const __m256i comp_avg_res = _mm256_srai_epi16(wt_res, 1);
-
-            const __m256i res_signed =
-                _mm256_sub_epi16(comp_avg_res, offset_const);
-            const __m256i round_result = _mm256_srai_epi16(
-                _mm256_add_epi16(res_signed, rounding_const), 4);
-
-            const __m256i res_8 =
-                _mm256_packus_epi16(round_result, round_result);
-            const __m128i res_0 = _mm256_castsi256_si128(res_8);
-            const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
-
-            *(int *)(&dst0[i * dst_stride0]) = _mm_cvtsi128_si32(res_0);
-            *(int *)(&dst0[i * dst_stride0 + dst_stride0]) =
-                _mm_cvtsi128_si32(res_1);
-
-            s[0] = s[1];
-            s[1] = s[2];
-          }
+          JNT_CONVOLVE_2D_VERTICAL_FILTER_4TAP(1, 1);
         }
       } else {
-        JNT_CONVOLVE_HORIZONTAL_FILTER(src_ptr, convolve_lowbd_x_4tap,
-                                       coeffs_x + 1);
-
-        /* Vertical filter */
-        __m256i s[3];
-        __m256i s0 = _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));
-        __m256i s1 = _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));
-        __m256i s2 = _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));
-        __m256i s3 = _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));
-
-        s[0] = _mm256_unpacklo_epi16(s0, s1);
-        s[1] = _mm256_unpacklo_epi16(s2, s3);
-
-        for (i = 0; i < h; i += 2) {
-          const int16_t *data = &im_block[i * im_stride];
-
-          const __m256i s4 =
-              _mm256_loadu_si256((__m256i *)(data + 4 * im_stride));
-          const __m256i s5 =
-              _mm256_loadu_si256((__m256i *)(data + 5 * im_stride));
-
-          s[2] = _mm256_unpacklo_epi16(s4, s5);
-
-          const __m256i res_a_1 = _mm256_madd_epi16(s[0], coeffs_y[1]);
-          const __m256i res_a_2 = _mm256_madd_epi16(s[1], coeffs_y[2]);
-          const __m256i res_a = _mm256_add_epi32(res_a_1, res_a_2);
-
-          const __m256i res_a_round =
-              _mm256_srai_epi32(_mm256_add_epi32(res_a, round_const_v), 7);
-          const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_a_round);
-          const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const);
-
-          const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
-          _mm_store_si128((__m128i *)(&dst[i * dst_stride]), res_0);
-
-          const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
-          _mm_store_si128((__m128i *)(&dst[i * dst_stride + dst_stride]),
-                          res_1);
-
-          s[0] = s[1];
-          s[1] = s[2];
-        }
+        JNT_CONVOLVE_2D_VERTICAL_FILTER_4TAP(1, 2);
       }
     }
   } else if (is_horiz_4tap) {
@@ -1046,6 +790,12 @@ void av1_dist_wtd_convolve_2d_avx2(const uint8_t *src, int src_stride,
     const int fo_vert = filter_params_y->taps / 2 - 1;
     const int fo_horiz = 1;
     const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
+    const __m256i wt = unpack_weights_avx2(conv_params);
+    const __m256i offset_const = _mm256_set1_epi16(offset);
+    const int rounding_shift =
+        2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
+    const __m256i rounding_const =
+        _mm256_set1_epi16((1 << rounding_shift) >> 1);
     for (int j = 0; j < w; j += 8) {
       JNT_CONVOLVE_HORIZONTAL_FILTER(src_ptr + j, convolve_lowbd_x_4tap,
                                      coeffs_x + 1);
@@ -1062,53 +812,14 @@ void av1_dist_wtd_convolve_2d_avx2(const uint8_t *src, int src_stride,

     for (int j = 0; j < w; j += 8) {
       JNT_CONVOLVE_HORIZONTAL_FILTER(src_ptr + j, convolve_lowbd_x, coeffs_x);
-
-      /* Vertical filter */
-      __m256i s[6];
-      __m256i s0 = _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));
-      __m256i s1 = _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));
-      __m256i s2 = _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));
-      __m256i s3 = _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));
-
-      s[0] = _mm256_unpacklo_epi16(s0, s1);
-      s[1] = _mm256_unpacklo_epi16(s2, s3);
-
-      s[3] = _mm256_unpackhi_epi16(s0, s1);
-      s[4] = _mm256_unpackhi_epi16(s2, s3);
-
-      for (i = 0; i < h; i += 2) {
-        const int16_t *data = &im_block[i * im_stride];
-
-        const __m256i s4 =
-            _mm256_loadu_si256((__m256i *)(data + 4 * im_stride));
-        const __m256i s5 =
-            _mm256_loadu_si256((__m256i *)(data + 5 * im_stride));
-
-        s[2] = _mm256_unpacklo_epi16(s4, s5);
-        s[5] = _mm256_unpackhi_epi16(s4, s5);
-
-        const __m256i res_a = convolve_4tap(s, coeffs_y + 1);
-        const __m256i res_a_round =
-            _mm256_srai_epi32(_mm256_add_epi32(res_a, round_const_v), 7);
-
-        if (w - j > 4) {
-          const __m256i res_b = convolve_4tap(s + 3, coeffs_y + 1);
-          const __m256i res_b_round =
-              _mm256_srai_epi32(_mm256_add_epi32(res_b, round_const_v), 7);
-          const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_b_round);
-          const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const);
-
-          JNT_CONVOLVE_PROCESS_OUTPUT(res_unsigned, j);
+      if (do_average) {
+        if (use_dist_wtd_comp_avg) {
+          JNT_CONVOLVE_2D_VERTICAL_FILTER_4TAP((w - j <= 4), 0);
         } else {
-          const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_a_round);
-          const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const);
-
-          JNT_CONVOLVE_PROCESS_OUTPUT(res_unsigned, j);
+          JNT_CONVOLVE_2D_VERTICAL_FILTER_4TAP((w - j <= 4), 1);
         }
-        s[0] = s[1];
-        s[1] = s[2];
-        s[3] = s[4];
-        s[4] = s[5];
+      } else {
+        JNT_CONVOLVE_2D_VERTICAL_FILTER_4TAP((w - j <= 4), 2);
       }
     }
   } else {
@@ -1116,6 +827,12 @@ void av1_dist_wtd_convolve_2d_avx2(const uint8_t *src, int src_stride,
     const int fo_vert = filter_params_y->taps / 2 - 1;
     const int fo_horiz = filter_params_x->taps / 2 - 1;
     const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
+    const __m256i wt = unpack_weights_avx2(conv_params);
+    const __m256i offset_const = _mm256_set1_epi16(offset);
+    const int rounding_shift =
+        2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
+    const __m256i rounding_const =
+        _mm256_set1_epi16((1 << rounding_shift) >> 1);

     filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2));
     filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3));
@@ -1127,6 +844,7 @@ void av1_dist_wtd_convolve_2d_avx2(const uint8_t *src, int src_stride,
     }
   }
 }
+#undef JNT_CONVOLVE_2D_VERTICAL_FILTER_4TAP

 #define DO_NO_AVG_2D_COPY_4X16(r0, c0, r1, c1, r2, c2, r3, c3)          \
   do {                                                                  \