Commit 039042e9eb for aom
commit 039042e9ebb008b3fd0d39d6dbc635f8c80ea30e
Author: Jerome Jiang <jianj@google.com>
Date: Wed Jun 10 11:23:21 2026 -0400
Optimize av1_lowbd_pixel_proj_error_avx2.
Optimize the AVX2 implementation of the low bitdepth pixel projection
error calculation.
- Implement stride decoupling to support distinct strides for all
input buffers.
- Add width-specific optimized paths (width == 8, width >= 32
width >= 16) to improve efficiency for different block sizes.
- Use register-direct SIMD horizontal reduction to avoid memory
roundtrips.
| Block | Before | After | Speedup |
| :------ | :-----: | :-----: | :-------- |
| 8x8 | 71.78n | 12.35n | 5.81x |
| 16x16 | 51.81n | 39.27n | 1.32x |
| 32x32 | 165.80n | 151.40n | 1.10x |
| 64x64 | 743.00n | 653.10n | 1.14x |
| 128x128 | 2.583µ | 2.300µ | 1.12x |
| geomean | 259.80n | 161.60n | 1.61x |
Change-Id: I80072dec3cb710fb9a6e4fdebc2276eed485af6e
diff --git a/av1/encoder/x86/pickrst_avx2.c b/av1/encoder/x86/pickrst_avx2.c
index f5375dd296..792bc81bb9 100644
--- a/av1/encoder/x86/pickrst_avx2.c
+++ b/av1/encoder/x86/pickrst_avx2.c
@@ -1545,6 +1545,51 @@ static inline __m256i pair_set_epi16(int a, int b) {
(int32_t)(((uint16_t)(a)) | (((uint32_t)(uint16_t)(b)) << 16)));
}
+static inline __m256i load_shuffled_u8_to_epi16(const uint8_t *ptr) {
+ const __m128i raw = xx_loadu_128(ptr);
+ const __m128i shuffled = _mm_shuffle_epi32(raw, _MM_SHUFFLE(3, 1, 2, 0));
+ return _mm256_cvtepu8_epi16(shuffled);
+}
+
+static inline __m256i load_shuffled_u8_dual8_to_epi16(const uint8_t *ptrA,
+ const uint8_t *ptrB) {
+ const __m128i rawA = _mm_loadl_epi64((const __m128i *)ptrA);
+ const __m128i rawB = _mm_loadl_epi64((const __m128i *)ptrB);
+ const __m128i raw_AB = _mm_unpacklo_epi64(rawA, rawB);
+ const __m128i shuffled = _mm_shuffle_epi32(raw_AB, _MM_SHUFFLE(3, 1, 2, 0));
+ return _mm256_cvtepu8_epi16(shuffled);
+}
+
+static inline __m256i calc_proj_err_r0_r1_avx2(
+ const __m256i d0, const __m256i s0, const __m256i flt0_16b,
+ const __m256i flt1_16b, const __m256i xq_coeff, const __m256i rounding,
+ int shift) {
+ const __m256i u0 = _mm256_slli_epi16(d0, SGRPROJ_RST_BITS);
+ const __m256i v0 = _mm256_madd_epi16(
+ xq_coeff, _mm256_unpacklo_epi16(_mm256_sub_epi16(flt0_16b, u0),
+ _mm256_sub_epi16(flt1_16b, u0)));
+ const __m256i v1 = _mm256_madd_epi16(
+ xq_coeff, _mm256_unpackhi_epi16(_mm256_sub_epi16(flt0_16b, u0),
+ _mm256_sub_epi16(flt1_16b, u0)));
+ const __m256i vr = _mm256_packs_epi32(
+ _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift),
+ _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift));
+ return _mm256_add_epi16(vr, _mm256_sub_epi16(d0, s0));
+}
+
+static inline __m256i calc_proj_err_r0_or_r1_avx2(
+ const __m256i d0, const __m256i s0, const __m256i flt_16b,
+ const __m256i xq_coeff, const __m256i rounding, int shift) {
+ const __m256i v0 =
+ _mm256_madd_epi16(xq_coeff, _mm256_unpacklo_epi16(flt_16b, d0));
+ const __m256i v1 =
+ _mm256_madd_epi16(xq_coeff, _mm256_unpackhi_epi16(flt_16b, d0));
+ const __m256i vr_16b = _mm256_packs_epi32(
+ _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift),
+ _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift));
+ return _mm256_add_epi16(vr_16b, _mm256_sub_epi16(d0, s0));
+}
+
int64_t av1_lowbd_pixel_proj_error_avx2(
const uint8_t *src8, int width, int height, int src_stride,
const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
@@ -1556,53 +1601,220 @@ int64_t av1_lowbd_pixel_proj_error_avx2(
const uint8_t *src = src8;
const uint8_t *dat = dat8;
int64_t err = 0;
+
if (params->r[0] > 0 && params->r[1] > 0) {
__m256i xq_coeff = pair_set_epi16(xq[0], xq[1]);
- for (i = 0; i < height; ++i) {
+ if (width == 8) {
__m256i sum32 = _mm256_setzero_si256();
- for (j = 0; j <= width - 16; j += 16) {
- const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
- const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
- const __m256i flt0_16b = _mm256_permute4x64_epi64(
- _mm256_packs_epi32(yy_loadu_256(flt0 + j),
- yy_loadu_256(flt0 + j + 8)),
- 0xd8);
- const __m256i flt1_16b = _mm256_permute4x64_epi64(
- _mm256_packs_epi32(yy_loadu_256(flt1 + j),
- yy_loadu_256(flt1 + j + 8)),
- 0xd8);
- const __m256i u0 = _mm256_slli_epi16(d0, SGRPROJ_RST_BITS);
- const __m256i flt0_0_sub_u = _mm256_sub_epi16(flt0_16b, u0);
- const __m256i flt1_0_sub_u = _mm256_sub_epi16(flt1_16b, u0);
- const __m256i v0 = _mm256_madd_epi16(
- xq_coeff, _mm256_unpacklo_epi16(flt0_0_sub_u, flt1_0_sub_u));
- const __m256i v1 = _mm256_madd_epi16(
- xq_coeff, _mm256_unpackhi_epi16(flt0_0_sub_u, flt1_0_sub_u));
- const __m256i vr0 =
- _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift);
- const __m256i vr1 =
- _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift);
- const __m256i e0 = _mm256_sub_epi16(
- _mm256_add_epi16(_mm256_packs_epi32(vr0, vr1), d0), s0);
+ const int height_even = height & ~1;
+ for (i = 0; i < height_even; i += 2) {
+ const uint8_t *dat_rowB = dat + dat_stride;
+ const uint8_t *src_rowB = src + src_stride;
+ const int32_t *flt0_rowB = flt0 + flt0_stride;
+ const int32_t *flt1_rowB = flt1 + flt1_stride;
+
+ const __m256i d0 = load_shuffled_u8_dual8_to_epi16(dat, dat_rowB);
+ const __m256i s0 = load_shuffled_u8_dual8_to_epi16(src, src_rowB);
+ const __m256i flt0_16b =
+ _mm256_packs_epi32(yy_loadu_256(flt0), yy_loadu_256(flt0_rowB));
+ const __m256i flt1_16b =
+ _mm256_packs_epi32(yy_loadu_256(flt1), yy_loadu_256(flt1_rowB));
+
+ const __m256i e0 = calc_proj_err_r0_r1_avx2(d0, s0, flt0_16b, flt1_16b,
+ xq_coeff, rounding, shift);
const __m256i err0 = _mm256_madd_epi16(e0, e0);
sum32 = _mm256_add_epi32(sum32, err0);
+
+ dat += 2 * dat_stride;
+ src += 2 * src_stride;
+ flt0 += 2 * flt0_stride;
+ flt1 += 2 * flt1_stride;
}
- for (k = j; k < width; ++k) {
- const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
- int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
- const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
- err += ((int64_t)e * e);
+ if (i < height) {
+ for (k = 0; k < 8; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += ((int64_t)e * e);
+ }
}
- dat += dat_stride;
- src += src_stride;
- flt0 += flt0_stride;
- flt1 += flt1_stride;
const __m256i sum64_0 =
_mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
const __m256i sum64_1 =
_mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
- sum64 = _mm256_add_epi64(sum64, sum64_0);
- sum64 = _mm256_add_epi64(sum64, sum64_1);
+ sum64 = _mm256_add_epi64(sum64, _mm256_add_epi64(sum64_0, sum64_1));
+ } else if (width == 16) {
+ __m256i sum32_A = _mm256_setzero_si256();
+ __m256i sum32_B = _mm256_setzero_si256();
+ __m256i sum32_C = _mm256_setzero_si256();
+ __m256i sum32_D = _mm256_setzero_si256();
+ const int height_v4 = height & ~3;
+ for (i = 0; i < height_v4; i += 4) {
+ const uint8_t *dat_rowB = dat + dat_stride;
+ const uint8_t *dat_rowC = dat_rowB + dat_stride;
+ const uint8_t *dat_rowD = dat_rowC + dat_stride;
+ const uint8_t *src_rowB = src + src_stride;
+ const uint8_t *src_rowC = src_rowB + src_stride;
+ const uint8_t *src_rowD = src_rowC + src_stride;
+ const int32_t *flt0_rowB = flt0 + flt0_stride;
+ const int32_t *flt0_rowC = flt0_rowB + flt0_stride;
+ const int32_t *flt0_rowD = flt0_rowC + flt0_stride;
+ const int32_t *flt1_rowB = flt1 + flt1_stride;
+ const int32_t *flt1_rowC = flt1_rowB + flt1_stride;
+ const int32_t *flt1_rowD = flt1_rowC + flt1_stride;
+
+ // Row A
+ {
+ const __m256i d0 = load_shuffled_u8_to_epi16(dat);
+ const __m256i s0 = load_shuffled_u8_to_epi16(src);
+ const __m256i flt0_16b =
+ _mm256_packs_epi32(yy_loadu_256(flt0), yy_loadu_256(flt0 + 8));
+ const __m256i flt1_16b =
+ _mm256_packs_epi32(yy_loadu_256(flt1), yy_loadu_256(flt1 + 8));
+ const __m256i e = calc_proj_err_r0_r1_avx2(d0, s0, flt0_16b, flt1_16b,
+ xq_coeff, rounding, shift);
+ sum32_A = _mm256_add_epi32(sum32_A, _mm256_madd_epi16(e, e));
+ }
+ // Row B
+ {
+ const __m256i d0 = load_shuffled_u8_to_epi16(dat_rowB);
+ const __m256i s0 = load_shuffled_u8_to_epi16(src_rowB);
+ const __m256i flt0_16b = _mm256_packs_epi32(
+ yy_loadu_256(flt0_rowB), yy_loadu_256(flt0_rowB + 8));
+ const __m256i flt1_16b = _mm256_packs_epi32(
+ yy_loadu_256(flt1_rowB), yy_loadu_256(flt1_rowB + 8));
+ const __m256i e = calc_proj_err_r0_r1_avx2(d0, s0, flt0_16b, flt1_16b,
+ xq_coeff, rounding, shift);
+ sum32_B = _mm256_add_epi32(sum32_B, _mm256_madd_epi16(e, e));
+ }
+ // Row C
+ {
+ const __m256i d0 = load_shuffled_u8_to_epi16(dat_rowC);
+ const __m256i s0 = load_shuffled_u8_to_epi16(src_rowC);
+ const __m256i flt0_16b = _mm256_packs_epi32(
+ yy_loadu_256(flt0_rowC), yy_loadu_256(flt0_rowC + 8));
+ const __m256i flt1_16b = _mm256_packs_epi32(
+ yy_loadu_256(flt1_rowC), yy_loadu_256(flt1_rowC + 8));
+ const __m256i e = calc_proj_err_r0_r1_avx2(d0, s0, flt0_16b, flt1_16b,
+ xq_coeff, rounding, shift);
+ sum32_C = _mm256_add_epi32(sum32_C, _mm256_madd_epi16(e, e));
+ }
+ // Row D
+ {
+ const __m256i d0 = load_shuffled_u8_to_epi16(dat_rowD);
+ const __m256i s0 = load_shuffled_u8_to_epi16(src_rowD);
+ const __m256i flt0_16b = _mm256_packs_epi32(
+ yy_loadu_256(flt0_rowD), yy_loadu_256(flt0_rowD + 8));
+ const __m256i flt1_16b = _mm256_packs_epi32(
+ yy_loadu_256(flt1_rowD), yy_loadu_256(flt1_rowD + 8));
+ const __m256i e = calc_proj_err_r0_r1_avx2(d0, s0, flt0_16b, flt1_16b,
+ xq_coeff, rounding, shift);
+ sum32_D = _mm256_add_epi32(sum32_D, _mm256_madd_epi16(e, e));
+ }
+
+ dat += 4 * dat_stride;
+ src += 4 * src_stride;
+ flt0 += 4 * flt0_stride;
+ flt1 += 4 * flt1_stride;
+ }
+ for (; i < height; ++i) {
+ const __m256i d0 = load_shuffled_u8_to_epi16(dat);
+ const __m256i s0 = load_shuffled_u8_to_epi16(src);
+ const __m256i flt0_16b =
+ _mm256_packs_epi32(yy_loadu_256(flt0), yy_loadu_256(flt0 + 8));
+ const __m256i flt1_16b =
+ _mm256_packs_epi32(yy_loadu_256(flt1), yy_loadu_256(flt1 + 8));
+ const __m256i e = calc_proj_err_r0_r1_avx2(d0, s0, flt0_16b, flt1_16b,
+ xq_coeff, rounding, shift);
+ sum32_A = _mm256_add_epi32(sum32_A, _mm256_madd_epi16(e, e));
+
+ dat += dat_stride;
+ src += src_stride;
+ flt0 += flt0_stride;
+ flt1 += flt1_stride;
+ }
+ __m256i sum32 = _mm256_add_epi32(_mm256_add_epi32(sum32_A, sum32_B),
+ _mm256_add_epi32(sum32_C, sum32_D));
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64, _mm256_add_epi64(sum64_0, sum64_1));
+ } else if (width >= 32 && (width % 32 == 0)) {
+ int rows_per_batch = 4096 / width;
+ if (rows_per_batch < 1) rows_per_batch = 1;
+ for (i = 0; i < height;) {
+ int rows_to_do = height - i;
+ if (rows_to_do > rows_per_batch) rows_to_do = rows_per_batch;
+ const int next_i = i + rows_to_do;
+ __m256i sum32_A = _mm256_setzero_si256();
+ __m256i sum32_B = _mm256_setzero_si256();
+ for (; i < next_i; ++i) {
+ for (j = 0; j <= width - 32; j += 32) {
+ const __m256i d_A = load_shuffled_u8_to_epi16(dat + j);
+ const __m256i s_A = load_shuffled_u8_to_epi16(src + j);
+ const __m256i flt0_A = _mm256_packs_epi32(
+ yy_loadu_256(flt0 + j), yy_loadu_256(flt0 + j + 8));
+ const __m256i flt1_A = _mm256_packs_epi32(
+ yy_loadu_256(flt1 + j), yy_loadu_256(flt1 + j + 8));
+ const __m256i e_A = calc_proj_err_r0_r1_avx2(
+ d_A, s_A, flt0_A, flt1_A, xq_coeff, rounding, shift);
+ sum32_A = _mm256_add_epi32(sum32_A, _mm256_madd_epi16(e_A, e_A));
+
+ const __m256i d_B = load_shuffled_u8_to_epi16(dat + j + 16);
+ const __m256i s_B = load_shuffled_u8_to_epi16(src + j + 16);
+ const __m256i flt0_B = _mm256_packs_epi32(
+ yy_loadu_256(flt0 + j + 16), yy_loadu_256(flt0 + j + 24));
+ const __m256i flt1_B = _mm256_packs_epi32(
+ yy_loadu_256(flt1 + j + 16), yy_loadu_256(flt1 + j + 24));
+ const __m256i e_B = calc_proj_err_r0_r1_avx2(
+ d_B, s_B, flt0_B, flt1_B, xq_coeff, rounding, shift);
+ sum32_B = _mm256_add_epi32(sum32_B, _mm256_madd_epi16(e_B, e_B));
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt0 += flt0_stride;
+ flt1 += flt1_stride;
+ }
+ __m256i sum32 = _mm256_add_epi32(sum32_A, sum32_B);
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64, _mm256_add_epi64(sum64_0, sum64_1));
+ }
+ } else {
+ // General fallback
+ for (i = 0; i < height; ++i) {
+ __m256i sum32 = _mm256_setzero_si256();
+ for (j = 0; j <= width - 16; j += 16) {
+ const __m256i d0 = load_shuffled_u8_to_epi16(dat + j);
+ const __m256i s0 = load_shuffled_u8_to_epi16(src + j);
+ const __m256i flt0_16b = _mm256_packs_epi32(
+ yy_loadu_256(flt0 + j), yy_loadu_256(flt0 + j + 8));
+ const __m256i flt1_16b = _mm256_packs_epi32(
+ yy_loadu_256(flt1 + j), yy_loadu_256(flt1 + j + 8));
+ const __m256i e0 = calc_proj_err_r0_r1_avx2(
+ d0, s0, flt0_16b, flt1_16b, xq_coeff, rounding, shift);
+ sum32 = _mm256_add_epi32(sum32, _mm256_madd_epi16(e0, e0));
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += ((int64_t)e * e);
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt0 += flt0_stride;
+ flt1 += flt1_stride;
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64, sum64_0);
+ sum64 = _mm256_add_epi64(sum64, sum64_1);
+ }
}
} else if (params->r[0] > 0 || params->r[1] > 0) {
const int xq_active = (params->r[0] > 0) ? xq[0] : xq[1];
@@ -1610,66 +1822,322 @@ int64_t av1_lowbd_pixel_proj_error_avx2(
pair_set_epi16(xq_active, -xq_active * (1 << SGRPROJ_RST_BITS));
const int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
const int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
- for (i = 0; i < height; ++i) {
+
+ if (width == 8) {
__m256i sum32 = _mm256_setzero_si256();
- for (j = 0; j <= width - 16; j += 16) {
- const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
- const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
- const __m256i flt_16b = _mm256_permute4x64_epi64(
- _mm256_packs_epi32(yy_loadu_256(flt + j),
- yy_loadu_256(flt + j + 8)),
- 0xd8);
- const __m256i v0 =
- _mm256_madd_epi16(xq_coeff, _mm256_unpacklo_epi16(flt_16b, d0));
- const __m256i v1 =
- _mm256_madd_epi16(xq_coeff, _mm256_unpackhi_epi16(flt_16b, d0));
- const __m256i vr0 =
- _mm256_srai_epi32(_mm256_add_epi32(v0, rounding), shift);
- const __m256i vr1 =
- _mm256_srai_epi32(_mm256_add_epi32(v1, rounding), shift);
- const __m256i e0 = _mm256_sub_epi16(
- _mm256_add_epi16(_mm256_packs_epi32(vr0, vr1), d0), s0);
+ const int height_even = height & ~1;
+ for (i = 0; i < height_even; i += 2) {
+ const uint8_t *dat_rowB = dat + dat_stride;
+ const uint8_t *src_rowB = src + src_stride;
+ const int32_t *flt_rowB = flt + flt_stride;
+
+ const __m256i d0 = load_shuffled_u8_dual8_to_epi16(dat, dat_rowB);
+ const __m256i s0 = load_shuffled_u8_dual8_to_epi16(src, src_rowB);
+ const __m256i flt_16b =
+ _mm256_packs_epi32(yy_loadu_256(flt), yy_loadu_256(flt_rowB));
+
+ const __m256i e0 = calc_proj_err_r0_or_r1_avx2(
+ d0, s0, flt_16b, xq_coeff, rounding, shift);
const __m256i err0 = _mm256_madd_epi16(e0, e0);
sum32 = _mm256_add_epi32(sum32, err0);
+
+ dat += 2 * dat_stride;
+ src += 2 * src_stride;
+ flt += 2 * flt_stride;
}
- for (k = j; k < width; ++k) {
- const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
- int32_t v = xq_active * (flt[k] - u);
- const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
- err += ((int64_t)e * e);
+ if (i < height) {
+ for (k = 0; k < 8; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq_active * (flt[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += ((int64_t)e * e);
+ }
}
- dat += dat_stride;
- src += src_stride;
- flt += flt_stride;
const __m256i sum64_0 =
_mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
const __m256i sum64_1 =
_mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
- sum64 = _mm256_add_epi64(sum64, sum64_0);
- sum64 = _mm256_add_epi64(sum64, sum64_1);
+ sum64 = _mm256_add_epi64(sum64, _mm256_add_epi64(sum64_0, sum64_1));
+ } else if (width == 16) {
+ __m256i sum32_A = _mm256_setzero_si256();
+ __m256i sum32_B = _mm256_setzero_si256();
+ __m256i sum32_C = _mm256_setzero_si256();
+ __m256i sum32_D = _mm256_setzero_si256();
+ const int height_v4 = height & ~3;
+ for (i = 0; i < height_v4; i += 4) {
+ const uint8_t *dat_rowB = dat + dat_stride;
+ const uint8_t *dat_rowC = dat_rowB + dat_stride;
+ const uint8_t *dat_rowD = dat_rowC + dat_stride;
+ const uint8_t *src_rowB = src + src_stride;
+ const uint8_t *src_rowC = src_rowB + src_stride;
+ const uint8_t *src_rowD = src_rowC + src_stride;
+ const int32_t *flt_rowB = flt + flt_stride;
+ const int32_t *flt_rowC = flt_rowB + flt_stride;
+ const int32_t *flt_rowD = flt_rowC + flt_stride;
+
+ // Row A
+ {
+ const __m256i d0 = load_shuffled_u8_to_epi16(dat);
+ const __m256i s0 = load_shuffled_u8_to_epi16(src);
+ const __m256i flt_16b =
+ _mm256_packs_epi32(yy_loadu_256(flt), yy_loadu_256(flt + 8));
+ const __m256i e = calc_proj_err_r0_or_r1_avx2(
+ d0, s0, flt_16b, xq_coeff, rounding, shift);
+ sum32_A = _mm256_add_epi32(sum32_A, _mm256_madd_epi16(e, e));
+ }
+ // Row B
+ {
+ const __m256i d0 = load_shuffled_u8_to_epi16(dat_rowB);
+ const __m256i s0 = load_shuffled_u8_to_epi16(src_rowB);
+ const __m256i flt_16b = _mm256_packs_epi32(
+ yy_loadu_256(flt_rowB), yy_loadu_256(flt_rowB + 8));
+ const __m256i e = calc_proj_err_r0_or_r1_avx2(
+ d0, s0, flt_16b, xq_coeff, rounding, shift);
+ sum32_B = _mm256_add_epi32(sum32_B, _mm256_madd_epi16(e, e));
+ }
+ // Row C
+ {
+ const __m256i d0 = load_shuffled_u8_to_epi16(dat_rowC);
+ const __m256i s0 = load_shuffled_u8_to_epi16(src_rowC);
+ const __m256i flt_16b = _mm256_packs_epi32(
+ yy_loadu_256(flt_rowC), yy_loadu_256(flt_rowC + 8));
+ const __m256i e = calc_proj_err_r0_or_r1_avx2(
+ d0, s0, flt_16b, xq_coeff, rounding, shift);
+ sum32_C = _mm256_add_epi32(sum32_C, _mm256_madd_epi16(e, e));
+ }
+ // Row D
+ {
+ const __m256i d0 = load_shuffled_u8_to_epi16(dat_rowD);
+ const __m256i s0 = load_shuffled_u8_to_epi16(src_rowD);
+ const __m256i flt_16b = _mm256_packs_epi32(
+ yy_loadu_256(flt_rowD), yy_loadu_256(flt_rowD + 8));
+ const __m256i e = calc_proj_err_r0_or_r1_avx2(
+ d0, s0, flt_16b, xq_coeff, rounding, shift);
+ sum32_D = _mm256_add_epi32(sum32_D, _mm256_madd_epi16(e, e));
+ }
+
+ dat += 4 * dat_stride;
+ src += 4 * src_stride;
+ flt += 4 * flt_stride;
+ }
+ for (; i < height; ++i) {
+ const __m256i d0 = load_shuffled_u8_to_epi16(dat);
+ const __m256i s0 = load_shuffled_u8_to_epi16(src);
+ const __m256i flt_16b =
+ _mm256_packs_epi32(yy_loadu_256(flt), yy_loadu_256(flt + 8));
+ const __m256i e = calc_proj_err_r0_or_r1_avx2(d0, s0, flt_16b, xq_coeff,
+ rounding, shift);
+ sum32_A = _mm256_add_epi32(sum32_A, _mm256_madd_epi16(e, e));
+
+ dat += dat_stride;
+ src += src_stride;
+ flt += flt_stride;
+ }
+ __m256i sum32 = _mm256_add_epi32(_mm256_add_epi32(sum32_A, sum32_B),
+ _mm256_add_epi32(sum32_C, sum32_D));
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64, _mm256_add_epi64(sum64_0, sum64_1));
+ } else if (width >= 32 && (width % 32 == 0)) {
+ int rows_per_batch = 4096 / width;
+ if (rows_per_batch < 1) rows_per_batch = 1;
+ for (i = 0; i < height;) {
+ int rows_to_do = height - i;
+ if (rows_to_do > rows_per_batch) rows_to_do = rows_per_batch;
+ const int next_i = i + rows_to_do;
+ __m256i sum32_A = _mm256_setzero_si256();
+ __m256i sum32_B = _mm256_setzero_si256();
+ for (; i < next_i; ++i) {
+ for (j = 0; j <= width - 32; j += 32) {
+ const __m256i d_A = load_shuffled_u8_to_epi16(dat + j);
+ const __m256i s_A = load_shuffled_u8_to_epi16(src + j);
+ const __m256i flt_A = _mm256_packs_epi32(yy_loadu_256(flt + j),
+ yy_loadu_256(flt + j + 8));
+ const __m256i e_A = calc_proj_err_r0_or_r1_avx2(
+ d_A, s_A, flt_A, xq_coeff, rounding, shift);
+ sum32_A = _mm256_add_epi32(sum32_A, _mm256_madd_epi16(e_A, e_A));
+
+ const __m256i d_B = load_shuffled_u8_to_epi16(dat + j + 16);
+ const __m256i s_B = load_shuffled_u8_to_epi16(src + j + 16);
+ const __m256i flt_B = _mm256_packs_epi32(
+ yy_loadu_256(flt + j + 16), yy_loadu_256(flt + j + 24));
+ const __m256i e_B = calc_proj_err_r0_or_r1_avx2(
+ d_B, s_B, flt_B, xq_coeff, rounding, shift);
+ sum32_B = _mm256_add_epi32(sum32_B, _mm256_madd_epi16(e_B, e_B));
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt += flt_stride;
+ }
+ __m256i sum32 = _mm256_add_epi32(sum32_A, sum32_B);
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64, _mm256_add_epi64(sum64_0, sum64_1));
+ }
+ } else {
+ // General fallback
+ for (i = 0; i < height; ++i) {
+ __m256i sum32 = _mm256_setzero_si256();
+ for (j = 0; j <= width - 16; j += 16) {
+ const __m256i d0 = load_shuffled_u8_to_epi16(dat + j);
+ const __m256i s0 = load_shuffled_u8_to_epi16(src + j);
+ const __m256i flt_16b = _mm256_packs_epi32(yy_loadu_256(flt + j),
+ yy_loadu_256(flt + j + 8));
+ const __m256i e0 = calc_proj_err_r0_or_r1_avx2(
+ d0, s0, flt_16b, xq_coeff, rounding, shift);
+ sum32 = _mm256_add_epi32(sum32, _mm256_madd_epi16(e0, e0));
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t u = (int32_t)(dat[k] << SGRPROJ_RST_BITS);
+ int32_t v = xq_active * (flt[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += ((int64_t)e * e);
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt += flt_stride;
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64, sum64_0);
+ sum64 = _mm256_add_epi64(sum64, sum64_1);
+ }
}
} else {
- __m256i sum32 = _mm256_setzero_si256();
- for (i = 0; i < height; ++i) {
- for (j = 0; j <= width - 16; j += 16) {
- const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
- const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
- const __m256i diff0 = _mm256_sub_epi16(d0, s0);
- const __m256i err0 = _mm256_madd_epi16(diff0, diff0);
- sum32 = _mm256_add_epi32(sum32, err0);
+ if (width == 8) {
+ __m256i sum32 = _mm256_setzero_si256();
+ const int height_even = height & ~1;
+ for (i = 0; i < height_even; i += 2) {
+ const uint8_t *dat_rowB = dat + dat_stride;
+ const uint8_t *src_rowB = src + src_stride;
+
+ const __m128i d_AB =
+ _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)dat),
+ _mm_loadl_epi64((const __m128i *)dat_rowB));
+ const __m128i s_AB =
+ _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)src),
+ _mm_loadl_epi64((const __m128i *)src_rowB));
+ const __m256i diff = _mm256_sub_epi16(_mm256_cvtepu8_epi16(d_AB),
+ _mm256_cvtepu8_epi16(s_AB));
+ sum32 = _mm256_add_epi32(sum32, _mm256_madd_epi16(diff, diff));
+ dat += 2 * dat_stride;
+ src += 2 * src_stride;
}
- for (k = j; k < width; ++k) {
- const int32_t e = (int32_t)(dat[k]) - src[k];
- err += ((int64_t)e * e);
+ if (i < height) {
+ const __m128i d_A = _mm_loadl_epi64((const __m128i *)dat);
+ const __m128i s_A = _mm_loadl_epi64((const __m128i *)src);
+ const __m256i diff = _mm256_sub_epi16(_mm256_cvtepu8_epi16(d_A),
+ _mm256_cvtepu8_epi16(s_A));
+ sum32 = _mm256_add_epi32(sum32, _mm256_madd_epi16(diff, diff));
}
- dat += dat_stride;
- src += src_stride;
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64_0, sum64_1);
+ } else if (width >= 32 && (width % 32 == 0)) {
+ __m256i sum32_A = _mm256_setzero_si256();
+ __m256i sum32_B = _mm256_setzero_si256();
+ for (i = 0; i < height; ++i) {
+ for (j = 0; j <= width - 32; j += 32) {
+ const __m256i d_A = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
+ const __m256i s_A = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
+ const __m256i diff_A = _mm256_sub_epi16(d_A, s_A);
+ sum32_A =
+ _mm256_add_epi32(sum32_A, _mm256_madd_epi16(diff_A, diff_A));
+
+ const __m256i d_B = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j + 16));
+ const __m256i s_B = _mm256_cvtepu8_epi16(xx_loadu_128(src + j + 16));
+ const __m256i diff_B = _mm256_sub_epi16(d_B, s_B);
+ sum32_B =
+ _mm256_add_epi32(sum32_B, _mm256_madd_epi16(diff_B, diff_B));
+ }
+ dat += dat_stride;
+ src += src_stride;
+ }
+ __m256i sum32 = _mm256_add_epi32(sum32_A, sum32_B);
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64_0, sum64_1);
+ } else if (width >= 16) {
+ __m256i sum32_A = _mm256_setzero_si256();
+ __m256i sum32_B = _mm256_setzero_si256();
+ const int height_even = height & ~1;
+ for (i = 0; i < height_even; i += 2) {
+ const uint8_t *dat_rowB = dat + dat_stride;
+ const uint8_t *src_rowB = src + src_stride;
+ for (j = 0; j <= width - 16; j += 16) {
+ const __m256i d_A = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
+ const __m256i s_A = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
+ const __m256i diff_A = _mm256_sub_epi16(d_A, s_A);
+ sum32_A =
+ _mm256_add_epi32(sum32_A, _mm256_madd_epi16(diff_A, diff_A));
+
+ const __m256i d_B = _mm256_cvtepu8_epi16(xx_loadu_128(dat_rowB + j));
+ const __m256i s_B = _mm256_cvtepu8_epi16(xx_loadu_128(src_rowB + j));
+ const __m256i diff_B = _mm256_sub_epi16(d_B, s_B);
+ sum32_B =
+ _mm256_add_epi32(sum32_B, _mm256_madd_epi16(diff_B, diff_B));
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t e_A = (int32_t)dat[k] - src[k];
+ err += (int64_t)e_A * e_A;
+ const int32_t e_B = (int32_t)dat_rowB[k] - src_rowB[k];
+ err += (int64_t)e_B * e_B;
+ }
+ dat += 2 * dat_stride;
+ src += 2 * src_stride;
+ }
+ if (i < height) {
+ for (j = 0; j <= width - 16; j += 16) {
+ const __m256i d_A = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
+ const __m256i s_A = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
+ const __m256i diff_A = _mm256_sub_epi16(d_A, s_A);
+ sum32_A =
+ _mm256_add_epi32(sum32_A, _mm256_madd_epi16(diff_A, diff_A));
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t e_A = (int32_t)dat[k] - src[k];
+ err += (int64_t)e_A * e_A;
+ }
+ }
+ __m256i sum32 = _mm256_add_epi32(sum32_A, sum32_B);
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64_0, sum64_1);
+ } else {
+ // General fallback
+ __m256i sum32 = _mm256_setzero_si256();
+ for (i = 0; i < height; ++i) {
+ for (j = 0; j <= width - 16; j += 16) {
+ const __m256i d0 = _mm256_cvtepu8_epi16(xx_loadu_128(dat + j));
+ const __m256i s0 = _mm256_cvtepu8_epi16(xx_loadu_128(src + j));
+ const __m256i diff0 = _mm256_sub_epi16(d0, s0);
+ const __m256i err0 = _mm256_madd_epi16(diff0, diff0);
+ sum32 = _mm256_add_epi32(sum32, err0);
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t e = (int32_t)(dat[k]) - src[k];
+ err += ((int64_t)e * e);
+ }
+ dat += dat_stride;
+ src += src_stride;
+ }
+ const __m256i sum64_0 =
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
+ const __m256i sum64_1 =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
+ sum64 = _mm256_add_epi64(sum64_0, sum64_1);
}
- const __m256i sum64_0 =
- _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum32));
- const __m256i sum64_1 =
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum32, 1));
- sum64 = _mm256_add_epi64(sum64_0, sum64_1);
}
int64_t sum[4];
yy_storeu_256(sum, sum64);