Commit faa8b1669c for aom

commit faa8b1669c59cd927374aa53bd4a127e9be1b4a3
Author: Jerome Jiang <jianj@google.com>
Date:   Wed May 27 12:52:36 2026 -0400

    Conv horiz: expand hwy avx512

    Further optimize for small blocks.
    But fall back to avx2 for blocks with height 32.

    - Use 8-bit pairwise multiply-accumulate (SatWidenMulPairwiseAdd)
      instead of 16-bit math for w <= 32 with even coefficients.
    - Halve filter coefficients to fit in int8_t and avoid overflow,
      adjusting final scaling shift to FILTER_BITS - 1.
    - Eliminate expensive 8-bit to 16-bit pixel promotion.
    - Add specialized unrolled loops for w = 4, 8, 16, and 32.

    All blocks now show significant speed up except for small
    slow downs for block 16x32, 32x32 and 64x32

    I'll further investigate these block sizes.

    Size       | avx2      | avx512 (diff)
    ------------------------------------------
    4x4        |    5.62µs |    4.03µs (-28.3%)
    4x8        |    6.78µs |    5.17µs (-23.7%)
    8x4        |    5.94µs |    4.03µs (-32.2%)
    8x8        |    6.75µs |    5.17µs (-23.4%)
    8x16       |   10.01µs |    7.66µs (-23.4%)
    16x8       |    7.28µs |    6.49µs (-10.8%)
    16x16      |   10.92µs |    10.47µs (-4.1%)
    16x32      |   17.94µs |   19.83µs (+10.5%)
    32x16      |   19.34µs |    19.59µs (+1.3%)
    32x32      |   33.67µs |   38.31µs (+13.8%)
    32x64      |  170.90µs |  153.10µs (-10.4%)
    64x32      |   68.21µs |   76.28µs (+11.8%)
    64x64      |  307.20µs |  151.80µs (-50.6%)
    64x128     |  677.800s |  305.30µs (-55.0%)
    128x64     |  527.90µs |  298.60µs (-43.4%)
    128x128    |    1.35ms |  593.90µs (-56.1%)

    Change-Id: I4134a9ca0e233855761f6b03c5f35e8fcf8e25fa

diff --git a/aom_dsp/convolve_hwy.h b/aom_dsp/convolve_hwy.h
index e5be37ec9f..0b28531038 100644
--- a/aom_dsp/convolve_hwy.h
+++ b/aom_dsp/convolve_hwy.h
@@ -134,26 +134,196 @@ HWY_ATTR HWY_INLINE void StoreUnaligned4x8(D tag, uint8_t *buf,
 HWY_ATTR inline void ConvolveHoriz2Tap(const uint8_t *src, ptrdiff_t src_stride,
                                        uint8_t *dst, ptrdiff_t dst_stride,
                                        const int16_t *filter_x, int w, int h) {
-  hn::ScalableTag<int16_t> mul_tag;
-  hn::Rebind<uint8_t, decltype(mul_tag)> pixel_tag;
-  auto filter_0 = hn::Set(mul_tag, filter_x[3]);
-  auto filter_1 = hn::Set(mul_tag, filter_x[4]);
-  auto vw = hn::Lanes(mul_tag);
-  for (int i = 0; i < h; ++i) {
-    for (int j = 0; j < w; j += vw) {
-      auto src0 = hn::PromoteTo(mul_tag, hn::LoadU(pixel_tag, &src[j]));
-      auto src1 = hn::PromoteTo(mul_tag, hn::LoadU(pixel_tag, &src[j + 1]));
-      auto mulv = hn::RoundingShiftRight<FILTER_BITS>(src0 * filter_0 +
-                                                      src1 * filter_1);
-      auto mulv_demoted = hn::DemoteTo(pixel_tag, mulv);
-      if (j + static_cast<int>(vw) > w) {
-        hn::StoreN(mulv_demoted, pixel_tag, &dst[j], w - j);
-      } else {
-        hn::StoreU(mulv_demoted, pixel_tag, &dst[j]);
+  const bool can_use_optimized_path =
+      (w <= 32) && (filter_x[3] % 2 == 0) && (filter_x[4] % 2 == 0);
+
+  if (can_use_optimized_path) {
+    hn::CappedTag<uint8_t, 16> tag8_16;
+    hn::CappedTag<int8_t, 16> tag_i8;
+    hn::CappedTag<int16_t, 8> tag16_8;
+    hn::CappedTag<uint8_t, 8> tag8_8;
+    hn::CappedTag<uint8_t, 4> tag8_4;
+    const auto bias_val = hn::Set(tag16_8, 1 << (FILTER_BITS - 2));
+
+    const auto shuffle_mask = hn::Dup128VecFromValues(
+        tag8_16, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8);
+
+    const int8_t c3 = static_cast<int8_t>(filter_x[3] / 2);
+    const int8_t c4 = static_cast<int8_t>(filter_x[4] / 2);
+
+    const auto coeff34 = hn::Dup128VecFromValues(
+        tag_i8, c3, c4, c3, c4, c3, c4, c3, c4, c3, c4, c3, c4, c3, c4, c3, c4);
+
+    if (w == 4) {
+      while (h >= 2) {
+        auto r0_d0 = hn::LoadU(tag8_16, src + 0);
+        auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+
+        auto r0_sum = hn::SatWidenMulPairwiseAdd(
+            tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), coeff34);
+        auto r1_sum = hn::SatWidenMulPairwiseAdd(
+            tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), coeff34);
+
+        hn::StoreU(
+            hn::LowerHalf(tag8_4,
+                          hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                                   hn::Add(r0_sum, bias_val)))),
+            tag8_4, dst);
+        hn::StoreU(
+            hn::LowerHalf(tag8_4,
+                          hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                                   hn::Add(r1_sum, bias_val)))),
+            tag8_4, dst + dst_stride);
+
+        src += 2 * src_stride;
+        dst += 2 * dst_stride;
+        h -= 2;
+      }
+    } else if (w == 8) {
+      while (h >= 2) {
+        auto r0_d0 = hn::LoadU(tag8_16, src + 0);
+        auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+
+        auto r0_sum = hn::SatWidenMulPairwiseAdd(
+            tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), coeff34);
+        auto r1_sum = hn::SatWidenMulPairwiseAdd(
+            tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), coeff34);
+
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r0_sum, bias_val))),
+                   tag8_8, dst);
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r1_sum, bias_val))),
+                   tag8_8, dst + dst_stride);
+
+        src += 2 * src_stride;
+        dst += 2 * dst_stride;
+        h -= 2;
+      }
+    } else if (w == 16) {
+      while (h >= 2) {
+        auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0);
+        auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8);
+
+        auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+        auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8);
+
+        auto r0_j0_sum = hn::SatWidenMulPairwiseAdd(
+            tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask), coeff34);
+        auto r0_j8_sum = hn::SatWidenMulPairwiseAdd(
+            tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask), coeff34);
+
+        auto r1_j0_sum = hn::SatWidenMulPairwiseAdd(
+            tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask), coeff34);
+        auto r1_j8_sum = hn::SatWidenMulPairwiseAdd(
+            tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask), coeff34);
+
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r0_j0_sum, bias_val))),
+                   tag8_8, dst + 0);
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r0_j8_sum, bias_val))),
+                   tag8_8, dst + 8);
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r1_j0_sum, bias_val))),
+                   tag8_8, dst + dst_stride + 0);
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r1_j8_sum, bias_val))),
+                   tag8_8, dst + dst_stride + 8);
+
+        src += 2 * src_stride;
+        dst += 2 * dst_stride;
+        h -= 2;
+      }
+    } else if (w == 32) {
+      while (h >= 2) {
+        {
+          auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0);
+          auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8);
+
+          auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+          auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8);
+
+          auto r0_j0_sum = hn::SatWidenMulPairwiseAdd(
+              tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask), coeff34);
+          auto r0_j8_sum = hn::SatWidenMulPairwiseAdd(
+              tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask), coeff34);
+
+          auto r1_j0_sum = hn::SatWidenMulPairwiseAdd(
+              tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask), coeff34);
+          auto r1_j8_sum = hn::SatWidenMulPairwiseAdd(
+              tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask), coeff34);
+
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r0_j0_sum, bias_val))),
+                     tag8_8, dst + 0);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r0_j8_sum, bias_val))),
+                     tag8_8, dst + 8);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r1_j0_sum, bias_val))),
+                     tag8_8, dst + dst_stride + 0);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r1_j8_sum, bias_val))),
+                     tag8_8, dst + dst_stride + 8);
+        }
+        {
+          auto r0_j16_d0 = hn::LoadU(tag8_16, src + 16);
+          auto r0_j24_d0 = hn::LoadU(tag8_16, src + 24);
+
+          auto r1_j16_d0 = hn::LoadU(tag8_16, src + src_stride + 16);
+          auto r1_j24_d0 = hn::LoadU(tag8_16, src + src_stride + 24);
+
+          auto r0_j16_sum = hn::SatWidenMulPairwiseAdd(
+              tag16_8, hn::TableLookupBytes(r0_j16_d0, shuffle_mask), coeff34);
+          auto r0_j24_sum = hn::SatWidenMulPairwiseAdd(
+              tag16_8, hn::TableLookupBytes(r0_j24_d0, shuffle_mask), coeff34);
+
+          auto r1_j16_sum = hn::SatWidenMulPairwiseAdd(
+              tag16_8, hn::TableLookupBytes(r1_j16_d0, shuffle_mask), coeff34);
+          auto r1_j24_sum = hn::SatWidenMulPairwiseAdd(
+              tag16_8, hn::TableLookupBytes(r1_j24_d0, shuffle_mask), coeff34);
+
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r0_j16_sum, bias_val))),
+                     tag8_8, dst + 16);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r0_j24_sum, bias_val))),
+                     tag8_8, dst + 24);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r1_j16_sum, bias_val))),
+                     tag8_8, dst + dst_stride + 16);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r1_j24_sum, bias_val))),
+                     tag8_8, dst + dst_stride + 24);
+        }
+        src += 2 * src_stride;
+        dst += 2 * dst_stride;
+        h -= 2;
       }
     }
-    src += src_stride;
-    dst += dst_stride;
+  } else {
+    hn::ScalableTag<int16_t> mul_tag;
+    hn::Rebind<uint8_t, decltype(mul_tag)> pixel_tag;
+    auto filter_0 = hn::Set(mul_tag, filter_x[3]);
+    auto filter_1 = hn::Set(mul_tag, filter_x[4]);
+    auto vw = hn::Lanes(mul_tag);
+    for (int i = 0; i < h; ++i) {
+      for (int j = 0; j < w; j += vw) {
+        auto src0 = hn::PromoteTo(mul_tag, hn::LoadU(pixel_tag, &src[j]));
+        auto src1 = hn::PromoteTo(mul_tag, hn::LoadU(pixel_tag, &src[j + 1]));
+        auto mulv = hn::RoundingShiftRight<FILTER_BITS>(src0 * filter_0 +
+                                                        src1 * filter_1);
+        auto mulv_demoted = hn::DemoteTo(pixel_tag, mulv);
+        if (j + static_cast<int>(vw) > w) {
+          hn::StoreN(mulv_demoted, pixel_tag, &dst[j], w - j);
+        } else {
+          hn::StoreU(mulv_demoted, pixel_tag, &dst[j]);
+        }
+      }
+      src += src_stride;
+      dst += dst_stride;
+    }
   }
 }

@@ -175,59 +345,286 @@ HWY_ATTR HWY_INLINE hn::VFromD<D> Convolve4_8(
 HWY_ATTR inline void ConvolveHoriz4Tap(const uint8_t *src, ptrdiff_t src_stride,
                                        uint8_t *dst, ptrdiff_t dst_stride,
                                        const int16_t *filter_x, int w, int h) {
-  hn::CappedTag<int16_t, 16> tag16;
-  hn::CappedTag<int16_t, 4> filter_tag;
-  auto f_vec = hn::LoadU(filter_tag, filter_x + 2);
-  // All filter values are even, halve to reduce intermediate precision
-  // requirements.
-  f_vec = hn::ShiftRight<1>(f_vec);
+  const bool can_use_optimized_path =
+      (w <= 32) && (filter_x[2] % 2 == 0) && (filter_x[3] % 2 == 0) &&
+      (filter_x[4] % 2 == 0) && (filter_x[5] % 2 == 0);

-  if (w == 4) {
-    // Each iteration processes a 4x4 block
-    do {
-      auto src0 = LoadUnaligned4x4(tag16, src, src_stride);
-      auto src1 = LoadUnaligned4x4(tag16, src + 1, src_stride);
-      auto src2 = LoadUnaligned4x4(tag16, src + 2, src_stride);
-      auto src3 = LoadUnaligned4x4(tag16, src + 3, src_stride);
-      auto result =
-          Convolve4_8(tag16, filter_tag, src0, src1, src2, src3, f_vec);
-      StoreUnaligned4x4(tag16, dst, dst_stride, result);
-      h -= 4;
-      src += 4 * src_stride;
-      dst += 4 * dst_stride;
-    } while (h > 0);
-  } else if (w == 8) {
-    // Each iteration processes a 2x8 block
-    do {
-      auto src0 = LoadUnaligned2x8(tag16, src, src_stride);
-      auto src1 = LoadUnaligned2x8(tag16, src + 1, src_stride);
-      auto src2 = LoadUnaligned2x8(tag16, src + 2, src_stride);
-      auto src3 = LoadUnaligned2x8(tag16, src + 3, src_stride);
-      auto result =
-          Convolve4_8(tag16, filter_tag, src0, src1, src2, src3, f_vec);
-      StoreUnaligned2x8(tag16, dst, dst_stride, result);
-      h -= 2;
-      src += 2 * src_stride;
-      dst += 2 * dst_stride;
-    } while (h > 0);
-  } else if (w == 16) {
-    // One 1x16 block a time
-    do {
-      hn::Rebind<uint8_t, decltype(tag16)> tag8;
-      auto src0 = hn::PromoteTo(tag16, hn::LoadU(tag8, src));
-      auto src1 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 1));
-      auto src2 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 2));
-      auto src3 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 3));
-      auto result =
-          Convolve4_8(tag16, filter_tag, src0, src1, src2, src3, f_vec);
-      hn::StoreU(hn::DemoteTo(tag8, result), tag8, dst);
-      h--;
-      src += src_stride;
-      dst += dst_stride;
-    } while (h > 0);
+  if (can_use_optimized_path) {
+    hn::CappedTag<uint8_t, 16> tag8_16;
+    hn::CappedTag<int8_t, 16> tag_i8;
+    hn::CappedTag<int16_t, 8> tag16_8;
+    hn::CappedTag<uint8_t, 8> tag8_8;
+    hn::CappedTag<uint8_t, 4> tag8_4;
+    const auto bias_val = hn::Set(tag16_8, 1 << (FILTER_BITS - 2));
+
+    const auto shuffle_mask = hn::Dup128VecFromValues(
+        tag8_16, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8);
+
+    const int8_t c2 = static_cast<int8_t>(filter_x[2] / 2);
+    const int8_t c3 = static_cast<int8_t>(filter_x[3] / 2);
+    const int8_t c4 = static_cast<int8_t>(filter_x[4] / 2);
+    const int8_t c5 = static_cast<int8_t>(filter_x[5] / 2);
+
+    const auto coeff23 = hn::Dup128VecFromValues(
+        tag_i8, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3);
+
+    const auto coeff45 = hn::Dup128VecFromValues(
+        tag_i8, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5);
+
+    if (w == 4) {
+      while (h >= 2) {
+        auto r0_d0 = hn::LoadU(tag8_16, src + 0);
+        auto r0_d2 = hn::LoadU(tag8_16, src + 2);
+
+        auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+        auto r1_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+
+        auto r0_sum = hn::Add(
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), coeff23),
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r0_d2, shuffle_mask), coeff45));
+
+        auto r1_sum = hn::Add(
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), coeff23),
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r1_d2, shuffle_mask), coeff45));
+
+        hn::StoreU(
+            hn::LowerHalf(tag8_4,
+                          hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                                   hn::Add(r0_sum, bias_val)))),
+            tag8_4, dst);
+        hn::StoreU(
+            hn::LowerHalf(tag8_4,
+                          hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                                   hn::Add(r1_sum, bias_val)))),
+            tag8_4, dst + dst_stride);
+
+        src += 2 * src_stride;
+        dst += 2 * dst_stride;
+        h -= 2;
+      }
+    } else if (w == 8) {
+      while (h >= 2) {
+        auto r0_d0 = hn::LoadU(tag8_16, src + 0);
+        auto r0_d2 = hn::LoadU(tag8_16, src + 2);
+
+        auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+        auto r1_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+
+        auto r0_sum = hn::Add(
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), coeff23),
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r0_d2, shuffle_mask), coeff45));
+
+        auto r1_sum = hn::Add(
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), coeff23),
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r1_d2, shuffle_mask), coeff45));
+
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r0_sum, bias_val))),
+                   tag8_8, dst);
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r1_sum, bias_val))),
+                   tag8_8, dst + dst_stride);
+
+        src += 2 * src_stride;
+        dst += 2 * dst_stride;
+        h -= 2;
+      }
+    } else if (w == 16) {
+      while (h >= 2) {
+        auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0);
+        auto r0_j0_d2 = hn::LoadU(tag8_16, src + 2);
+
+        auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8);
+        auto r0_j8_d2 = hn::LoadU(tag8_16, src + 10);
+
+        auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+        auto r1_j0_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+
+        auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8);
+        auto r1_j8_d2 = hn::LoadU(tag8_16, src + src_stride + 10);
+
+        auto r0_j0_sum = hn::Add(
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask), coeff23),
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r0_j0_d2, shuffle_mask),
+                coeff45));
+
+        auto r0_j8_sum = hn::Add(
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask), coeff23),
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r0_j8_d2, shuffle_mask),
+                coeff45));
+
+        auto r1_j0_sum = hn::Add(
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask), coeff23),
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r1_j0_d2, shuffle_mask),
+                coeff45));
+
+        auto r1_j8_sum = hn::Add(
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask), coeff23),
+            hn::SatWidenMulPairwiseAdd(
+                tag16_8, hn::TableLookupBytes(r1_j8_d2, shuffle_mask),
+                coeff45));
+
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r0_j0_sum, bias_val))),
+                   tag8_8, dst + 0);
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r0_j8_sum, bias_val))),
+                   tag8_8, dst + 8);
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r1_j0_sum, bias_val))),
+                   tag8_8, dst + dst_stride + 0);
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r1_j8_sum, bias_val))),
+                   tag8_8, dst + dst_stride + 8);
+
+        src += 2 * src_stride;
+        dst += 2 * dst_stride;
+        h -= 2;
+      }
+    } else if (w == 32) {
+      while (h >= 2) {
+        {
+          auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0);
+          auto r0_j0_d2 = hn::LoadU(tag8_16, src + 2);
+
+          auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8);
+          auto r0_j8_d2 = hn::LoadU(tag8_16, src + 10);
+
+          auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+          auto r1_j0_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+
+          auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8);
+          auto r1_j8_d2 = hn::LoadU(tag8_16, src + src_stride + 10);
+
+          auto r0_j0_sum =
+              hn::Add(hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask),
+                          coeff23),
+                      hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r0_j0_d2, shuffle_mask),
+                          coeff45));
+
+          auto r0_j8_sum =
+              hn::Add(hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask),
+                          coeff23),
+                      hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r0_j8_d2, shuffle_mask),
+                          coeff45));
+
+          auto r1_j0_sum =
+              hn::Add(hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask),
+                          coeff23),
+                      hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r1_j0_d2, shuffle_mask),
+                          coeff45));
+
+          auto r1_j8_sum =
+              hn::Add(hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask),
+                          coeff23),
+                      hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r1_j8_d2, shuffle_mask),
+                          coeff45));
+
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r0_j0_sum, bias_val))),
+                     tag8_8, dst + 0);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r0_j8_sum, bias_val))),
+                     tag8_8, dst + 8);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r1_j0_sum, bias_val))),
+                     tag8_8, dst + dst_stride + 0);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r1_j8_sum, bias_val))),
+                     tag8_8, dst + dst_stride + 8);
+        }
+        {
+          auto r0_j16_d0 = hn::LoadU(tag8_16, src + 16);
+          auto r0_j16_d2 = hn::LoadU(tag8_16, src + 18);
+
+          auto r0_j24_d0 = hn::LoadU(tag8_16, src + 24);
+          auto r0_j24_d2 = hn::LoadU(tag8_16, src + 26);
+
+          auto r1_j16_d0 = hn::LoadU(tag8_16, src + src_stride + 16);
+          auto r1_j16_d2 = hn::LoadU(tag8_16, src + src_stride + 18);
+
+          auto r1_j24_d0 = hn::LoadU(tag8_16, src + src_stride + 24);
+          auto r1_j24_d2 = hn::LoadU(tag8_16, src + src_stride + 26);
+
+          auto r0_j16_sum = hn::Add(
+              hn::SatWidenMulPairwiseAdd(
+                  tag16_8, hn::TableLookupBytes(r0_j16_d0, shuffle_mask),
+                  coeff23),
+              hn::SatWidenMulPairwiseAdd(
+                  tag16_8, hn::TableLookupBytes(r0_j16_d2, shuffle_mask),
+                  coeff45));
+
+          auto r0_j24_sum = hn::Add(
+              hn::SatWidenMulPairwiseAdd(
+                  tag16_8, hn::TableLookupBytes(r0_j24_d0, shuffle_mask),
+                  coeff23),
+              hn::SatWidenMulPairwiseAdd(
+                  tag16_8, hn::TableLookupBytes(r0_j24_d2, shuffle_mask),
+                  coeff45));
+
+          auto r1_j16_sum = hn::Add(
+              hn::SatWidenMulPairwiseAdd(
+                  tag16_8, hn::TableLookupBytes(r1_j16_d0, shuffle_mask),
+                  coeff23),
+              hn::SatWidenMulPairwiseAdd(
+                  tag16_8, hn::TableLookupBytes(r1_j16_d2, shuffle_mask),
+                  coeff45));
+
+          auto r1_j24_sum = hn::Add(
+              hn::SatWidenMulPairwiseAdd(
+                  tag16_8, hn::TableLookupBytes(r1_j24_d0, shuffle_mask),
+                  coeff23),
+              hn::SatWidenMulPairwiseAdd(
+                  tag16_8, hn::TableLookupBytes(r1_j24_d2, shuffle_mask),
+                  coeff45));
+
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r0_j16_sum, bias_val))),
+                     tag8_8, dst + 16);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r0_j24_sum, bias_val))),
+                     tag8_8, dst + 24);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r1_j16_sum, bias_val))),
+                     tag8_8, dst + dst_stride + 16);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r1_j24_sum, bias_val))),
+                     tag8_8, dst + dst_stride + 24);
+        }
+        src += 2 * src_stride;
+        dst += 2 * dst_stride;
+        h -= 2;
+      }
+    }
   } else {
     hn::ScalableTag<int16_t> mul_tag;
     hn::Rebind<uint8_t, decltype(mul_tag)> pixel_tag;
+    hn::CappedTag<int16_t, 4> filter_tag;
+    auto f_vec = hn::LoadU(filter_tag, filter_x + 2);
+    f_vec = hn::ShiftRight<1>(f_vec);
     auto vw = hn::Lanes(mul_tag);
     for (int i = 0; i < h; ++i) {
       for (int j = 0; j < w; j += vw) {
@@ -278,82 +675,447 @@ HWY_ATTR HWY_INLINE hn::VFromD<D> Convolve8_8(
   return hn::RoundingShiftRight<FILTER_BITS - 1>(res);
 }

-DECLARE_ALIGNED(32, static const uint8_t, filt_global[]) = {
-  0,  1,  1,  2,  2, 3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,  0,  1,  1,
-  2,  2,  3,  3,  4, 4,  5,  5,  6,  6,  7,  7,  8,  2,  3,  3,  4,  4,  5,
-  5,  6,  6,  7,  7, 8,  8,  9,  9,  10, 2,  3,  3,  4,  4,  5,  5,  6,  6,
-  7,  7,  8,  8,  9, 9,  10, 4,  5,  5,  6,  6,  7,  7,  8,  8,  9,  9,  10,
-  10, 11, 11, 12, 4, 5,  5,  6,  6,  7,  7,  8,  8,  9,  9,  10, 10, 11, 11,
-  12, 6,  7,  7,  8, 8,  9,  9,  10, 10, 11, 11, 12, 12, 13, 13, 14, 6,  7,
-  7,  8,  8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 14
-};
-
 HWY_ATTR inline void ConvolveHoriz8Tap(const uint8_t *src, ptrdiff_t src_stride,
                                        uint8_t *dst, ptrdiff_t dst_stride,
                                        const int16_t *filter_x, int w, int h) {
-  hn::CappedTag<int16_t, 16> tag16;
-  hn::CappedTag<int16_t, 8> filter_tag;
-  auto f_vec = hn::LoadU(filter_tag, filter_x);
-  // All filter values are even, halve to reduce intermediate precision
-  // requirements.
-  f_vec = hn::ShiftRight<1>(f_vec);
+  const bool can_use_optimized_path =
+      (w <= 32) && (filter_x[0] % 2 == 0) && (filter_x[1] % 2 == 0) &&
+      (filter_x[2] % 2 == 0) && (filter_x[3] % 2 == 0) &&
+      (filter_x[4] % 2 == 0) && (filter_x[5] % 2 == 0) &&
+      (filter_x[6] % 2 == 0) && (filter_x[7] % 2 == 0);

-  if (w == 4) {
-    do {
-      auto src0 = LoadUnaligned4x4(tag16, src, src_stride);
-      auto src1 = LoadUnaligned4x4(tag16, src + 1, src_stride);
-      auto src2 = LoadUnaligned4x4(tag16, src + 2, src_stride);
-      auto src3 = LoadUnaligned4x4(tag16, src + 3, src_stride);
-      auto src4 = LoadUnaligned4x4(tag16, src + 4, src_stride);
-      auto src5 = LoadUnaligned4x4(tag16, src + 5, src_stride);
-      auto src6 = LoadUnaligned4x4(tag16, src + 6, src_stride);
-      auto src7 = LoadUnaligned4x4(tag16, src + 7, src_stride);
-      auto result = Convolve8_8(tag16, filter_tag, src0, src1, src2, src3, src4,
-                                src5, src6, src7, f_vec);
-      StoreUnaligned4x4(tag16, dst, dst_stride, result);
-      h -= 4;
-      src += 4 * src_stride;
-      dst += 4 * dst_stride;
-    } while (h > 0);
-  } else if (w == 8) {
-    // Each iteration processes a 2x8 block
-    do {
-      auto src0 = LoadUnaligned2x8(tag16, src, src_stride);
-      auto src1 = LoadUnaligned2x8(tag16, src + 1, src_stride);
-      auto src2 = LoadUnaligned2x8(tag16, src + 2, src_stride);
-      auto src3 = LoadUnaligned2x8(tag16, src + 3, src_stride);
-      auto src4 = LoadUnaligned2x8(tag16, src + 4, src_stride);
-      auto src5 = LoadUnaligned2x8(tag16, src + 5, src_stride);
-      auto src6 = LoadUnaligned2x8(tag16, src + 6, src_stride);
-      auto src7 = LoadUnaligned2x8(tag16, src + 7, src_stride);
-      auto result = Convolve8_8(tag16, filter_tag, src0, src1, src2, src3, src4,
-                                src5, src6, src7, f_vec);
-      StoreUnaligned2x8(tag16, dst, dst_stride, result);
-      h -= 2;
-      src += 2 * src_stride;
-      dst += 2 * dst_stride;
-    } while (h > 0);
-  } else if (w == 16) {
-    // One 1x16 block a time
-    do {
-      hn::Rebind<uint8_t, decltype(tag16)> tag8;
-      auto src0 = hn::PromoteTo(tag16, hn::LoadU(tag8, src));
-      auto src1 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 1));
-      auto src2 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 2));
-      auto src3 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 3));
-      auto src4 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 4));
-      auto src5 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 5));
-      auto src6 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 6));
-      auto src7 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 7));
-      auto result = Convolve8_8(tag16, filter_tag, src0, src1, src2, src3, src4,
-                                src5, src6, src7, f_vec);
-      hn::StoreU(hn::DemoteTo(tag8, result), tag8, dst);
-      h--;
-      src += src_stride;
-      dst += dst_stride;
-    } while (h > 0);
+  if (can_use_optimized_path) {
+    hn::CappedTag<uint8_t, 16> tag8_16;
+    hn::CappedTag<int8_t, 16> tag_i8;
+    hn::CappedTag<int16_t, 8> tag16_8;
+    hn::CappedTag<uint8_t, 8> tag8_8;
+    hn::CappedTag<uint8_t, 4> tag8_4;
+    const auto bias_val = hn::Set(tag16_8, 1 << (FILTER_BITS - 2));
+
+    const auto shuffle_mask = hn::Dup128VecFromValues(
+        tag8_16, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8);
+
+    const int8_t c0 = static_cast<int8_t>(filter_x[0] / 2);
+    const int8_t c1 = static_cast<int8_t>(filter_x[1] / 2);
+    const int8_t c2 = static_cast<int8_t>(filter_x[2] / 2);
+    const int8_t c3 = static_cast<int8_t>(filter_x[3] / 2);
+    const int8_t c4 = static_cast<int8_t>(filter_x[4] / 2);
+    const int8_t c5 = static_cast<int8_t>(filter_x[5] / 2);
+    const int8_t c6 = static_cast<int8_t>(filter_x[6] / 2);
+    const int8_t c7 = static_cast<int8_t>(filter_x[7] / 2);
+
+    const auto coeff01 = hn::Dup128VecFromValues(
+        tag_i8, c0, c1, c0, c1, c0, c1, c0, c1, c0, c1, c0, c1, c0, c1, c0, c1);
+
+    const auto coeff23 = hn::Dup128VecFromValues(
+        tag_i8, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3);
+
+    const auto coeff45 = hn::Dup128VecFromValues(
+        tag_i8, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5);
+
+    const auto coeff67 = hn::Dup128VecFromValues(
+        tag_i8, c6, c7, c6, c7, c6, c7, c6, c7, c6, c7, c6, c7, c6, c7, c6, c7);
+
+    if (w == 4) {
+      while (h >= 2) {
+        auto r0_d0 = hn::LoadU(tag8_16, src + 0);
+        auto r0_d2 = hn::LoadU(tag8_16, src + 2);
+        auto r0_d4 = hn::LoadU(tag8_16, src + 4);
+        auto r0_d6 = hn::LoadU(tag8_16, src + 6);
+
+        auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+        auto r1_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+        auto r1_d4 = hn::LoadU(tag8_16, src + src_stride + 4);
+        auto r1_d6 = hn::LoadU(tag8_16, src + src_stride + 6);
+
+        auto r0_sum = hn::Add(
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask),
+                        coeff01),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_d2, shuffle_mask),
+                        coeff23)),
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_d4, shuffle_mask),
+                        coeff45),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_d6, shuffle_mask),
+                        coeff67)));
+
+        auto r1_sum = hn::Add(
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask),
+                        coeff01),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_d2, shuffle_mask),
+                        coeff23)),
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_d4, shuffle_mask),
+                        coeff45),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_d6, shuffle_mask),
+                        coeff67)));
+
+        hn::StoreU(
+            hn::LowerHalf(tag8_4,
+                          hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                                   hn::Add(r0_sum, bias_val)))),
+            tag8_4, dst);
+        hn::StoreU(
+            hn::LowerHalf(tag8_4,
+                          hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                                   hn::Add(r1_sum, bias_val)))),
+            tag8_4, dst + dst_stride);
+
+        src += 2 * src_stride;
+        dst += 2 * dst_stride;
+        h -= 2;
+      }
+    } else if (w == 8) {
+      while (h >= 2) {
+        auto r0_d0 = hn::LoadU(tag8_16, src + 0);
+        auto r0_d2 = hn::LoadU(tag8_16, src + 2);
+        auto r0_d4 = hn::LoadU(tag8_16, src + 4);
+        auto r0_d6 = hn::LoadU(tag8_16, src + 6);
+
+        auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+        auto r1_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+        auto r1_d4 = hn::LoadU(tag8_16, src + src_stride + 4);
+        auto r1_d6 = hn::LoadU(tag8_16, src + src_stride + 6);
+
+        auto r0_sum = hn::Add(
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask),
+                        coeff01),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_d2, shuffle_mask),
+                        coeff23)),
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_d4, shuffle_mask),
+                        coeff45),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_d6, shuffle_mask),
+                        coeff67)));
+
+        auto r1_sum = hn::Add(
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask),
+                        coeff01),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_d2, shuffle_mask),
+                        coeff23)),
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_d4, shuffle_mask),
+                        coeff45),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_d6, shuffle_mask),
+                        coeff67)));
+
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r0_sum, bias_val))),
+                   tag8_8, dst);
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r1_sum, bias_val))),
+                   tag8_8, dst + dst_stride);
+
+        src += 2 * src_stride;
+        dst += 2 * dst_stride;
+        h -= 2;
+      }
+    } else if (w == 16) {
+      while (h >= 2) {
+        auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0);
+        auto r0_j0_d2 = hn::LoadU(tag8_16, src + 2);
+        auto r0_j0_d4 = hn::LoadU(tag8_16, src + 4);
+        auto r0_j0_d6 = hn::LoadU(tag8_16, src + 6);
+
+        auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8);
+        auto r0_j8_d2 = hn::LoadU(tag8_16, src + 10);
+        auto r0_j8_d4 = hn::LoadU(tag8_16, src + 12);
+        auto r0_j8_d6 = hn::LoadU(tag8_16, src + 14);
+
+        auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+        auto r1_j0_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+        auto r1_j0_d4 = hn::LoadU(tag8_16, src + src_stride + 4);
+        auto r1_j0_d6 = hn::LoadU(tag8_16, src + src_stride + 6);
+
+        auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8);
+        auto r1_j8_d2 = hn::LoadU(tag8_16, src + src_stride + 10);
+        auto r1_j8_d4 = hn::LoadU(tag8_16, src + src_stride + 12);
+        auto r1_j8_d6 = hn::LoadU(tag8_16, src + src_stride + 14);
+
+        auto r0_j0_sum = hn::Add(
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask),
+                        coeff01),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_j0_d2, shuffle_mask),
+                        coeff23)),
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_j0_d4, shuffle_mask),
+                        coeff45),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_j0_d6, shuffle_mask),
+                        coeff67)));
+
+        auto r0_j8_sum = hn::Add(
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask),
+                        coeff01),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_j8_d2, shuffle_mask),
+                        coeff23)),
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_j8_d4, shuffle_mask),
+                        coeff45),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r0_j8_d6, shuffle_mask),
+                        coeff67)));
+
+        auto r1_j0_sum = hn::Add(
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask),
+                        coeff01),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_j0_d2, shuffle_mask),
+                        coeff23)),
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_j0_d4, shuffle_mask),
+                        coeff45),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_j0_d6, shuffle_mask),
+                        coeff67)));
+
+        auto r1_j8_sum = hn::Add(
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask),
+                        coeff01),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_j8_d2, shuffle_mask),
+                        coeff23)),
+            hn::Add(hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_j8_d4, shuffle_mask),
+                        coeff45),
+                    hn::SatWidenMulPairwiseAdd(
+                        tag16_8, hn::TableLookupBytes(r1_j8_d6, shuffle_mask),
+                        coeff67)));
+
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r0_j0_sum, bias_val))),
+                   tag8_8, dst + 0);
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r0_j8_sum, bias_val))),
+                   tag8_8, dst + 8);
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r1_j0_sum, bias_val))),
+                   tag8_8, dst + dst_stride + 0);
+        hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                            hn::Add(r1_j8_sum, bias_val))),
+                   tag8_8, dst + dst_stride + 8);
+
+        src += 2 * src_stride;
+        dst += 2 * dst_stride;
+        h -= 2;
+      }
+    } else if (w == 32) {
+      while (h >= 2) {
+        {
+          auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0);
+          auto r0_j0_d2 = hn::LoadU(tag8_16, src + 2);
+          auto r0_j0_d4 = hn::LoadU(tag8_16, src + 4);
+          auto r0_j0_d6 = hn::LoadU(tag8_16, src + 6);
+
+          auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8);
+          auto r0_j8_d2 = hn::LoadU(tag8_16, src + 10);
+          auto r0_j8_d4 = hn::LoadU(tag8_16, src + 12);
+          auto r0_j8_d6 = hn::LoadU(tag8_16, src + 14);
+
+          auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+          auto r1_j0_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+          auto r1_j0_d4 = hn::LoadU(tag8_16, src + src_stride + 4);
+          auto r1_j0_d6 = hn::LoadU(tag8_16, src + src_stride + 6);
+
+          auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8);
+          auto r1_j8_d2 = hn::LoadU(tag8_16, src + src_stride + 10);
+          auto r1_j8_d4 = hn::LoadU(tag8_16, src + src_stride + 12);
+          auto r1_j8_d6 = hn::LoadU(tag8_16, src + src_stride + 14);
+
+          auto r0_j0_sum = hn::Add(
+              hn::Add(hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask),
+                          coeff01),
+                      hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r0_j0_d2, shuffle_mask),
+                          coeff23)),
+              hn::Add(hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r0_j0_d4, shuffle_mask),
+                          coeff45),
+                      hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r0_j0_d6, shuffle_mask),
+                          coeff67)));
+
+          auto r0_j8_sum = hn::Add(
+              hn::Add(hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask),
+                          coeff01),
+                      hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r0_j8_d2, shuffle_mask),
+                          coeff23)),
+              hn::Add(hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r0_j8_d4, shuffle_mask),
+                          coeff45),
+                      hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r0_j8_d6, shuffle_mask),
+                          coeff67)));
+
+          auto r1_j0_sum = hn::Add(
+              hn::Add(hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask),
+                          coeff01),
+                      hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r1_j0_d2, shuffle_mask),
+                          coeff23)),
+              hn::Add(hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r1_j0_d4, shuffle_mask),
+                          coeff45),
+                      hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r1_j0_d6, shuffle_mask),
+                          coeff67)));
+
+          auto r1_j8_sum = hn::Add(
+              hn::Add(hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask),
+                          coeff01),
+                      hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r1_j8_d2, shuffle_mask),
+                          coeff23)),
+              hn::Add(hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r1_j8_d4, shuffle_mask),
+                          coeff45),
+                      hn::SatWidenMulPairwiseAdd(
+                          tag16_8, hn::TableLookupBytes(r1_j8_d6, shuffle_mask),
+                          coeff67)));
+
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r0_j0_sum, bias_val))),
+                     tag8_8, dst + 0);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r0_j8_sum, bias_val))),
+                     tag8_8, dst + 8);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r1_j0_sum, bias_val))),
+                     tag8_8, dst + dst_stride + 0);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r1_j8_sum, bias_val))),
+                     tag8_8, dst + dst_stride + 8);
+        }
+        {
+          auto r0_j16_d0 = hn::LoadU(tag8_16, src + 16);
+          auto r0_j16_d2 = hn::LoadU(tag8_16, src + 18);
+          auto r0_j16_d4 = hn::LoadU(tag8_16, src + 20);
+          auto r0_j16_d6 = hn::LoadU(tag8_16, src + 22);
+
+          auto r0_j24_d0 = hn::LoadU(tag8_16, src + 24);
+          auto r0_j24_d2 = hn::LoadU(tag8_16, src + 26);
+          auto r0_j24_d4 = hn::LoadU(tag8_16, src + 28);
+          auto r0_j24_d6 = hn::LoadU(tag8_16, src + 30);
+
+          auto r1_j16_d0 = hn::LoadU(tag8_16, src + src_stride + 16);
+          auto r1_j16_d2 = hn::LoadU(tag8_16, src + src_stride + 18);
+          auto r1_j16_d4 = hn::LoadU(tag8_16, src + src_stride + 20);
+          auto r1_j16_d6 = hn::LoadU(tag8_16, src + src_stride + 22);
+
+          auto r1_j24_d0 = hn::LoadU(tag8_16, src + src_stride + 24);
+          auto r1_j24_d2 = hn::LoadU(tag8_16, src + src_stride + 26);
+          auto r1_j24_d4 = hn::LoadU(tag8_16, src + src_stride + 28);
+          auto r1_j24_d6 = hn::LoadU(tag8_16, src + src_stride + 30);
+
+          auto r0_j16_sum = hn::Add(
+              hn::Add(
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r0_j16_d0, shuffle_mask),
+                      coeff01),
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r0_j16_d2, shuffle_mask),
+                      coeff23)),
+              hn::Add(
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r0_j16_d4, shuffle_mask),
+                      coeff45),
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r0_j16_d6, shuffle_mask),
+                      coeff67)));
+
+          auto r0_j24_sum = hn::Add(
+              hn::Add(
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r0_j24_d0, shuffle_mask),
+                      coeff01),
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r0_j24_d2, shuffle_mask),
+                      coeff23)),
+              hn::Add(
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r0_j24_d4, shuffle_mask),
+                      coeff45),
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r0_j24_d6, shuffle_mask),
+                      coeff67)));
+
+          auto r1_j16_sum = hn::Add(
+              hn::Add(
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r1_j16_d0, shuffle_mask),
+                      coeff01),
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r1_j16_d2, shuffle_mask),
+                      coeff23)),
+              hn::Add(
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r1_j16_d4, shuffle_mask),
+                      coeff45),
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r1_j16_d6, shuffle_mask),
+                      coeff67)));
+
+          auto r1_j24_sum = hn::Add(
+              hn::Add(
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r1_j24_d0, shuffle_mask),
+                      coeff01),
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r1_j24_d2, shuffle_mask),
+                      coeff23)),
+              hn::Add(
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r1_j24_d4, shuffle_mask),
+                      coeff45),
+                  hn::SatWidenMulPairwiseAdd(
+                      tag16_8, hn::TableLookupBytes(r1_j24_d6, shuffle_mask),
+                      coeff67)));
+
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r0_j16_sum, bias_val))),
+                     tag8_8, dst + 16);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r0_j24_sum, bias_val))),
+                     tag8_8, dst + 24);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r1_j16_sum, bias_val))),
+                     tag8_8, dst + dst_stride + 16);
+          hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+                                              hn::Add(r1_j24_sum, bias_val))),
+                     tag8_8, dst + dst_stride + 24);
+        }
+        src += 2 * src_stride;
+        dst += 2 * dst_stride;
+        h -= 2;
+      }
+    }
   } else {
-    // This tag will have 32 lanes (for avx512) or 16 lanes (for avx2)
+    hn::CappedTag<int16_t, 8> filter_tag;
+    auto f_vec = hn::LoadU(filter_tag, filter_x);
+    f_vec = hn::ShiftRight<1>(f_vec);
     hn::ScalableTag<int16_t> mul_tag;
     hn::Rebind<uint8_t, decltype(mul_tag)> pixel_tag;
     auto vw = hn::Lanes(mul_tag);
diff --git a/aom_dsp/x86/convolve_hwy_avx512.cc b/aom_dsp/x86/convolve_hwy_avx512.cc
index c1aa90492f..62557047ca 100644
--- a/aom_dsp/x86/convolve_hwy_avx512.cc
+++ b/aom_dsp/x86/convolve_hwy_avx512.cc
@@ -32,10 +32,11 @@ HWY_ATTR void aom_convolve8_horiz_avx512(const uint8_t *src,
                                          const int16_t *filter_x, int x_step_q4,
                                          const int16_t *filter_y, int y_step_q4,
                                          int w, int h) {
-  // Fallback to AVX2 for small block sizes (w <= 16) where the handwritten
-  // AVX2 implementation was measured to be faster than the Highway AVX512
-  // implementation in benchmarks.
-  if (w <= 16) {
+  // 16x32, 32x32 and 64x32 blocks show ~10% slow down compared with avx2 with
+  // significant speed up for all other blocks. Fall back to avx2 for wx32
+  // blocks.
+  // TODO: jianj - Investigate and optimize for wx32 blocks.
+  if (h == 32) {
     aom_convolve8_horiz_avx2(src, src_stride, dst, dst_stride, filter_x,
                              x_step_q4, filter_y, y_step_q4, w, h);
   } else {