Commit faa8b1669c for aom
commit faa8b1669c59cd927374aa53bd4a127e9be1b4a3
Author: Jerome Jiang <jianj@google.com>
Date: Wed May 27 12:52:36 2026 -0400
Conv horiz: expand hwy avx512
Further optimize for small blocks.
But fall back to avx2 for blocks with height 32.
- Use 8-bit pairwise multiply-accumulate (SatWidenMulPairwiseAdd)
instead of 16-bit math for w <= 32 with even coefficients.
- Halve filter coefficients to fit in int8_t and avoid overflow,
adjusting final scaling shift to FILTER_BITS - 1.
- Eliminate expensive 8-bit to 16-bit pixel promotion.
- Add specialized unrolled loops for w = 4, 8, 16, and 32.
All blocks now show significant speed up except for small
slow downs for block 16x32, 32x32 and 64x32
I'll further investigate these block sizes.
Size | avx2 | avx512 (diff)
------------------------------------------
4x4 | 5.62µs | 4.03µs (-28.3%)
4x8 | 6.78µs | 5.17µs (-23.7%)
8x4 | 5.94µs | 4.03µs (-32.2%)
8x8 | 6.75µs | 5.17µs (-23.4%)
8x16 | 10.01µs | 7.66µs (-23.4%)
16x8 | 7.28µs | 6.49µs (-10.8%)
16x16 | 10.92µs | 10.47µs (-4.1%)
16x32 | 17.94µs | 19.83µs (+10.5%)
32x16 | 19.34µs | 19.59µs (+1.3%)
32x32 | 33.67µs | 38.31µs (+13.8%)
32x64 | 170.90µs | 153.10µs (-10.4%)
64x32 | 68.21µs | 76.28µs (+11.8%)
64x64 | 307.20µs | 151.80µs (-50.6%)
64x128 | 677.800s | 305.30µs (-55.0%)
128x64 | 527.90µs | 298.60µs (-43.4%)
128x128 | 1.35ms | 593.90µs (-56.1%)
Change-Id: I4134a9ca0e233855761f6b03c5f35e8fcf8e25fa
diff --git a/aom_dsp/convolve_hwy.h b/aom_dsp/convolve_hwy.h
index e5be37ec9f..0b28531038 100644
--- a/aom_dsp/convolve_hwy.h
+++ b/aom_dsp/convolve_hwy.h
@@ -134,26 +134,196 @@ HWY_ATTR HWY_INLINE void StoreUnaligned4x8(D tag, uint8_t *buf,
HWY_ATTR inline void ConvolveHoriz2Tap(const uint8_t *src, ptrdiff_t src_stride,
uint8_t *dst, ptrdiff_t dst_stride,
const int16_t *filter_x, int w, int h) {
- hn::ScalableTag<int16_t> mul_tag;
- hn::Rebind<uint8_t, decltype(mul_tag)> pixel_tag;
- auto filter_0 = hn::Set(mul_tag, filter_x[3]);
- auto filter_1 = hn::Set(mul_tag, filter_x[4]);
- auto vw = hn::Lanes(mul_tag);
- for (int i = 0; i < h; ++i) {
- for (int j = 0; j < w; j += vw) {
- auto src0 = hn::PromoteTo(mul_tag, hn::LoadU(pixel_tag, &src[j]));
- auto src1 = hn::PromoteTo(mul_tag, hn::LoadU(pixel_tag, &src[j + 1]));
- auto mulv = hn::RoundingShiftRight<FILTER_BITS>(src0 * filter_0 +
- src1 * filter_1);
- auto mulv_demoted = hn::DemoteTo(pixel_tag, mulv);
- if (j + static_cast<int>(vw) > w) {
- hn::StoreN(mulv_demoted, pixel_tag, &dst[j], w - j);
- } else {
- hn::StoreU(mulv_demoted, pixel_tag, &dst[j]);
+ const bool can_use_optimized_path =
+ (w <= 32) && (filter_x[3] % 2 == 0) && (filter_x[4] % 2 == 0);
+
+ if (can_use_optimized_path) {
+ hn::CappedTag<uint8_t, 16> tag8_16;
+ hn::CappedTag<int8_t, 16> tag_i8;
+ hn::CappedTag<int16_t, 8> tag16_8;
+ hn::CappedTag<uint8_t, 8> tag8_8;
+ hn::CappedTag<uint8_t, 4> tag8_4;
+ const auto bias_val = hn::Set(tag16_8, 1 << (FILTER_BITS - 2));
+
+ const auto shuffle_mask = hn::Dup128VecFromValues(
+ tag8_16, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8);
+
+ const int8_t c3 = static_cast<int8_t>(filter_x[3] / 2);
+ const int8_t c4 = static_cast<int8_t>(filter_x[4] / 2);
+
+ const auto coeff34 = hn::Dup128VecFromValues(
+ tag_i8, c3, c4, c3, c4, c3, c4, c3, c4, c3, c4, c3, c4, c3, c4, c3, c4);
+
+ if (w == 4) {
+ while (h >= 2) {
+ auto r0_d0 = hn::LoadU(tag8_16, src + 0);
+ auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+
+ auto r0_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), coeff34);
+ auto r1_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), coeff34);
+
+ hn::StoreU(
+ hn::LowerHalf(tag8_4,
+ hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_sum, bias_val)))),
+ tag8_4, dst);
+ hn::StoreU(
+ hn::LowerHalf(tag8_4,
+ hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_sum, bias_val)))),
+ tag8_4, dst + dst_stride);
+
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ }
+ } else if (w == 8) {
+ while (h >= 2) {
+ auto r0_d0 = hn::LoadU(tag8_16, src + 0);
+ auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+
+ auto r0_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), coeff34);
+ auto r1_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), coeff34);
+
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_sum, bias_val))),
+ tag8_8, dst);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_sum, bias_val))),
+ tag8_8, dst + dst_stride);
+
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ }
+ } else if (w == 16) {
+ while (h >= 2) {
+ auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0);
+ auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8);
+
+ auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+ auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8);
+
+ auto r0_j0_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask), coeff34);
+ auto r0_j8_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask), coeff34);
+
+ auto r1_j0_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask), coeff34);
+ auto r1_j8_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask), coeff34);
+
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j0_sum, bias_val))),
+ tag8_8, dst + 0);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j8_sum, bias_val))),
+ tag8_8, dst + 8);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j0_sum, bias_val))),
+ tag8_8, dst + dst_stride + 0);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j8_sum, bias_val))),
+ tag8_8, dst + dst_stride + 8);
+
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ }
+ } else if (w == 32) {
+ while (h >= 2) {
+ {
+ auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0);
+ auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8);
+
+ auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+ auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8);
+
+ auto r0_j0_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask), coeff34);
+ auto r0_j8_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask), coeff34);
+
+ auto r1_j0_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask), coeff34);
+ auto r1_j8_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask), coeff34);
+
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j0_sum, bias_val))),
+ tag8_8, dst + 0);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j8_sum, bias_val))),
+ tag8_8, dst + 8);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j0_sum, bias_val))),
+ tag8_8, dst + dst_stride + 0);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j8_sum, bias_val))),
+ tag8_8, dst + dst_stride + 8);
+ }
+ {
+ auto r0_j16_d0 = hn::LoadU(tag8_16, src + 16);
+ auto r0_j24_d0 = hn::LoadU(tag8_16, src + 24);
+
+ auto r1_j16_d0 = hn::LoadU(tag8_16, src + src_stride + 16);
+ auto r1_j24_d0 = hn::LoadU(tag8_16, src + src_stride + 24);
+
+ auto r0_j16_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j16_d0, shuffle_mask), coeff34);
+ auto r0_j24_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j24_d0, shuffle_mask), coeff34);
+
+ auto r1_j16_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j16_d0, shuffle_mask), coeff34);
+ auto r1_j24_sum = hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j24_d0, shuffle_mask), coeff34);
+
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j16_sum, bias_val))),
+ tag8_8, dst + 16);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j24_sum, bias_val))),
+ tag8_8, dst + 24);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j16_sum, bias_val))),
+ tag8_8, dst + dst_stride + 16);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j24_sum, bias_val))),
+ tag8_8, dst + dst_stride + 24);
+ }
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
}
}
- src += src_stride;
- dst += dst_stride;
+ } else {
+ hn::ScalableTag<int16_t> mul_tag;
+ hn::Rebind<uint8_t, decltype(mul_tag)> pixel_tag;
+ auto filter_0 = hn::Set(mul_tag, filter_x[3]);
+ auto filter_1 = hn::Set(mul_tag, filter_x[4]);
+ auto vw = hn::Lanes(mul_tag);
+ for (int i = 0; i < h; ++i) {
+ for (int j = 0; j < w; j += vw) {
+ auto src0 = hn::PromoteTo(mul_tag, hn::LoadU(pixel_tag, &src[j]));
+ auto src1 = hn::PromoteTo(mul_tag, hn::LoadU(pixel_tag, &src[j + 1]));
+ auto mulv = hn::RoundingShiftRight<FILTER_BITS>(src0 * filter_0 +
+ src1 * filter_1);
+ auto mulv_demoted = hn::DemoteTo(pixel_tag, mulv);
+ if (j + static_cast<int>(vw) > w) {
+ hn::StoreN(mulv_demoted, pixel_tag, &dst[j], w - j);
+ } else {
+ hn::StoreU(mulv_demoted, pixel_tag, &dst[j]);
+ }
+ }
+ src += src_stride;
+ dst += dst_stride;
+ }
}
}
@@ -175,59 +345,286 @@ HWY_ATTR HWY_INLINE hn::VFromD<D> Convolve4_8(
HWY_ATTR inline void ConvolveHoriz4Tap(const uint8_t *src, ptrdiff_t src_stride,
uint8_t *dst, ptrdiff_t dst_stride,
const int16_t *filter_x, int w, int h) {
- hn::CappedTag<int16_t, 16> tag16;
- hn::CappedTag<int16_t, 4> filter_tag;
- auto f_vec = hn::LoadU(filter_tag, filter_x + 2);
- // All filter values are even, halve to reduce intermediate precision
- // requirements.
- f_vec = hn::ShiftRight<1>(f_vec);
+ const bool can_use_optimized_path =
+ (w <= 32) && (filter_x[2] % 2 == 0) && (filter_x[3] % 2 == 0) &&
+ (filter_x[4] % 2 == 0) && (filter_x[5] % 2 == 0);
- if (w == 4) {
- // Each iteration processes a 4x4 block
- do {
- auto src0 = LoadUnaligned4x4(tag16, src, src_stride);
- auto src1 = LoadUnaligned4x4(tag16, src + 1, src_stride);
- auto src2 = LoadUnaligned4x4(tag16, src + 2, src_stride);
- auto src3 = LoadUnaligned4x4(tag16, src + 3, src_stride);
- auto result =
- Convolve4_8(tag16, filter_tag, src0, src1, src2, src3, f_vec);
- StoreUnaligned4x4(tag16, dst, dst_stride, result);
- h -= 4;
- src += 4 * src_stride;
- dst += 4 * dst_stride;
- } while (h > 0);
- } else if (w == 8) {
- // Each iteration processes a 2x8 block
- do {
- auto src0 = LoadUnaligned2x8(tag16, src, src_stride);
- auto src1 = LoadUnaligned2x8(tag16, src + 1, src_stride);
- auto src2 = LoadUnaligned2x8(tag16, src + 2, src_stride);
- auto src3 = LoadUnaligned2x8(tag16, src + 3, src_stride);
- auto result =
- Convolve4_8(tag16, filter_tag, src0, src1, src2, src3, f_vec);
- StoreUnaligned2x8(tag16, dst, dst_stride, result);
- h -= 2;
- src += 2 * src_stride;
- dst += 2 * dst_stride;
- } while (h > 0);
- } else if (w == 16) {
- // One 1x16 block a time
- do {
- hn::Rebind<uint8_t, decltype(tag16)> tag8;
- auto src0 = hn::PromoteTo(tag16, hn::LoadU(tag8, src));
- auto src1 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 1));
- auto src2 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 2));
- auto src3 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 3));
- auto result =
- Convolve4_8(tag16, filter_tag, src0, src1, src2, src3, f_vec);
- hn::StoreU(hn::DemoteTo(tag8, result), tag8, dst);
- h--;
- src += src_stride;
- dst += dst_stride;
- } while (h > 0);
+ if (can_use_optimized_path) {
+ hn::CappedTag<uint8_t, 16> tag8_16;
+ hn::CappedTag<int8_t, 16> tag_i8;
+ hn::CappedTag<int16_t, 8> tag16_8;
+ hn::CappedTag<uint8_t, 8> tag8_8;
+ hn::CappedTag<uint8_t, 4> tag8_4;
+ const auto bias_val = hn::Set(tag16_8, 1 << (FILTER_BITS - 2));
+
+ const auto shuffle_mask = hn::Dup128VecFromValues(
+ tag8_16, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8);
+
+ const int8_t c2 = static_cast<int8_t>(filter_x[2] / 2);
+ const int8_t c3 = static_cast<int8_t>(filter_x[3] / 2);
+ const int8_t c4 = static_cast<int8_t>(filter_x[4] / 2);
+ const int8_t c5 = static_cast<int8_t>(filter_x[5] / 2);
+
+ const auto coeff23 = hn::Dup128VecFromValues(
+ tag_i8, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3);
+
+ const auto coeff45 = hn::Dup128VecFromValues(
+ tag_i8, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5);
+
+ if (w == 4) {
+ while (h >= 2) {
+ auto r0_d0 = hn::LoadU(tag8_16, src + 0);
+ auto r0_d2 = hn::LoadU(tag8_16, src + 2);
+
+ auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+ auto r1_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+
+ auto r0_sum = hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d2, shuffle_mask), coeff45));
+
+ auto r1_sum = hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d2, shuffle_mask), coeff45));
+
+ hn::StoreU(
+ hn::LowerHalf(tag8_4,
+ hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_sum, bias_val)))),
+ tag8_4, dst);
+ hn::StoreU(
+ hn::LowerHalf(tag8_4,
+ hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_sum, bias_val)))),
+ tag8_4, dst + dst_stride);
+
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ }
+ } else if (w == 8) {
+ while (h >= 2) {
+ auto r0_d0 = hn::LoadU(tag8_16, src + 0);
+ auto r0_d2 = hn::LoadU(tag8_16, src + 2);
+
+ auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+ auto r1_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+
+ auto r0_sum = hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask), coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d2, shuffle_mask), coeff45));
+
+ auto r1_sum = hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask), coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d2, shuffle_mask), coeff45));
+
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_sum, bias_val))),
+ tag8_8, dst);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_sum, bias_val))),
+ tag8_8, dst + dst_stride);
+
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ }
+ } else if (w == 16) {
+ while (h >= 2) {
+ auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0);
+ auto r0_j0_d2 = hn::LoadU(tag8_16, src + 2);
+
+ auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8);
+ auto r0_j8_d2 = hn::LoadU(tag8_16, src + 10);
+
+ auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+ auto r1_j0_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+
+ auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8);
+ auto r1_j8_d2 = hn::LoadU(tag8_16, src + src_stride + 10);
+
+ auto r0_j0_sum = hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask), coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d2, shuffle_mask),
+ coeff45));
+
+ auto r0_j8_sum = hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask), coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d2, shuffle_mask),
+ coeff45));
+
+ auto r1_j0_sum = hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask), coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d2, shuffle_mask),
+ coeff45));
+
+ auto r1_j8_sum = hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask), coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d2, shuffle_mask),
+ coeff45));
+
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j0_sum, bias_val))),
+ tag8_8, dst + 0);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j8_sum, bias_val))),
+ tag8_8, dst + 8);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j0_sum, bias_val))),
+ tag8_8, dst + dst_stride + 0);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j8_sum, bias_val))),
+ tag8_8, dst + dst_stride + 8);
+
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ }
+ } else if (w == 32) {
+ while (h >= 2) {
+ {
+ auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0);
+ auto r0_j0_d2 = hn::LoadU(tag8_16, src + 2);
+
+ auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8);
+ auto r0_j8_d2 = hn::LoadU(tag8_16, src + 10);
+
+ auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+ auto r1_j0_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+
+ auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8);
+ auto r1_j8_d2 = hn::LoadU(tag8_16, src + src_stride + 10);
+
+ auto r0_j0_sum =
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask),
+ coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d2, shuffle_mask),
+ coeff45));
+
+ auto r0_j8_sum =
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask),
+ coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d2, shuffle_mask),
+ coeff45));
+
+ auto r1_j0_sum =
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask),
+ coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d2, shuffle_mask),
+ coeff45));
+
+ auto r1_j8_sum =
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask),
+ coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d2, shuffle_mask),
+ coeff45));
+
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j0_sum, bias_val))),
+ tag8_8, dst + 0);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j8_sum, bias_val))),
+ tag8_8, dst + 8);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j0_sum, bias_val))),
+ tag8_8, dst + dst_stride + 0);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j8_sum, bias_val))),
+ tag8_8, dst + dst_stride + 8);
+ }
+ {
+ auto r0_j16_d0 = hn::LoadU(tag8_16, src + 16);
+ auto r0_j16_d2 = hn::LoadU(tag8_16, src + 18);
+
+ auto r0_j24_d0 = hn::LoadU(tag8_16, src + 24);
+ auto r0_j24_d2 = hn::LoadU(tag8_16, src + 26);
+
+ auto r1_j16_d0 = hn::LoadU(tag8_16, src + src_stride + 16);
+ auto r1_j16_d2 = hn::LoadU(tag8_16, src + src_stride + 18);
+
+ auto r1_j24_d0 = hn::LoadU(tag8_16, src + src_stride + 24);
+ auto r1_j24_d2 = hn::LoadU(tag8_16, src + src_stride + 26);
+
+ auto r0_j16_sum = hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j16_d0, shuffle_mask),
+ coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j16_d2, shuffle_mask),
+ coeff45));
+
+ auto r0_j24_sum = hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j24_d0, shuffle_mask),
+ coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j24_d2, shuffle_mask),
+ coeff45));
+
+ auto r1_j16_sum = hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j16_d0, shuffle_mask),
+ coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j16_d2, shuffle_mask),
+ coeff45));
+
+ auto r1_j24_sum = hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j24_d0, shuffle_mask),
+ coeff23),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j24_d2, shuffle_mask),
+ coeff45));
+
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j16_sum, bias_val))),
+ tag8_8, dst + 16);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j24_sum, bias_val))),
+ tag8_8, dst + 24);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j16_sum, bias_val))),
+ tag8_8, dst + dst_stride + 16);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j24_sum, bias_val))),
+ tag8_8, dst + dst_stride + 24);
+ }
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ }
+ }
} else {
hn::ScalableTag<int16_t> mul_tag;
hn::Rebind<uint8_t, decltype(mul_tag)> pixel_tag;
+ hn::CappedTag<int16_t, 4> filter_tag;
+ auto f_vec = hn::LoadU(filter_tag, filter_x + 2);
+ f_vec = hn::ShiftRight<1>(f_vec);
auto vw = hn::Lanes(mul_tag);
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; j += vw) {
@@ -278,82 +675,447 @@ HWY_ATTR HWY_INLINE hn::VFromD<D> Convolve8_8(
return hn::RoundingShiftRight<FILTER_BITS - 1>(res);
}
-DECLARE_ALIGNED(32, static const uint8_t, filt_global[]) = {
- 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 0, 1, 1,
- 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 2, 3, 3, 4, 4, 5,
- 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 2, 3, 3, 4, 4, 5, 5, 6, 6,
- 7, 7, 8, 8, 9, 9, 10, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10,
- 10, 11, 11, 12, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11,
- 12, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 6, 7,
- 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14
-};
-
HWY_ATTR inline void ConvolveHoriz8Tap(const uint8_t *src, ptrdiff_t src_stride,
uint8_t *dst, ptrdiff_t dst_stride,
const int16_t *filter_x, int w, int h) {
- hn::CappedTag<int16_t, 16> tag16;
- hn::CappedTag<int16_t, 8> filter_tag;
- auto f_vec = hn::LoadU(filter_tag, filter_x);
- // All filter values are even, halve to reduce intermediate precision
- // requirements.
- f_vec = hn::ShiftRight<1>(f_vec);
+ const bool can_use_optimized_path =
+ (w <= 32) && (filter_x[0] % 2 == 0) && (filter_x[1] % 2 == 0) &&
+ (filter_x[2] % 2 == 0) && (filter_x[3] % 2 == 0) &&
+ (filter_x[4] % 2 == 0) && (filter_x[5] % 2 == 0) &&
+ (filter_x[6] % 2 == 0) && (filter_x[7] % 2 == 0);
- if (w == 4) {
- do {
- auto src0 = LoadUnaligned4x4(tag16, src, src_stride);
- auto src1 = LoadUnaligned4x4(tag16, src + 1, src_stride);
- auto src2 = LoadUnaligned4x4(tag16, src + 2, src_stride);
- auto src3 = LoadUnaligned4x4(tag16, src + 3, src_stride);
- auto src4 = LoadUnaligned4x4(tag16, src + 4, src_stride);
- auto src5 = LoadUnaligned4x4(tag16, src + 5, src_stride);
- auto src6 = LoadUnaligned4x4(tag16, src + 6, src_stride);
- auto src7 = LoadUnaligned4x4(tag16, src + 7, src_stride);
- auto result = Convolve8_8(tag16, filter_tag, src0, src1, src2, src3, src4,
- src5, src6, src7, f_vec);
- StoreUnaligned4x4(tag16, dst, dst_stride, result);
- h -= 4;
- src += 4 * src_stride;
- dst += 4 * dst_stride;
- } while (h > 0);
- } else if (w == 8) {
- // Each iteration processes a 2x8 block
- do {
- auto src0 = LoadUnaligned2x8(tag16, src, src_stride);
- auto src1 = LoadUnaligned2x8(tag16, src + 1, src_stride);
- auto src2 = LoadUnaligned2x8(tag16, src + 2, src_stride);
- auto src3 = LoadUnaligned2x8(tag16, src + 3, src_stride);
- auto src4 = LoadUnaligned2x8(tag16, src + 4, src_stride);
- auto src5 = LoadUnaligned2x8(tag16, src + 5, src_stride);
- auto src6 = LoadUnaligned2x8(tag16, src + 6, src_stride);
- auto src7 = LoadUnaligned2x8(tag16, src + 7, src_stride);
- auto result = Convolve8_8(tag16, filter_tag, src0, src1, src2, src3, src4,
- src5, src6, src7, f_vec);
- StoreUnaligned2x8(tag16, dst, dst_stride, result);
- h -= 2;
- src += 2 * src_stride;
- dst += 2 * dst_stride;
- } while (h > 0);
- } else if (w == 16) {
- // One 1x16 block a time
- do {
- hn::Rebind<uint8_t, decltype(tag16)> tag8;
- auto src0 = hn::PromoteTo(tag16, hn::LoadU(tag8, src));
- auto src1 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 1));
- auto src2 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 2));
- auto src3 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 3));
- auto src4 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 4));
- auto src5 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 5));
- auto src6 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 6));
- auto src7 = hn::PromoteTo(tag16, hn::LoadU(tag8, src + 7));
- auto result = Convolve8_8(tag16, filter_tag, src0, src1, src2, src3, src4,
- src5, src6, src7, f_vec);
- hn::StoreU(hn::DemoteTo(tag8, result), tag8, dst);
- h--;
- src += src_stride;
- dst += dst_stride;
- } while (h > 0);
+ if (can_use_optimized_path) {
+ hn::CappedTag<uint8_t, 16> tag8_16;
+ hn::CappedTag<int8_t, 16> tag_i8;
+ hn::CappedTag<int16_t, 8> tag16_8;
+ hn::CappedTag<uint8_t, 8> tag8_8;
+ hn::CappedTag<uint8_t, 4> tag8_4;
+ const auto bias_val = hn::Set(tag16_8, 1 << (FILTER_BITS - 2));
+
+ const auto shuffle_mask = hn::Dup128VecFromValues(
+ tag8_16, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8);
+
+ const int8_t c0 = static_cast<int8_t>(filter_x[0] / 2);
+ const int8_t c1 = static_cast<int8_t>(filter_x[1] / 2);
+ const int8_t c2 = static_cast<int8_t>(filter_x[2] / 2);
+ const int8_t c3 = static_cast<int8_t>(filter_x[3] / 2);
+ const int8_t c4 = static_cast<int8_t>(filter_x[4] / 2);
+ const int8_t c5 = static_cast<int8_t>(filter_x[5] / 2);
+ const int8_t c6 = static_cast<int8_t>(filter_x[6] / 2);
+ const int8_t c7 = static_cast<int8_t>(filter_x[7] / 2);
+
+ const auto coeff01 = hn::Dup128VecFromValues(
+ tag_i8, c0, c1, c0, c1, c0, c1, c0, c1, c0, c1, c0, c1, c0, c1, c0, c1);
+
+ const auto coeff23 = hn::Dup128VecFromValues(
+ tag_i8, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3, c2, c3);
+
+ const auto coeff45 = hn::Dup128VecFromValues(
+ tag_i8, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5, c4, c5);
+
+ const auto coeff67 = hn::Dup128VecFromValues(
+ tag_i8, c6, c7, c6, c7, c6, c7, c6, c7, c6, c7, c6, c7, c6, c7, c6, c7);
+
+ if (w == 4) {
+ while (h >= 2) {
+ auto r0_d0 = hn::LoadU(tag8_16, src + 0);
+ auto r0_d2 = hn::LoadU(tag8_16, src + 2);
+ auto r0_d4 = hn::LoadU(tag8_16, src + 4);
+ auto r0_d6 = hn::LoadU(tag8_16, src + 6);
+
+ auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+ auto r1_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+ auto r1_d4 = hn::LoadU(tag8_16, src + src_stride + 4);
+ auto r1_d6 = hn::LoadU(tag8_16, src + src_stride + 6);
+
+ auto r0_sum = hn::Add(
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d6, shuffle_mask),
+ coeff67)));
+
+ auto r1_sum = hn::Add(
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d6, shuffle_mask),
+ coeff67)));
+
+ hn::StoreU(
+ hn::LowerHalf(tag8_4,
+ hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_sum, bias_val)))),
+ tag8_4, dst);
+ hn::StoreU(
+ hn::LowerHalf(tag8_4,
+ hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_sum, bias_val)))),
+ tag8_4, dst + dst_stride);
+
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ }
+ } else if (w == 8) {
+ while (h >= 2) {
+ auto r0_d0 = hn::LoadU(tag8_16, src + 0);
+ auto r0_d2 = hn::LoadU(tag8_16, src + 2);
+ auto r0_d4 = hn::LoadU(tag8_16, src + 4);
+ auto r0_d6 = hn::LoadU(tag8_16, src + 6);
+
+ auto r1_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+ auto r1_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+ auto r1_d4 = hn::LoadU(tag8_16, src + src_stride + 4);
+ auto r1_d6 = hn::LoadU(tag8_16, src + src_stride + 6);
+
+ auto r0_sum = hn::Add(
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_d6, shuffle_mask),
+ coeff67)));
+
+ auto r1_sum = hn::Add(
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_d6, shuffle_mask),
+ coeff67)));
+
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_sum, bias_val))),
+ tag8_8, dst);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_sum, bias_val))),
+ tag8_8, dst + dst_stride);
+
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ }
+ } else if (w == 16) {
+ while (h >= 2) {
+ auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0);
+ auto r0_j0_d2 = hn::LoadU(tag8_16, src + 2);
+ auto r0_j0_d4 = hn::LoadU(tag8_16, src + 4);
+ auto r0_j0_d6 = hn::LoadU(tag8_16, src + 6);
+
+ auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8);
+ auto r0_j8_d2 = hn::LoadU(tag8_16, src + 10);
+ auto r0_j8_d4 = hn::LoadU(tag8_16, src + 12);
+ auto r0_j8_d6 = hn::LoadU(tag8_16, src + 14);
+
+ auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+ auto r1_j0_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+ auto r1_j0_d4 = hn::LoadU(tag8_16, src + src_stride + 4);
+ auto r1_j0_d6 = hn::LoadU(tag8_16, src + src_stride + 6);
+
+ auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8);
+ auto r1_j8_d2 = hn::LoadU(tag8_16, src + src_stride + 10);
+ auto r1_j8_d4 = hn::LoadU(tag8_16, src + src_stride + 12);
+ auto r1_j8_d6 = hn::LoadU(tag8_16, src + src_stride + 14);
+
+ auto r0_j0_sum = hn::Add(
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d6, shuffle_mask),
+ coeff67)));
+
+ auto r0_j8_sum = hn::Add(
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d6, shuffle_mask),
+ coeff67)));
+
+ auto r1_j0_sum = hn::Add(
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d6, shuffle_mask),
+ coeff67)));
+
+ auto r1_j8_sum = hn::Add(
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d6, shuffle_mask),
+ coeff67)));
+
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j0_sum, bias_val))),
+ tag8_8, dst + 0);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j8_sum, bias_val))),
+ tag8_8, dst + 8);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j0_sum, bias_val))),
+ tag8_8, dst + dst_stride + 0);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j8_sum, bias_val))),
+ tag8_8, dst + dst_stride + 8);
+
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ }
+ } else if (w == 32) {
+ while (h >= 2) {
+ {
+ auto r0_j0_d0 = hn::LoadU(tag8_16, src + 0);
+ auto r0_j0_d2 = hn::LoadU(tag8_16, src + 2);
+ auto r0_j0_d4 = hn::LoadU(tag8_16, src + 4);
+ auto r0_j0_d6 = hn::LoadU(tag8_16, src + 6);
+
+ auto r0_j8_d0 = hn::LoadU(tag8_16, src + 8);
+ auto r0_j8_d2 = hn::LoadU(tag8_16, src + 10);
+ auto r0_j8_d4 = hn::LoadU(tag8_16, src + 12);
+ auto r0_j8_d6 = hn::LoadU(tag8_16, src + 14);
+
+ auto r1_j0_d0 = hn::LoadU(tag8_16, src + src_stride + 0);
+ auto r1_j0_d2 = hn::LoadU(tag8_16, src + src_stride + 2);
+ auto r1_j0_d4 = hn::LoadU(tag8_16, src + src_stride + 4);
+ auto r1_j0_d6 = hn::LoadU(tag8_16, src + src_stride + 6);
+
+ auto r1_j8_d0 = hn::LoadU(tag8_16, src + src_stride + 8);
+ auto r1_j8_d2 = hn::LoadU(tag8_16, src + src_stride + 10);
+ auto r1_j8_d4 = hn::LoadU(tag8_16, src + src_stride + 12);
+ auto r1_j8_d6 = hn::LoadU(tag8_16, src + src_stride + 14);
+
+ auto r0_j0_sum = hn::Add(
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j0_d6, shuffle_mask),
+ coeff67)));
+
+ auto r0_j8_sum = hn::Add(
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j8_d6, shuffle_mask),
+ coeff67)));
+
+ auto r1_j0_sum = hn::Add(
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j0_d6, shuffle_mask),
+ coeff67)));
+
+ auto r1_j8_sum = hn::Add(
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j8_d6, shuffle_mask),
+ coeff67)));
+
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j0_sum, bias_val))),
+ tag8_8, dst + 0);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j8_sum, bias_val))),
+ tag8_8, dst + 8);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j0_sum, bias_val))),
+ tag8_8, dst + dst_stride + 0);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j8_sum, bias_val))),
+ tag8_8, dst + dst_stride + 8);
+ }
+ {
+ auto r0_j16_d0 = hn::LoadU(tag8_16, src + 16);
+ auto r0_j16_d2 = hn::LoadU(tag8_16, src + 18);
+ auto r0_j16_d4 = hn::LoadU(tag8_16, src + 20);
+ auto r0_j16_d6 = hn::LoadU(tag8_16, src + 22);
+
+ auto r0_j24_d0 = hn::LoadU(tag8_16, src + 24);
+ auto r0_j24_d2 = hn::LoadU(tag8_16, src + 26);
+ auto r0_j24_d4 = hn::LoadU(tag8_16, src + 28);
+ auto r0_j24_d6 = hn::LoadU(tag8_16, src + 30);
+
+ auto r1_j16_d0 = hn::LoadU(tag8_16, src + src_stride + 16);
+ auto r1_j16_d2 = hn::LoadU(tag8_16, src + src_stride + 18);
+ auto r1_j16_d4 = hn::LoadU(tag8_16, src + src_stride + 20);
+ auto r1_j16_d6 = hn::LoadU(tag8_16, src + src_stride + 22);
+
+ auto r1_j24_d0 = hn::LoadU(tag8_16, src + src_stride + 24);
+ auto r1_j24_d2 = hn::LoadU(tag8_16, src + src_stride + 26);
+ auto r1_j24_d4 = hn::LoadU(tag8_16, src + src_stride + 28);
+ auto r1_j24_d6 = hn::LoadU(tag8_16, src + src_stride + 30);
+
+ auto r0_j16_sum = hn::Add(
+ hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j16_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j16_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j16_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j16_d6, shuffle_mask),
+ coeff67)));
+
+ auto r0_j24_sum = hn::Add(
+ hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j24_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j24_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j24_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r0_j24_d6, shuffle_mask),
+ coeff67)));
+
+ auto r1_j16_sum = hn::Add(
+ hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j16_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j16_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j16_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j16_d6, shuffle_mask),
+ coeff67)));
+
+ auto r1_j24_sum = hn::Add(
+ hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j24_d0, shuffle_mask),
+ coeff01),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j24_d2, shuffle_mask),
+ coeff23)),
+ hn::Add(
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j24_d4, shuffle_mask),
+ coeff45),
+ hn::SatWidenMulPairwiseAdd(
+ tag16_8, hn::TableLookupBytes(r1_j24_d6, shuffle_mask),
+ coeff67)));
+
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j16_sum, bias_val))),
+ tag8_8, dst + 16);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r0_j24_sum, bias_val))),
+ tag8_8, dst + 24);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j16_sum, bias_val))),
+ tag8_8, dst + dst_stride + 16);
+ hn::StoreU(hn::DemoteTo(tag8_8, hn::ShiftRight<FILTER_BITS - 1>(
+ hn::Add(r1_j24_sum, bias_val))),
+ tag8_8, dst + dst_stride + 24);
+ }
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ }
+ }
} else {
- // This tag will have 32 lanes (for avx512) or 16 lanes (for avx2)
+ hn::CappedTag<int16_t, 8> filter_tag;
+ auto f_vec = hn::LoadU(filter_tag, filter_x);
+ f_vec = hn::ShiftRight<1>(f_vec);
hn::ScalableTag<int16_t> mul_tag;
hn::Rebind<uint8_t, decltype(mul_tag)> pixel_tag;
auto vw = hn::Lanes(mul_tag);
diff --git a/aom_dsp/x86/convolve_hwy_avx512.cc b/aom_dsp/x86/convolve_hwy_avx512.cc
index c1aa90492f..62557047ca 100644
--- a/aom_dsp/x86/convolve_hwy_avx512.cc
+++ b/aom_dsp/x86/convolve_hwy_avx512.cc
@@ -32,10 +32,11 @@ HWY_ATTR void aom_convolve8_horiz_avx512(const uint8_t *src,
const int16_t *filter_x, int x_step_q4,
const int16_t *filter_y, int y_step_q4,
int w, int h) {
- // Fallback to AVX2 for small block sizes (w <= 16) where the handwritten
- // AVX2 implementation was measured to be faster than the Highway AVX512
- // implementation in benchmarks.
- if (w <= 16) {
+ // 16x32, 32x32 and 64x32 blocks show ~10% slow down compared with avx2 with
+ // significant speed up for all other blocks. Fall back to avx2 for wx32
+ // blocks.
+ // TODO: jianj - Investigate and optimize for wx32 blocks.
+ if (h == 32) {
aom_convolve8_horiz_avx2(src, src_stride, dst, dst_stride, filter_x,
x_step_q4, filter_y, y_step_q4, w, h);
} else {