Commit 32059d384c for aom
commit 32059d384c261a9719bf78e5ffd7d86c63841de5
Author: Jerome Jiang <jianj@google.com>
Date: Thu Jun 11 13:05:34 2026 -0400
Optimize Highway SIMD for av1_warp_affine
1. Replaces template function pointer parameters with
compile-time enum/boolean template routing to eliminate
indirect call overhead.
2. Decays 2D LUT arrays to 1D pointers to completely bypass
UBSan array bounds check overhead.
3. Replaces AVX2 vector concatenation permutes with contiguous
unaligned loads (hn::LoadU) to eliminate Port 5 permute
congestion.
4. Accumulates vertical dot products directly onto
res_add_const vertical rounding bias, saving post-filtering
vector additions.
5. Replaces row indexing multiplications in StoreRows with
strided pointer additions.
The result below is an aggregation of a/b/g/d values for each
block.
AVX2 Speedup:
Block | Before | After | Speedup
--------+---------+---------+---------
4x4 | 61 n | 45 n | 1.37x
4x8 | 79 n | 58 n | 1.35x
4x16 | 141 n | 106 n | 1.33x
8x4 | 62 n | 45 n | 1.38x
8x8 | 80 n | 59 n | 1.37x
8x16 | 145 n | 107 n | 1.36x
16x8 | 144 n | 106 n | 1.36x
16x16 | 272 n | 201 n | 1.35x
16x32 | 530 n | 392 n | 1.35x
32x8 | 270 n | 200 n | 1.35x
32x16 | 524 n | 389 n | 1.35x
32x32 | 1.0 u | 776 n | 1.34x
32x64 | 2.1 u | 1.5 u | 1.33x
64x32 | 2.1 u | 1.5 u | 1.35x
64x64 | 4.1 u | 3.1 u | 1.34x
64x128 | 8.3 u | 6.4 u | 1.30x
128x64 | 8.6 u | 6.8 u | 1.27x
128x128| 17.2 u | 14.1 u | 1.22x
AVX512 Speedup:
Block | Before | After | Speedup
--------+---------+---------+---------
4x4 | 63 n | 46 n | 1.38x
4x8 | 79 n | 56 n | 1.40x
4x16 | 139 n | 102 n | 1.36x
8x4 | 64 n | 46 n | 1.41x
8x8 | 81 n | 56 n | 1.45x
8x16 | 143 n | 102 n | 1.41x
16x8 | 140 n | 100 n | 1.40x
16x16 | 263 n | 190 n | 1.38x
16x32 | 507 n | 370 n | 1.37x
32x8 | 259 n | 189 n | 1.38x
32x16 | 500 n | 367 n | 1.36x
32x32 | 984 n | 726 n | 1.35x
32x64 | 1.9 u | 1.4 u | 1.35x
64x32 | 1.9 u | 1.4 u | 1.34x
64x64 | 3.8 u | 2.9 u | 1.34x
64x128 | 7.7 u | 6.0 u | 1.29x
128x64 | 8.1 u | 6.2 u | 1.30x
128x128| 16.3 u | 13.0 u | 1.25x
Change-Id: I3e5ed8ccc0a0dd23c97e1ca9d89ad6f508a4a590
diff --git a/av1/common/warp_plane_hwy.h b/av1/common/warp_plane_hwy.h
index 61d5fc6bcd..da0d5450bc 100644
--- a/av1/common/warp_plane_hwy.h
+++ b/av1/common/warp_plane_hwy.h
@@ -47,6 +47,27 @@ constexpr hn::FixedTag<int64_t, 2> int64x2_tag;
constexpr hn::ScalableTag<int8_t> coeff_tag;
+HWY_ATTR HWY_INLINE hn::Vec<decltype(uint8xN_tag)> GetShuffle0() {
+ return hn::Dup128VecFromValues(
+ uint8xN_tag, 0, 2, 2, 4, 4, 6, 6, 8, 1, 3, 3, 5, 5, 7, 7, 9 //
+ );
+}
+HWY_ATTR HWY_INLINE hn::Vec<decltype(uint8xN_tag)> GetShuffle1() {
+ return hn::Dup128VecFromValues(
+ uint8xN_tag, 4, 6, 6, 8, 8, 10, 10, 12, 5, 7, 7, 9, 9, 11, 11, 13 //
+ );
+}
+HWY_ATTR HWY_INLINE hn::Vec<decltype(uint8xN_tag)> GetShuffle2() {
+ return hn::Dup128VecFromValues(
+ uint8xN_tag, 1, 3, 3, 5, 5, 7, 7, 9, 2, 4, 4, 6, 6, 8, 8, 10 //
+ );
+}
+HWY_ATTR HWY_INLINE hn::Vec<decltype(uint8xN_tag)> GetShuffle3() {
+ return hn::Dup128VecFromValues(
+ uint8xN_tag, 5, 7, 7, 9, 9, 11, 11, 13, 6, 8, 8, 10, 10, 12, 12, 14 //
+ );
+}
+
using IVec16 = hn::Vec<decltype(int16xN_tag)>;
using IVec32 = hn::Vec<decltype(int32xN_tag)>;
using IVec8x16 = hn::Vec<decltype(int8x16_tag)>;
@@ -67,27 +88,14 @@ HWY_ATTR inline void FilterPixelsHorizontal(D tag, const hn::VFromD<D> src,
const auto coeff2 = hn::Load(int8_tag, coeff + hn::MaxLanes(coeff_tag) * 2);
const auto coeff3 = hn::Load(int8_tag, coeff + hn::MaxLanes(coeff_tag) * 3);
- const auto shuffle0 = hn::Dup128VecFromValues(
- uint8xN_tag, 0, 2, 2, 4, 4, 6, 6, 8, 1, 3, 3, 5, 5, 7, 7, 9 //
- );
- const auto shuffle1 = hn::Dup128VecFromValues(
- uint8xN_tag, 4, 6, 6, 8, 8, 10, 10, 12, 5, 7, 7, 9, 9, 11, 11, 13 //
- );
- const auto shuffle2 = hn::Dup128VecFromValues(
- uint8xN_tag, 1, 3, 3, 5, 5, 7, 7, 9, 2, 4, 4, 6, 6, 8, 8, 10 //
- );
- const auto shuffle3 = hn::Dup128VecFromValues(
- uint8xN_tag, 5, 7, 7, 9, 9, 11, 11, 13, 6, 8, 8, 10, 10, 12, 12, 14 //
- );
-
const auto src_0 =
- hn::TableLookupBytes(src, hn::ResizeBitCast(tag, shuffle0));
+ hn::TableLookupBytes(src, hn::ResizeBitCast(tag, GetShuffle0()));
const auto src_1 =
- hn::TableLookupBytes(src, hn::ResizeBitCast(tag, shuffle1));
+ hn::TableLookupBytes(src, hn::ResizeBitCast(tag, GetShuffle1()));
const auto src_2 =
- hn::TableLookupBytes(src, hn::ResizeBitCast(tag, shuffle2));
+ hn::TableLookupBytes(src, hn::ResizeBitCast(tag, GetShuffle2()));
const auto src_3 =
- hn::TableLookupBytes(src, hn::ResizeBitCast(tag, shuffle3));
+ hn::TableLookupBytes(src, hn::ResizeBitCast(tag, GetShuffle3()));
const auto res_02 = hn::SatWidenMulPairwiseAdd(result_tag, src_0, coeff0);
const auto res_46 = hn::SatWidenMulPairwiseAdd(result_tag, src_1, coeff1);
@@ -107,14 +115,16 @@ HWY_ATTR inline void FilterPixelsHorizontal(D tag, const hn::VFromD<D> src,
}
HWY_ATTR HWY_INLINE IVec8x16 LoadAV1Filter8Bit(unsigned int offset) {
- return hn::LoadN(int8x16_tag, av1_filter_8bit[offset >> WARPEDDIFF_PREC_BITS],
- 8);
+ const int8_t *HWY_RESTRICT filter_ptr = &av1_filter_8bit[0][0];
+ return hn::LoadN(int8x16_tag,
+ filter_ptr + (offset >> WARPEDDIFF_PREC_BITS) * 8, 8);
}
template <typename D>
HWY_ATTR HWY_INLINE hn::VFromD<D> LoadAV1Filter8BitLower(D int8_tag,
unsigned int offset) {
- return hn::LoadN(int8_tag, av1_filter_8bit[offset >> WARPEDDIFF_PREC_BITS],
+ const int8_t *HWY_RESTRICT filter_ptr = &av1_filter_8bit[0][0];
+ return hn::LoadN(int8_tag, filter_ptr + (offset >> WARPEDDIFF_PREC_BITS) * 8,
8);
}
@@ -123,9 +133,10 @@ HWY_ATTR HWY_INLINE hn::VFromD<D> LoadAV1Filter8BitUpper(D int8_tag,
unsigned int offset,
hn::VFromD<D> src) {
(void)int8_tag;
+ const int8_t *HWY_RESTRICT filter_ptr = &av1_filter_8bit[0][0];
return hn::InsertBlock<Block>(
src, hn::LoadN(int8x16_tag,
- av1_filter_8bit[offset >> WARPEDDIFF_PREC_BITS], 8));
+ filter_ptr + (offset >> WARPEDDIFF_PREC_BITS) * 8, 8));
}
template <typename D>
@@ -390,26 +401,23 @@ HWY_ATTR HWY_INLINE hn::VFromD<D> LoadRowsClamped(
return src;
}
-template <void (*PrepareCoeffs)(int alpha, int beta, int sx,
- int8_t *HWY_RESTRICT coeffs),
- typename D>
-HWY_ATTR int WarpHorizontalFilterLoop(
- D tag, const uint8_t *HWY_RESTRICT ref, int16_t *HWY_RESTRICT horz_out,
- int stride, int32_t ix4, int32_t iy4, int32_t sx4, int alpha, int beta,
- int p_height, int height, int i, const IVec16 round_const,
- const int reduce_bits_horiz, int k, int8_t *HWY_RESTRICT coeff) {
- constexpr int kNumRows = tag.MaxBlocks();
- for (; k < AOMMIN(8, p_height - i) - kNumRows; k += kNumRows) {
- const auto src =
- LoadRowsClamped(tag, ref + ix4 - 7, stride, iy4 + k, height);
- if constexpr (PrepareCoeffs != nullptr) {
- int sx = sx4 + beta * (k + 4);
- PrepareCoeffs(alpha, beta, sx, coeff);
- }
- FilterPixelsHorizontal(tag, src, horz_out, coeff, round_const,
- reduce_bits_horiz, k + 7);
+template <typename D>
+HWY_ATTR HWY_INLINE hn::VFromD<D> LoadRowsInterior(
+ D tag, const uint8_t *HWY_RESTRICT ptr, const int stride, const int stride2,
+ const int stride3) {
+ constexpr hn::BlockDFromD<D> block_tag;
+ auto src = hn::ResizeBitCast(tag, hn::LoadU(block_tag, ptr));
+ if constexpr (tag.MaxBlocks() >= 2) {
+ const auto src_1 = hn::LoadU(block_tag, ptr + stride);
+ src = hn::InsertBlock<1>(src, src_1);
}
- return k;
+ if constexpr (tag.MaxBlocks() >= 3) {
+ const auto src_2 = hn::LoadU(block_tag, ptr + stride2);
+ const auto src_3 = hn::LoadU(block_tag, ptr + stride3);
+ src = hn::InsertBlock<2>(src, src_2);
+ src = hn::InsertBlock<3>(src, src_3);
+ }
+ return src;
}
enum class HorizontalFilterCoeffs {
@@ -443,6 +451,49 @@ HWY_ATTR void WarpHorizontalPrepareCoeffs(int alpha, int beta, int sx,
}
}
+template <bool InnerCoeffUpdate, bool IsLast, HorizontalFilterCoeffs Filter,
+ typename D, typename D16>
+HWY_ATTR int WarpHorizontalFilterLoop(
+ D tag, D16, const uint8_t *HWY_RESTRICT ref, int16_t *HWY_RESTRICT horz_out,
+ int stride, int32_t ix4, int32_t iy4, int32_t sx4, int alpha, int beta,
+ int p_height, int height, int i, const IVec16 round_const,
+ const int reduce_bits_horiz, int k, int8_t *HWY_RESTRICT coeff) {
+ constexpr int kNumRows = tag.MaxBlocks();
+ for (; k < AOMMIN(8, p_height - i) - kNumRows; k += kNumRows) {
+ const auto src =
+ LoadRowsClamped(tag, ref + ix4 - 7, stride, iy4 + k, height);
+ if constexpr (InnerCoeffUpdate) {
+ int sx = sx4 + beta * (k + 4);
+ WarpHorizontalPrepareCoeffs<IsLast, Filter, D16>(alpha, beta, sx, coeff);
+ }
+ FilterPixelsHorizontal(tag, src, horz_out, coeff, round_const,
+ reduce_bits_horiz, k + 7);
+ }
+ return k;
+}
+
+template <bool InnerCoeffUpdate, bool IsLast, HorizontalFilterCoeffs Filter,
+ typename D, typename D16>
+HWY_ATTR int WarpHorizontalFilterLoopInterior(
+ D tag, D16, const uint8_t *HWY_RESTRICT base_ptr,
+ int16_t *HWY_RESTRICT horz_out, int stride, int stride2, int stride3,
+ int32_t sx4, int alpha, int beta, int p_height, int i,
+ const IVec16 round_const, const int reduce_bits_horiz, int k,
+ int8_t *HWY_RESTRICT coeff) {
+ constexpr int kNumRows = tag.MaxBlocks();
+ for (; k < AOMMIN(8, p_height - i) - kNumRows; k += kNumRows) {
+ const auto src = LoadRowsInterior(tag, base_ptr + (k + 7) * stride, stride,
+ stride2, stride3);
+ if constexpr (InnerCoeffUpdate) {
+ int sx = sx4 + beta * (k + 4);
+ WarpHorizontalPrepareCoeffs<IsLast, Filter, D16>(alpha, beta, sx, coeff);
+ }
+ FilterPixelsHorizontal(tag, src, horz_out, coeff, round_const,
+ reduce_bits_horiz, k + 7);
+ }
+ return k;
+}
+
template <bool InnerCoeffUpdate, HorizontalFilterCoeffs Filter>
HWY_ATTR inline void WarpHorizontalFilterTemplate(
const uint8_t *HWY_RESTRICT ref, int16_t *HWY_RESTRICT horz_out, int stride,
@@ -455,28 +506,19 @@ HWY_ATTR inline void WarpHorizontalFilterTemplate(
alpha, beta, sx4, coeff);
}
if constexpr (uint8xN_tag.MaxBlocks() >= 3) {
- k = WarpHorizontalFilterLoop<(
- InnerCoeffUpdate
- ? WarpHorizontalPrepareCoeffs<false, Filter, decltype(int16xN_tag)>
- : nullptr)>(uint8xN_tag, ref, horz_out, stride, ix4, iy4, sx4,
- alpha, beta, p_height, height, i, round_const,
- reduce_bits_horiz, k, coeff);
+ k = WarpHorizontalFilterLoop<InnerCoeffUpdate, false, Filter>(
+ uint8xN_tag, int16xN_tag, ref, horz_out, stride, ix4, iy4, sx4, alpha,
+ beta, p_height, height, i, round_const, reduce_bits_horiz, k, coeff);
}
if constexpr (uint8xN_tag.MaxBlocks() >= 2) {
- k = WarpHorizontalFilterLoop<(
- InnerCoeffUpdate
- ? WarpHorizontalPrepareCoeffs<false, Filter, decltype(int16x16_tag)>
- : nullptr)>(uint8x32_tag, ref, horz_out, stride, ix4, iy4, sx4,
- alpha, beta, p_height, height, i, round_const,
- reduce_bits_horiz, k, coeff);
+ k = WarpHorizontalFilterLoop<InnerCoeffUpdate, false, Filter>(
+ uint8x32_tag, int16x16_tag, ref, horz_out, stride, ix4, iy4, sx4, alpha,
+ beta, p_height, height, i, round_const, reduce_bits_horiz, k, coeff);
}
if constexpr (uint8xN_tag.MaxBlocks() == 1) {
- k = WarpHorizontalFilterLoop<(
- InnerCoeffUpdate
- ? WarpHorizontalPrepareCoeffs<true, Filter, decltype(int16x8_tag)>
- : nullptr)>(uint8x16_tag, ref, horz_out, stride, ix4, iy4, sx4,
- alpha, beta, p_height, height, i, round_const,
- reduce_bits_horiz, k, coeff);
+ k = WarpHorizontalFilterLoop<InnerCoeffUpdate, true, Filter>(
+ uint8x16_tag, int16x8_tag, ref, horz_out, stride, ix4, iy4, sx4, alpha,
+ beta, p_height, height, i, round_const, reduce_bits_horiz, k, coeff);
}
iy = iy4 + k;
iy = clamp(iy, 0, height - 1);
@@ -490,6 +532,49 @@ HWY_ATTR inline void WarpHorizontalFilterTemplate(
reduce_bits_horiz, k + 7);
}
+template <bool InnerCoeffUpdate, HorizontalFilterCoeffs Filter>
+HWY_ATTR inline void WarpHorizontalFilterTemplateInterior(
+ const uint8_t *HWY_RESTRICT ref, int16_t *HWY_RESTRICT horz_out, int stride,
+ int32_t ix4, int32_t iy4, int32_t sx4, int alpha, int beta, int p_height,
+ int i, const IVec16 round_const, const int reduce_bits_horiz) {
+ const uint8_t *HWY_RESTRICT base_ptr = ref + (iy4 - 7) * stride + ix4 - 7;
+ const int stride2 = stride * 2;
+ const int stride3 = stride * 3;
+
+ int k = -7;
+ HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(coeff_tag)];
+ if constexpr (!InnerCoeffUpdate) {
+ WarpHorizontalPrepareCoeffs<false, Filter, decltype(int16xN_tag)>(
+ alpha, beta, sx4, coeff);
+ }
+ if constexpr (uint8xN_tag.MaxBlocks() >= 3) {
+ k = WarpHorizontalFilterLoopInterior<InnerCoeffUpdate, false, Filter>(
+ uint8xN_tag, int16xN_tag, base_ptr, horz_out, stride, stride2, stride3,
+ sx4, alpha, beta, p_height, i, round_const, reduce_bits_horiz, k,
+ coeff);
+ }
+ if constexpr (uint8xN_tag.MaxBlocks() >= 2) {
+ k = WarpHorizontalFilterLoopInterior<InnerCoeffUpdate, false, Filter>(
+ uint8x32_tag, int16x16_tag, base_ptr, horz_out, stride, stride2,
+ stride3, sx4, alpha, beta, p_height, i, round_const, reduce_bits_horiz,
+ k, coeff);
+ }
+ if constexpr (uint8xN_tag.MaxBlocks() == 1) {
+ k = WarpHorizontalFilterLoopInterior<InnerCoeffUpdate, true, Filter>(
+ uint8x16_tag, int16x8_tag, base_ptr, horz_out, stride, stride2, stride3,
+ sx4, alpha, beta, p_height, i, round_const, reduce_bits_horiz, k,
+ coeff);
+ }
+ const auto src = hn::LoadU(uint8x16_tag, base_ptr + (k + 7) * stride);
+ if constexpr (InnerCoeffUpdate) {
+ int sx = sx4 + beta * (k + 4);
+ WarpHorizontalPrepareCoeffs<true, Filter, decltype(int16x8_tag)>(
+ alpha, beta, sx, coeff);
+ }
+ FilterPixelsHorizontal(uint8x16_tag, src, horz_out, coeff, round_const,
+ reduce_bits_horiz, k + 7);
+}
+
HWY_ATTR inline void UnpackWeightsAndSetRoundConst(
ConvolveParams *HWY_RESTRICT conv_params, const int round_bits,
const int offset_bits, IVec16 &HWY_RESTRICT res_sub_const,
@@ -507,21 +592,28 @@ HWY_ATTR inline void UnpackWeightsAndSetRoundConst(
}
HWY_ATTR HWY_INLINE IVec16 LoadAV1WarpedFilter(size_t offset) {
+ const WarpedFilterCoeff *HWY_RESTRICT warped_filter_ptr =
+ &av1_warped_filter[0][0];
return hn::LoadN(int16xN_tag,
- av1_warped_filter[offset >> WARPEDDIFF_PREC_BITS], 8);
+ warped_filter_ptr + (offset >> WARPEDDIFF_PREC_BITS) * 8, 8);
}
HWY_ATTR HWY_INLINE IVec16 LoadAV1WarpedFilterLower(size_t offset) {
+ const WarpedFilterCoeff *HWY_RESTRICT warped_filter_ptr =
+ &av1_warped_filter[0][0];
return hn::ResizeBitCast(
int16xN_tag,
- hn::Load(int16x8_tag, av1_warped_filter[offset >> WARPEDDIFF_PREC_BITS]));
+ hn::Load(int16x8_tag,
+ warped_filter_ptr + (offset >> WARPEDDIFF_PREC_BITS) * 8));
}
template <int Block>
HWY_ATTR HWY_INLINE IVec16 LoadAV1WarpedFilterUpper(size_t offset, IVec16 src) {
+ const WarpedFilterCoeff *HWY_RESTRICT warped_filter_ptr =
+ &av1_warped_filter[0][0];
return hn::InsertBlock<Block>(
- src,
- hn::Load(int16x8_tag, av1_warped_filter[offset >> WARPEDDIFF_PREC_BITS]));
+ src, hn::Load(int16x8_tag,
+ warped_filter_ptr + (offset >> WARPEDDIFF_PREC_BITS) * 8));
}
HWY_ATTR inline void PrepareVerticalFilterCoeffs(int gamma, int delta, int sy,
@@ -709,10 +801,13 @@ HWY_ATTR inline void PrepareVerticalFilterCoeffsGamma0(
hn::Store(broadcast_3, int16xN_tag, coeffs + 7 * hn::MaxLanes(int16xN_tag));
}
-HWY_ATTR inline void FilterPixelsVertical(
- int16_t *HWY_RESTRICT horz_out, int16_t *HWY_RESTRICT src_lo,
- int16_t *HWY_RESTRICT src_hi, int16_t *HWY_RESTRICT coeffs,
- IVec32 &HWY_RESTRICT res_lo, IVec32 &HWY_RESTRICT res_hi, int row) {
+HWY_ATTR inline void FilterPixelsVertical(int16_t *HWY_RESTRICT horz_out,
+ int16_t *HWY_RESTRICT src_lo,
+ int16_t *HWY_RESTRICT src_hi,
+ int16_t *HWY_RESTRICT coeffs,
+ IVec32 &HWY_RESTRICT res_lo,
+ IVec32 &HWY_RESTRICT res_hi, int row,
+ const IVec32 res_add_const) {
if constexpr (int16xN_tag.MaxBlocks() >= 3) {
const auto horz_out_4 =
hn::Load(int16xN_tag, horz_out + (row + 4) * hn::MaxLanes(int16x8_tag));
@@ -737,10 +832,8 @@ HWY_ATTR inline void FilterPixelsVertical(
} else if constexpr (int16xN_tag.MaxBlocks() == 2) {
const auto horz_out_6 =
hn::Load(int16xN_tag, horz_out + (row + 6) * hn::MaxLanes(int16x8_tag));
- const auto horz_out_8 =
- hn::Load(int16xN_tag, horz_out + (row + 8) * hn::MaxLanes(int16x8_tag));
- const auto horz_out_7 =
- hn::ConcatLowerUpper(int16xN_tag, horz_out_8, horz_out_6);
+ const auto horz_out_7 = hn::LoadU(
+ int16xN_tag, horz_out + (row + 7) * hn::MaxLanes(int16x8_tag));
const auto src_lo_3 =
hn::InterleaveLower(int16xN_tag, horz_out_6, horz_out_7);
const auto src_hi_3 =
@@ -794,7 +887,7 @@ HWY_ATTR inline void FilterPixelsVertical(
const auto src_hi_3 =
hn::Load(int16xN_tag, src_hi + 3 * hn::MaxLanes(int16xN_tag));
- auto even_sum0 = hn::Zero(int32xN_tag);
+ auto even_sum0 = res_add_const;
auto even_sum1 = hn::Zero(int32xN_tag);
even_sum0 = hn::ReorderWidenMulAccumulate(int32xN_tag, src_lo_0, coeff_0,
even_sum0, even_sum1);
@@ -806,7 +899,7 @@ HWY_ATTR inline void FilterPixelsVertical(
even_sum0, even_sum1);
auto res_even = hn::RearrangeToOddPlusEven(even_sum0, even_sum1);
- auto odd_sum0 = hn::Zero(int32xN_tag);
+ auto odd_sum0 = res_add_const;
auto odd_sum1 = hn::Zero(int32xN_tag);
odd_sum0 = hn::ReorderWidenMulAccumulate(int32xN_tag, src_hi_0, coeff_4,
odd_sum0, odd_sum1);
@@ -828,8 +921,14 @@ HWY_ATTR HWY_INLINE void StoreRows(DS store_tag, DR row_tag, hn::VFromD<DR> vec,
A stride, B y, C x,
hn::TFromD<DS> *HWY_RESTRICT out) {
hn::TFromD<DS> *HWY_RESTRICT pointers[row_tag.MaxBlocks()];
- for (int i = 0; i < static_cast<int>(row_tag.MaxBlocks()); ++i) {
- pointers[i] = &out[(y + i) * stride + x];
+ hn::TFromD<DS> *HWY_RESTRICT base_ptr = &out[y * stride + x];
+ pointers[0] = base_ptr;
+ if constexpr (row_tag.MaxBlocks() >= 2) {
+ pointers[1] = base_ptr + stride;
+ }
+ if constexpr (row_tag.MaxBlocks() >= 3) {
+ pointers[2] = base_ptr + 2 * stride;
+ pointers[3] = base_ptr + 3 * stride;
}
hn::Store(hn::ResizeBitCast(store_tag, hn::ExtractBlock<0>(vec)), store_tag,
pointers[0]);
@@ -846,149 +945,151 @@ HWY_ATTR HWY_INLINE void StoreRows(DS store_tag, DR row_tag, hn::VFromD<DR> vec,
}
HWY_ATTR HWY_INLINE void StoreVerticalFilterOutput(
- IVec32 res_lo, IVec32 res_hi, const IVec32 res_add_const, const IVec16 wt,
- const IVec16 res_sub_const, const IVec16 round_bits_const,
- uint8_t *HWY_RESTRICT pred, ConvolveParams *HWY_RESTRICT conv_params, int i,
- int j, int k, const int reduce_bits_vert, int p_stride, int p_width,
+ IVec32 res_lo, IVec32 res_hi, const IVec16 wt,
+ const IVec16 res_sub_round_const, uint8_t *HWY_RESTRICT pred,
+ ConvolveParams *HWY_RESTRICT conv_params, int i, int j, int k,
+ const int reduce_bits_vert, int p_stride, int p_width,
const int round_bits) {
constexpr int kNumRows = uint16xN_tag.MaxBlocks();
if (conv_params->is_compound) {
uint16_t *HWY_RESTRICT pointers[kNumRows];
- for (int row = 0; row < kNumRows; ++row) {
- pointers[row] =
- &conv_params->dst[(i + k + row) * conv_params->dst_stride + j];
+ uint16_t *const base_ptr =
+ &conv_params->dst[(i + k) * conv_params->dst_stride + j];
+ const int dst_stride = conv_params->dst_stride;
+ pointers[0] = base_ptr;
+ if constexpr (kNumRows >= 2) {
+ pointers[1] = base_ptr + dst_stride;
}
-
- res_lo =
- hn::ShiftRightSame(hn::Add(res_lo, res_add_const), reduce_bits_vert);
-
- const auto temp_lo_16 = hn::ReorderDemote2To(uint16xN_tag, res_lo, res_lo);
- if (conv_params->do_average) {
- auto p_16 =
- hn::ResizeBitCast(uint16xN_tag, hn::Load(uint16x4_tag, pointers[0]));
- if constexpr (kNumRows >= 2) {
- p_16 = hn::InsertBlock<1>(
- p_16, hn::ResizeBitCast(uint16x8_tag,
- hn::Load(uint16x4_tag, pointers[1])));
- }
- if constexpr (kNumRows >= 3) {
- p_16 = hn::InsertBlock<2>(
- p_16, hn::ResizeBitCast(uint16x8_tag,
- hn::Load(uint16x4_tag, pointers[2])));
- p_16 = hn::InsertBlock<3>(
- p_16, hn::ResizeBitCast(uint16x8_tag,
- hn::Load(uint16x4_tag, pointers[3])));
- }
- auto res_lo_16 = hn::Undefined(int16xN_tag);
- if (conv_params->use_dist_wtd_comp_avg) {
- const auto p_16_lo =
- hn::BitCast(int16xN_tag, hn::InterleaveLower(p_16, temp_lo_16));
- const auto wt_res_lo =
- hn::WidenMulPairwiseAdd(int32xN_tag, p_16_lo, wt);
- const auto shifted_32 = hn::ShiftRight<DIST_PRECISION_BITS>(wt_res_lo);
- res_lo_16 = hn::BitCast(
- int16xN_tag,
- hn::ReorderDemote2To(uint16xN_tag, shifted_32, shifted_32));
- } else {
- res_lo_16 = hn::ShiftRight<1>(
- hn::BitCast(int16xN_tag, hn::Add(p_16, temp_lo_16)));
- }
- res_lo_16 = hn::Add(res_lo_16, res_sub_const);
- res_lo_16 =
- hn::ShiftRightSame(hn::Add(res_lo_16, round_bits_const), round_bits);
- const auto res_8_lo =
- hn::ReorderDemote2To(uint8xN_tag, res_lo_16, res_lo_16);
- StoreRows(uint8x4_tag, uint8xN_tag, res_8_lo, p_stride, i + k, j, pred);
- } else {
- hn::Store(
- hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<0>(temp_lo_16)),
- uint16x4_tag, pointers[0]);
- if constexpr (kNumRows >= 2) {
- hn::Store(
- hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<1>(temp_lo_16)),
- uint16x4_tag, pointers[1]);
- }
- if constexpr (kNumRows >= 3) {
- hn::Store(
- hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<2>(temp_lo_16)),
- uint16x4_tag, pointers[2]);
- hn::Store(
- hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<3>(temp_lo_16)),
- uint16x4_tag, pointers[3]);
- }
+ if constexpr (kNumRows >= 3) {
+ pointers[2] = base_ptr + 2 * dst_stride;
+ pointers[3] = base_ptr + 3 * dst_stride;
}
+
+ res_lo = hn::ShiftRightSame(res_lo, reduce_bits_vert);
+
if (p_width > 4) {
- uint16_t *HWY_RESTRICT pointers4[kNumRows];
- for (int row = 0; row < kNumRows; ++row) {
- pointers4[row] =
- &conv_params->dst[(i + k + row) * conv_params->dst_stride + j + 4];
- }
- res_hi =
- hn::ShiftRightSame(hn::Add(res_hi, res_add_const), reduce_bits_vert);
- const auto temp_hi_16 =
- hn::ReorderDemote2To(uint16xN_tag, res_hi, res_hi);
+ res_hi = hn::ShiftRightSame(res_hi, reduce_bits_vert);
+ const auto temp_16 = hn::ReorderDemote2To(uint16xN_tag, res_lo, res_hi);
if (conv_params->do_average) {
- auto p4_16 = hn::ResizeBitCast(uint16xN_tag,
- hn::Load(uint16x4_tag, pointers4[0]));
+ auto p_16 = hn::ResizeBitCast(uint16xN_tag,
+ hn::LoadU(uint16x8_tag, pointers[0]));
if constexpr (kNumRows >= 2) {
- p4_16 = hn::InsertBlock<1>(
- p4_16, hn::ResizeBitCast(uint16x8_tag,
- hn::Load(uint16x4_tag, pointers4[1])));
+ p_16 = hn::InsertBlock<1>(
+ p_16, hn::ResizeBitCast(uint16x8_tag,
+ hn::LoadU(uint16x8_tag, pointers[1])));
}
if constexpr (kNumRows >= 3) {
- p4_16 = hn::InsertBlock<2>(
- p4_16, hn::ResizeBitCast(uint16x8_tag,
- hn::Load(uint16x4_tag, pointers4[2])));
- p4_16 = hn::InsertBlock<3>(
- p4_16, hn::ResizeBitCast(uint16x8_tag,
- hn::Load(uint16x4_tag, pointers4[3])));
+ p_16 = hn::InsertBlock<2>(
+ p_16, hn::ResizeBitCast(uint16x8_tag,
+ hn::LoadU(uint16x8_tag, pointers[2])));
+ p_16 = hn::InsertBlock<3>(
+ p_16, hn::ResizeBitCast(uint16x8_tag,
+ hn::LoadU(uint16x8_tag, pointers[3])));
}
-
- auto res_hi_16 = hn::Undefined(int16xN_tag);
+ auto res_16 = hn::Undefined(int16xN_tag);
if (conv_params->use_dist_wtd_comp_avg) {
- const auto p_16_hi =
- hn::BitCast(int16xN_tag, hn::InterleaveLower(p4_16, temp_hi_16));
+ const auto p_16_lo = hn::BitCast(
+ int16xN_tag, hn::InterleaveLower(uint16xN_tag, p_16, temp_16));
+ const auto p_16_hi = hn::BitCast(
+ int16xN_tag, hn::InterleaveUpper(uint16xN_tag, p_16, temp_16));
+ const auto wt_res_lo =
+ hn::WidenMulPairwiseAdd(int32xN_tag, p_16_lo, wt);
const auto wt_res_hi =
hn::WidenMulPairwiseAdd(int32xN_tag, p_16_hi, wt);
- const auto shifted_32 =
+ const auto shifted_lo_32 =
+ hn::ShiftRight<DIST_PRECISION_BITS>(wt_res_lo);
+ const auto shifted_hi_32 =
hn::ShiftRight<DIST_PRECISION_BITS>(wt_res_hi);
- res_hi_16 = hn::BitCast(
+ res_16 = hn::BitCast(
+ int16xN_tag,
+ hn::ReorderDemote2To(uint16xN_tag, shifted_lo_32, shifted_hi_32));
+ } else {
+ res_16 = hn::ShiftRight<1>(
+ hn::BitCast(int16xN_tag, hn::Add(p_16, temp_16)));
+ }
+ res_16 = hn::ShiftRightSame(hn::Add(res_16, res_sub_round_const),
+ round_bits);
+ const auto res_8 = hn::ReorderDemote2To(uint8xN_tag, res_16, res_16);
+ StoreRows(uint8x8_tag, uint8xN_tag, res_8, p_stride, i + k, j, pred);
+ } else {
+ hn::StoreU(
+ hn::ResizeBitCast(uint16x8_tag, hn::ExtractBlock<0>(temp_16)),
+ uint16x8_tag, pointers[0]);
+ if constexpr (kNumRows >= 2) {
+ hn::StoreU(
+ hn::ResizeBitCast(uint16x8_tag, hn::ExtractBlock<1>(temp_16)),
+ uint16x8_tag, pointers[1]);
+ }
+ if constexpr (kNumRows >= 3) {
+ hn::StoreU(
+ hn::ResizeBitCast(uint16x8_tag, hn::ExtractBlock<2>(temp_16)),
+ uint16x8_tag, pointers[2]);
+ hn::StoreU(
+ hn::ResizeBitCast(uint16x8_tag, hn::ExtractBlock<3>(temp_16)),
+ uint16x8_tag, pointers[3]);
+ }
+ }
+ } else {
+ const auto temp_lo_16 =
+ hn::ReorderDemote2To(uint16xN_tag, res_lo, res_lo);
+ if (conv_params->do_average) {
+ auto p_16 = hn::ResizeBitCast(uint16xN_tag,
+ hn::LoadU(uint16x4_tag, pointers[0]));
+ if constexpr (kNumRows >= 2) {
+ p_16 = hn::InsertBlock<1>(
+ p_16, hn::ResizeBitCast(uint16x8_tag,
+ hn::LoadU(uint16x4_tag, pointers[1])));
+ }
+ if constexpr (kNumRows >= 3) {
+ p_16 = hn::InsertBlock<2>(
+ p_16, hn::ResizeBitCast(uint16x8_tag,
+ hn::LoadU(uint16x4_tag, pointers[2])));
+ p_16 = hn::InsertBlock<3>(
+ p_16, hn::ResizeBitCast(uint16x8_tag,
+ hn::LoadU(uint16x4_tag, pointers[3])));
+ }
+ auto res_lo_16 = hn::Undefined(int16xN_tag);
+ if (conv_params->use_dist_wtd_comp_avg) {
+ const auto p_16_lo = hn::BitCast(
+ int16xN_tag, hn::InterleaveLower(uint16xN_tag, p_16, temp_lo_16));
+ const auto wt_res_lo =
+ hn::WidenMulPairwiseAdd(int32xN_tag, p_16_lo, wt);
+ const auto shifted_32 =
+ hn::ShiftRight<DIST_PRECISION_BITS>(wt_res_lo);
+ res_lo_16 = hn::BitCast(
int16xN_tag,
hn::ReorderDemote2To(uint16xN_tag, shifted_32, shifted_32));
} else {
- res_hi_16 = hn::ShiftRight<1>(
- hn::BitCast(int16xN_tag, hn::Add(p4_16, temp_hi_16)));
+ res_lo_16 = hn::ShiftRight<1>(
+ hn::BitCast(int16xN_tag, hn::Add(p_16, temp_lo_16)));
}
- res_hi_16 = hn::Add(res_hi_16, res_sub_const);
- res_hi_16 = hn::ShiftRightSame(hn::Add(res_hi_16, round_bits_const),
+ res_lo_16 = hn::ShiftRightSame(hn::Add(res_lo_16, res_sub_round_const),
round_bits);
- const auto res_8_hi =
- hn::ReorderDemote2To(uint8xN_tag, res_hi_16, res_hi_16);
- StoreRows(uint8x4_tag, uint8xN_tag, res_8_hi, p_stride, i + k, j + 4,
- pred);
+ const auto res_8_lo =
+ hn::ReorderDemote2To(uint8xN_tag, res_lo_16, res_lo_16);
+ StoreRows(uint8x4_tag, uint8xN_tag, res_8_lo, p_stride, i + k, j, pred);
} else {
- hn::Store(hn::ResizeBitCast(uint16x4_tag, temp_hi_16), uint16x4_tag,
- pointers4[0]);
+ hn::Store(
+ hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<0>(temp_lo_16)),
+ uint16x4_tag, pointers[0]);
if constexpr (kNumRows >= 2) {
hn::Store(
- hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<1>(temp_hi_16)),
- uint16x4_tag, pointers4[1]);
+ hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<1>(temp_lo_16)),
+ uint16x4_tag, pointers[1]);
}
if constexpr (kNumRows >= 3) {
hn::Store(
- hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<2>(temp_hi_16)),
- uint16x4_tag, pointers4[2]);
+ hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<2>(temp_lo_16)),
+ uint16x4_tag, pointers[2]);
hn::Store(
- hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<3>(temp_hi_16)),
- uint16x4_tag, pointers4[3]);
+ hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<3>(temp_lo_16)),
+ uint16x4_tag, pointers[3]);
}
}
}
} else {
- const auto res_lo_round =
- hn::ShiftRightSame(hn::Add(res_lo, res_add_const), reduce_bits_vert);
- const auto res_hi_round =
- hn::ShiftRightSame(hn::Add(res_hi, res_add_const), reduce_bits_vert);
+ const auto res_lo_round = hn::ShiftRightSame(res_lo, reduce_bits_vert);
+ const auto res_hi_round = hn::ShiftRightSame(res_hi, reduce_bits_vert);
const auto res_16bit =
hn::ReorderDemote2To(int16xN_tag, res_lo_round, res_hi_round);
@@ -1003,16 +1104,27 @@ HWY_ATTR HWY_INLINE void StoreVerticalFilterOutput(
}
}
-template <bool InnerCoeffUpdate,
- void (*PrepareCoeffs)(int gamma, int delta, int sy,
- int16_t *HWY_RESTRICT coeffs)>
+enum class VerticalFilterCoeffs { kGamma0, kDelta0, kDefault };
+
+template <VerticalFilterCoeffs Filter>
+HWY_ATTR inline void WarpVerticalPrepareCoeffs(int gamma, int delta, int sy,
+ int16_t *HWY_RESTRICT coeffs) {
+ if constexpr (Filter == VerticalFilterCoeffs::kGamma0) {
+ PrepareVerticalFilterCoeffsGamma0(gamma, delta, sy, coeffs);
+ } else if constexpr (Filter == VerticalFilterCoeffs::kDelta0) {
+ PrepareVerticalFilterCoeffsDelta0(gamma, delta, sy, coeffs);
+ } else {
+ PrepareVerticalFilterCoeffs(gamma, delta, sy, coeffs);
+ }
+}
+
+template <bool InnerCoeffUpdate, VerticalFilterCoeffs Filter>
HWY_ATTR inline void WarpVerticalFilterTemplate(
uint8_t *HWY_RESTRICT pred, int16_t *HWY_RESTRICT horz_out,
ConvolveParams *HWY_RESTRICT conv_params, int16_t gamma, int16_t delta,
int p_height, int p_stride, int p_width, int i, int j, int sy4,
const int reduce_bits_vert, const IVec32 res_add_const,
- const int round_bits, const IVec16 res_sub_const,
- const IVec16 round_bits_const, const IVec16 wt) {
+ const int round_bits, const IVec16 res_sub_round_const, const IVec16 wt) {
HWY_ALIGN int16_t src_lo[4 * hn::MaxLanes(int16xN_tag)];
HWY_ALIGN int16_t src_hi[4 * hn::MaxLanes(int16xN_tag)];
if constexpr (int16xN_tag.MaxBlocks() >= 3) {
@@ -1039,14 +1151,12 @@ HWY_ATTR inline void WarpVerticalFilterTemplate(
hn::Load(int16xN_tag, horz_out + 1 * hn::MaxLanes(int16xN_tag));
const auto horz_out_4 =
hn::Load(int16xN_tag, horz_out + 2 * hn::MaxLanes(int16xN_tag));
- const auto horz_out_6 =
- hn::Load(int16xN_tag, horz_out + 3 * hn::MaxLanes(int16xN_tag));
const auto horz_out_1 =
- hn::ConcatLowerUpper(int16xN_tag, horz_out_2, horz_out_0);
+ hn::LoadU(int16xN_tag, horz_out + 1 * hn::MaxLanes(int16x8_tag));
const auto horz_out_3 =
- hn::ConcatLowerUpper(int16xN_tag, horz_out_4, horz_out_2);
+ hn::LoadU(int16xN_tag, horz_out + 3 * hn::MaxLanes(int16x8_tag));
const auto horz_out_5 =
- hn::ConcatLowerUpper(int16xN_tag, horz_out_6, horz_out_4);
+ hn::LoadU(int16xN_tag, horz_out + 5 * hn::MaxLanes(int16x8_tag));
hn::Store(hn::InterleaveLower(int16xN_tag, horz_out_0, horz_out_1),
int16xN_tag, src_lo + 0 * hn::MaxLanes(int16xN_tag));
hn::Store(hn::InterleaveUpper(int16xN_tag, horz_out_0, horz_out_1),
@@ -1088,22 +1198,22 @@ HWY_ATTR inline void WarpVerticalFilterTemplate(
HWY_ALIGN int16_t coeffs[8 * hn::MaxLanes(int16xN_tag)];
if constexpr (!InnerCoeffUpdate) {
- PrepareCoeffs(gamma, delta, sy4, coeffs);
+ WarpVerticalPrepareCoeffs<Filter>(gamma, delta, sy4, coeffs);
}
for (int k = -4; k < AOMMIN(4, p_height - i - 4);
k += static_cast<int>(int16xN_tag.MaxBlocks())) {
if constexpr (InnerCoeffUpdate) {
int sy = sy4 + delta * (k + 4);
- PrepareCoeffs(gamma, delta, sy, coeffs);
+ WarpVerticalPrepareCoeffs<Filter>(gamma, delta, sy, coeffs);
}
IVec32 res_lo, res_hi;
FilterPixelsVertical(horz_out, src_lo, src_hi, coeffs, res_lo, res_hi,
- k + 4);
- StoreVerticalFilterOutput(res_lo, res_hi, res_add_const, wt, res_sub_const,
- round_bits_const, pred, conv_params, i, j, k + 4,
- reduce_bits_vert, p_stride, p_width, round_bits);
+ k + 4, res_add_const);
+ StoreVerticalFilterOutput(res_lo, res_hi, wt, res_sub_round_const, pred,
+ conv_params, i, j, k + 4, reduce_bits_vert,
+ p_stride, p_width, round_bits);
if constexpr (int16xN_tag.MaxBlocks() >= 3) {
hn::Store(hn::Load(int16xN_tag, src_lo + 2 * hn::MaxLanes(int16xN_tag)),
@@ -1177,28 +1287,27 @@ HWY_ATTR inline void PrepareWarpVerticalFilter(
ConvolveParams *HWY_RESTRICT conv_params, int16_t gamma, int16_t delta,
int p_height, int p_stride, int p_width, int i, int j, int sy4,
const int reduce_bits_vert, const IVec32 res_add_const,
- const int round_bits, const IVec16 res_sub_const,
- const IVec16 round_bits_const, const IVec16 wt) {
+ const int round_bits, const IVec16 res_sub_round_const, const IVec16 wt) {
if (gamma == 0 && delta == 0)
- WarpVerticalFilterTemplate<false, PrepareVerticalFilterCoeffsGamma0>(
+ WarpVerticalFilterTemplate<false, VerticalFilterCoeffs::kGamma0>(
pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width,
- i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const,
- round_bits_const, wt);
+ i, j, sy4, reduce_bits_vert, res_add_const, round_bits,
+ res_sub_round_const, wt);
else if (gamma == 0 && delta != 0)
- WarpVerticalFilterTemplate<true, PrepareVerticalFilterCoeffsGamma0>(
+ WarpVerticalFilterTemplate<true, VerticalFilterCoeffs::kGamma0>(
pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width,
- i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const,
- round_bits_const, wt);
+ i, j, sy4, reduce_bits_vert, res_add_const, round_bits,
+ res_sub_round_const, wt);
else if (gamma != 0 && delta == 0)
- WarpVerticalFilterTemplate<false, PrepareVerticalFilterCoeffsDelta0>(
+ WarpVerticalFilterTemplate<false, VerticalFilterCoeffs::kDelta0>(
pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width,
- i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const,
- round_bits_const, wt);
+ i, j, sy4, reduce_bits_vert, res_add_const, round_bits,
+ res_sub_round_const, wt);
else
- WarpVerticalFilterTemplate<true, PrepareVerticalFilterCoeffs>(
+ WarpVerticalFilterTemplate<true, VerticalFilterCoeffs::kDefault>(
pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width,
- i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const,
- round_bits_const, wt);
+ i, j, sy4, reduce_bits_vert, res_add_const, round_bits,
+ res_sub_round_const, wt);
}
HWY_ATTR inline void PrepareWarpHorizontalFilter(
@@ -1353,14 +1462,14 @@ HWY_ATTR void WarpHorizontalFilterOutOfBoundsPad(
reduce_bits_horiz, k + 7);
}
-HWY_ATTR void WarpAffine(const int32_t *HWY_RESTRICT mat,
- const uint8_t *HWY_RESTRICT ref, int width, int height,
- int stride, uint8_t *HWY_RESTRICT pred, int p_col,
- int p_row, int p_width, int p_height, int p_stride,
- int subsampling_x, int subsampling_y,
- ConvolveParams *HWY_RESTRICT conv_params,
- int16_t alpha, int16_t beta, int16_t gamma,
- int16_t delta) {
+template <bool HorizInnerCoeffUpdate, HorizontalFilterCoeffs HorizFilter,
+ bool VertInnerCoeffUpdate, VerticalFilterCoeffs VertFilter>
+HWY_ATTR inline void WarpAffineTemplate(
+ const int32_t *HWY_RESTRICT mat, const uint8_t *HWY_RESTRICT ref, int width,
+ int height, int stride, uint8_t *HWY_RESTRICT pred, int p_col, int p_row,
+ int p_width, int p_height, int p_stride, int subsampling_x,
+ int subsampling_y, ConvolveParams *HWY_RESTRICT conv_params, int16_t alpha,
+ int16_t beta, int16_t gamma, int16_t delta) {
int i, j;
const int bd = 8;
const int reduce_bits_horiz = conv_params->round_0;
@@ -1385,6 +1494,7 @@ HWY_ATTR void WarpAffine(const int32_t *HWY_RESTRICT mat,
IVec16 res_sub_const, round_bits_const, wt;
UnpackWeightsAndSetRoundConst(conv_params, round_bits, offset_bits,
res_sub_const, round_bits_const, wt);
+ const auto res_sub_round_const = hn::Add(res_sub_const, round_bits_const);
IVec32 res_add_const_1;
if (conv_params->is_compound == 1) {
@@ -1403,21 +1513,38 @@ HWY_ATTR void WarpAffine(const int32_t *HWY_RESTRICT mat,
const int16_t const4 = (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1));
const int16_t const5 = (1 << (FILTER_BITS - reduce_bits_horiz));
+ const int64_t dst_x_stride_j =
+ static_cast<int64_t>(mat[2]) * (8 << subsampling_x);
+ const int64_t dst_y_stride_j =
+ static_cast<int64_t>(mat[4]) * (8 << subsampling_x);
+ const int64_t dst_x_stride_i =
+ static_cast<int64_t>(mat[3]) * (8 << subsampling_y);
+ const int64_t dst_y_stride_i =
+ static_cast<int64_t>(mat[5]) * (8 << subsampling_y);
+
+ const int32_t src_x_0 = (p_col + 4) << subsampling_x;
+ const int32_t src_y_0 = (p_row + 4) << subsampling_y;
+ const int64_t dst_x_start = static_cast<int64_t>(mat[2]) * src_x_0 +
+ static_cast<int64_t>(mat[3]) * src_y_0 +
+ static_cast<int64_t>(mat[0]);
+ const int64_t dst_y_start = static_cast<int64_t>(mat[4]) * src_x_0 +
+ static_cast<int64_t>(mat[5]) * src_y_0 +
+ static_cast<int64_t>(mat[1]);
+
+ int64_t dst_x_row = dst_x_start;
+ int64_t dst_y_row = dst_y_start;
+
for (i = 0; i < p_height; i += 8) {
+ int64_t dst_x = dst_x_row;
+ int64_t dst_y = dst_y_row;
for (j = 0; j < p_width; j += 8) {
HWY_ALIGN int16_t horz_out[8 * 16 + hn::MaxLanes(int16xN_tag)];
- const int32_t src_x = (p_col + j + 4) << subsampling_x;
- const int32_t src_y = (p_row + i + 4) << subsampling_y;
- const int64_t dst_x =
- (int64_t)mat[2] * src_x + (int64_t)mat[3] * src_y + (int64_t)mat[0];
- const int64_t dst_y =
- (int64_t)mat[4] * src_x + (int64_t)mat[5] * src_y + (int64_t)mat[1];
const int64_t x4 = dst_x >> subsampling_x;
const int64_t y4 = dst_y >> subsampling_y;
- int32_t ix4 = (int32_t)(x4 >> WARPEDMODEL_PREC_BITS);
+ int32_t ix4 = static_cast<int32_t>(x4 >> WARPEDMODEL_PREC_BITS);
int32_t sx4 = x4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
- int32_t iy4 = (int32_t)(y4 >> WARPEDMODEL_PREC_BITS);
+ int32_t iy4 = static_cast<int32_t>(y4 >> WARPEDMODEL_PREC_BITS);
int32_t sy4 = y4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
// Add in all the constant terms, including rounding and offset
@@ -1444,17 +1571,119 @@ HWY_ATTR void WarpAffine(const int32_t *HWY_RESTRICT mat,
ref, stride, ix4, iy4, sx4, alpha, beta, p_height, width, height, i,
round_const, reduce_bits_horiz, horz_out);
} else {
- PrepareWarpHorizontalFilter(ref, horz_out, stride, ix4, iy4, sx4, alpha,
- beta, p_height, height, i, round_const,
- reduce_bits_horiz);
+ if (iy4 - 7 >= 0 && iy4 + 7 < height) {
+ if constexpr (HorizFilter == HorizontalFilterCoeffs::kAlpha0) {
+ WarpHorizontalFilterTemplateInterior<
+ HorizInnerCoeffUpdate, HorizontalFilterCoeffs::kAlpha0>(
+ ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, i,
+ round_const, reduce_bits_horiz);
+ } else if constexpr (HorizFilter == HorizontalFilterCoeffs::kBeta0) {
+ WarpHorizontalFilterTemplateInterior<
+ HorizInnerCoeffUpdate, HorizontalFilterCoeffs::kBeta0>(
+ ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, i,
+ round_const, reduce_bits_horiz);
+ } else {
+ WarpHorizontalFilterTemplateInterior<
+ HorizInnerCoeffUpdate, HorizontalFilterCoeffs::kDefault>(
+ ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, i,
+ round_const, reduce_bits_horiz);
+ }
+ } else {
+ if constexpr (HorizFilter == HorizontalFilterCoeffs::kAlpha0) {
+ WarpHorizontalFilterTemplate<HorizInnerCoeffUpdate,
+ HorizontalFilterCoeffs::kAlpha0>(
+ ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height,
+ height, i, round_const, reduce_bits_horiz);
+ } else if constexpr (HorizFilter == HorizontalFilterCoeffs::kBeta0) {
+ WarpHorizontalFilterTemplate<HorizInnerCoeffUpdate,
+ HorizontalFilterCoeffs::kBeta0>(
+ ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height,
+ height, i, round_const, reduce_bits_horiz);
+ } else {
+ WarpHorizontalFilterTemplate<HorizInnerCoeffUpdate,
+ HorizontalFilterCoeffs::kDefault>(
+ ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height,
+ height, i, round_const, reduce_bits_horiz);
+ }
+ }
}
// Vertical filter
- PrepareWarpVerticalFilter(pred, horz_out, conv_params, gamma, delta,
- p_height, p_stride, p_width, i, j, sy4,
- reduce_bits_vert, res_add_const_1, round_bits,
- res_sub_const, round_bits_const, wt);
+ WarpVerticalFilterTemplate<VertInnerCoeffUpdate, VertFilter>(
+ pred, horz_out, conv_params, gamma, delta, p_height, p_stride,
+ p_width, i, j, sy4, reduce_bits_vert, res_add_const_1, round_bits,
+ res_sub_round_const, wt);
+
+ dst_x += dst_x_stride_j;
+ dst_y += dst_y_stride_j;
}
+ dst_x_row += dst_x_stride_i;
+ dst_y_row += dst_y_stride_i;
+ }
+}
+
+template <bool HorizInnerCoeffUpdate, HorizontalFilterCoeffs HorizFilter>
+HWY_ATTR inline void WarpAffineHorizDispatched(
+ const int32_t *HWY_RESTRICT mat, const uint8_t *HWY_RESTRICT ref, int width,
+ int height, int stride, uint8_t *HWY_RESTRICT pred, int p_col, int p_row,
+ int p_width, int p_height, int p_stride, int subsampling_x,
+ int subsampling_y, ConvolveParams *HWY_RESTRICT conv_params, int16_t alpha,
+ int16_t beta, int16_t gamma, int16_t delta) {
+ if (gamma == 0 && delta == 0) {
+ WarpAffineTemplate<HorizInnerCoeffUpdate, HorizFilter, false,
+ VerticalFilterCoeffs::kGamma0>(
+ mat, ref, width, height, stride, pred, p_col, p_row, p_width, p_height,
+ p_stride, subsampling_x, subsampling_y, conv_params, alpha, beta, gamma,
+ delta);
+ } else if (gamma == 0 && delta != 0) {
+ WarpAffineTemplate<HorizInnerCoeffUpdate, HorizFilter, true,
+ VerticalFilterCoeffs::kGamma0>(
+ mat, ref, width, height, stride, pred, p_col, p_row, p_width, p_height,
+ p_stride, subsampling_x, subsampling_y, conv_params, alpha, beta, gamma,
+ delta);
+ } else if (gamma != 0 && delta == 0) {
+ WarpAffineTemplate<HorizInnerCoeffUpdate, HorizFilter, false,
+ VerticalFilterCoeffs::kDelta0>(
+ mat, ref, width, height, stride, pred, p_col, p_row, p_width, p_height,
+ p_stride, subsampling_x, subsampling_y, conv_params, alpha, beta, gamma,
+ delta);
+ } else {
+ WarpAffineTemplate<HorizInnerCoeffUpdate, HorizFilter, true,
+ VerticalFilterCoeffs::kDefault>(
+ mat, ref, width, height, stride, pred, p_col, p_row, p_width, p_height,
+ p_stride, subsampling_x, subsampling_y, conv_params, alpha, beta, gamma,
+ delta);
+ }
+}
+
+HWY_ATTR void WarpAffine(const int32_t *HWY_RESTRICT mat,
+ const uint8_t *HWY_RESTRICT ref, int width, int height,
+ int stride, uint8_t *HWY_RESTRICT pred, int p_col,
+ int p_row, int p_width, int p_height, int p_stride,
+ int subsampling_x, int subsampling_y,
+ ConvolveParams *HWY_RESTRICT conv_params,
+ int16_t alpha, int16_t beta, int16_t gamma,
+ int16_t delta) {
+ if (alpha == 0 && beta == 0) {
+ WarpAffineHorizDispatched<false, HorizontalFilterCoeffs::kAlpha0>(
+ mat, ref, width, height, stride, pred, p_col, p_row, p_width, p_height,
+ p_stride, subsampling_x, subsampling_y, conv_params, alpha, beta, gamma,
+ delta);
+ } else if (alpha == 0 && beta != 0) {
+ WarpAffineHorizDispatched<true, HorizontalFilterCoeffs::kAlpha0>(
+ mat, ref, width, height, stride, pred, p_col, p_row, p_width, p_height,
+ p_stride, subsampling_x, subsampling_y, conv_params, alpha, beta, gamma,
+ delta);
+ } else if (alpha != 0 && beta == 0) {
+ WarpAffineHorizDispatched<false, HorizontalFilterCoeffs::kBeta0>(
+ mat, ref, width, height, stride, pred, p_col, p_row, p_width, p_height,
+ p_stride, subsampling_x, subsampling_y, conv_params, alpha, beta, gamma,
+ delta);
+ } else {
+ WarpAffineHorizDispatched<true, HorizontalFilterCoeffs::kDefault>(
+ mat, ref, width, height, stride, pred, p_col, p_row, p_width, p_height,
+ p_stride, subsampling_x, subsampling_y, conv_params, alpha, beta, gamma,
+ delta);
}
}