Commit 56a172d730 for aom

commit 56a172d7300768390612b021dd2c3b3696977686
Author: Rohan Baid <rohan.baid@ittiam.com>
Date:   Fri Mar 20 10:59:57 2026 +0530

    Enable AVX2 and SSE2 for av1_highbd_apply_temporal_filter()

    The AVX2 and SSE2 implementations are modified for TF_BLOCK_SIZE of 64x64.
    Scaling w.r.t C are as follows:

           tf_wgt_calc_lvl=0    tf_wgt_calc_lvl=1
    AVX2        10.48x               8.40x
    SSE2         9.74x               7.86x

    Change-Id: I29c3b6adef4ca8174e11f401be6ba21c7bac177c

diff --git a/av1/encoder/arm/highbd_temporal_filter_neon.c b/av1/encoder/arm/highbd_temporal_filter_neon.c
index 59cd3fc7fc..dddac6ad6f 100644
--- a/av1/encoder/arm/highbd_temporal_filter_neon.c
+++ b/av1/encoder/arm/highbd_temporal_filter_neon.c
@@ -289,6 +289,13 @@ void av1_highbd_apply_temporal_filter_neon(
     const int *subblock_mses, const int q_factor, const int filter_strength,
     int tf_wgt_calc_lvl, const uint8_t *pred8, uint32_t *accum,
     uint16_t *count) {
+  if (block_size == BLOCK_64X64) {
+    av1_apply_temporal_filter_c(frame_to_filter, mbd, block_size, mb_row,
+                                mb_col, num_planes, noise_levels, subblock_mvs,
+                                subblock_mses, q_factor, filter_strength,
+                                tf_wgt_calc_lvl, pred8, accum, count);
+    return;
+  }
   const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH;
   assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with Neon!");
   assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE);
diff --git a/av1/encoder/temporal_filter.c b/av1/encoder/temporal_filter.c
index 68c8ae6a8b..1dbb5bca71 100644
--- a/av1/encoder/temporal_filter.c
+++ b/av1/encoder/temporal_filter.c
@@ -1156,7 +1156,7 @@ void av1_tf_do_filtering_row(AV1_COMP *cpi, ThreadData *td, int mb_row) {
         // only supports 32x32 block size and 5x5 filtering window.
         if (is_frame_high_bitdepth(frame_to_filter)) {  // for high bit-depth
 #if CONFIG_AV1_HIGHBITDEPTH
-          if (!is_yuv422_format && TF_BLOCK_SIZE == BLOCK_32X32 &&
+          if (!is_yuv422_format && TF_BLOCK_SIZE == BLOCK_64X64 &&
               TF_WINDOW_LENGTH == 5) {
             av1_highbd_apply_temporal_filter(
                 frame_to_filter, mbd, block_size, mb_row, mb_col, num_planes,
diff --git a/av1/encoder/x86/highbd_temporal_filter_avx2.c b/av1/encoder/x86/highbd_temporal_filter_avx2.c
index 6232e9e3c2..acaf7678d9 100644
--- a/av1/encoder/x86/highbd_temporal_filter_avx2.c
+++ b/av1/encoder/x86/highbd_temporal_filter_avx2.c
@@ -26,78 +26,48 @@ DECLARE_ALIGNED(32, static const uint32_t, sse_bytemask[4][8]) = {
   { 0, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF }
 };

-static AOM_FORCE_INLINE void get_squared_error_16x16_avx2(
+static AOM_FORCE_INLINE void get_squared_error_avx2(
     const uint16_t *frame1, const unsigned int stride, const uint16_t *frame2,
     const unsigned int stride2, const int block_width, const int block_height,
     uint32_t *frame_sse, const unsigned int sse_stride) {
-  (void)block_width;
   const uint16_t *src1 = frame1;
   const uint16_t *src2 = frame2;
+  assert(block_width >= 32);
   uint32_t *dst = frame_sse + 2;
   for (int i = 0; i < block_height; i++) {
-    __m256i v_src1 = _mm256_loadu_si256((__m256i *)src1);
-    __m256i v_src2 = _mm256_loadu_si256((__m256i *)src2);
-    __m256i v_diff = _mm256_sub_epi16(v_src1, v_src2);
-    __m256i v_mullo = _mm256_mullo_epi16(v_diff, v_diff);
-    __m256i v_mulhi = _mm256_mulhi_epi16(v_diff, v_diff);
-
-    __m256i v_lo = _mm256_unpacklo_epi16(v_mullo, v_mulhi);
-    __m256i v_hi = _mm256_unpackhi_epi16(v_mullo, v_mulhi);
-    __m256i diff_lo =
-        _mm256_inserti128_si256(v_lo, _mm256_extracti128_si256(v_hi, 0), 1);
-    __m256i diff_hi =
-        _mm256_inserti128_si256(v_hi, _mm256_extracti128_si256(v_lo, 1), 0);
-
-    _mm256_storeu_si256((__m256i *)dst, diff_lo);
-    dst += 8;
-    _mm256_storeu_si256((__m256i *)dst, diff_hi);
-
-    src1 += stride, src2 += stride2;
-    dst += sse_stride - 8;
-  }
-}
-
-static AOM_FORCE_INLINE void get_squared_error_32x32_avx2(
-    const uint16_t *frame1, const unsigned int stride, const uint16_t *frame2,
-    const unsigned int stride2, const int block_width, const int block_height,
-    uint32_t *frame_sse, const unsigned int sse_stride) {
-  (void)block_width;
-  const uint16_t *src1 = frame1;
-  const uint16_t *src2 = frame2;
-  uint32_t *dst = frame_sse + 2;
-  for (int i = 0; i < block_height; i++) {
-    __m256i v_src1 = _mm256_loadu_si256((__m256i *)src1);
-    __m256i v_src2 = _mm256_loadu_si256((__m256i *)src2);
-    __m256i v_diff = _mm256_sub_epi16(v_src1, v_src2);
-    __m256i v_mullo = _mm256_mullo_epi16(v_diff, v_diff);
-    __m256i v_mulhi = _mm256_mulhi_epi16(v_diff, v_diff);
-
-    __m256i v_lo = _mm256_unpacklo_epi16(v_mullo, v_mulhi);
-    __m256i v_hi = _mm256_unpackhi_epi16(v_mullo, v_mulhi);
-    __m256i diff_lo =
-        _mm256_inserti128_si256(v_lo, _mm256_extracti128_si256(v_hi, 0), 1);
-    __m256i diff_hi =
-        _mm256_inserti128_si256(v_hi, _mm256_extracti128_si256(v_lo, 1), 0);
-
-    _mm256_storeu_si256((__m256i *)dst, diff_lo);
-    _mm256_storeu_si256((__m256i *)(dst + 8), diff_hi);
-
-    v_src1 = _mm256_loadu_si256((__m256i *)(src1 + 16));
-    v_src2 = _mm256_loadu_si256((__m256i *)(src2 + 16));
-    v_diff = _mm256_sub_epi16(v_src1, v_src2);
-    v_mullo = _mm256_mullo_epi16(v_diff, v_diff);
-    v_mulhi = _mm256_mulhi_epi16(v_diff, v_diff);
-
-    v_lo = _mm256_unpacklo_epi16(v_mullo, v_mulhi);
-    v_hi = _mm256_unpackhi_epi16(v_mullo, v_mulhi);
-    diff_lo =
-        _mm256_inserti128_si256(v_lo, _mm256_extracti128_si256(v_hi, 0), 1);
-    diff_hi =
-        _mm256_inserti128_si256(v_hi, _mm256_extracti128_si256(v_lo, 1), 0);
-
-    _mm256_storeu_si256((__m256i *)(dst + 16), diff_lo);
-    _mm256_storeu_si256((__m256i *)(dst + 24), diff_hi);
-
+    for (int j = 0; j < block_width; j += 32) {
+      __m256i v_src1 = _mm256_loadu_si256((__m256i *)(src1 + j));
+      __m256i v_src2 = _mm256_loadu_si256((__m256i *)(src2 + j));
+      __m256i v_diff = _mm256_sub_epi16(v_src1, v_src2);
+      __m256i v_mullo = _mm256_mullo_epi16(v_diff, v_diff);
+      __m256i v_mulhi = _mm256_mulhi_epi16(v_diff, v_diff);
+
+      __m256i v_lo = _mm256_unpacklo_epi16(v_mullo, v_mulhi);
+      __m256i v_hi = _mm256_unpackhi_epi16(v_mullo, v_mulhi);
+      __m256i diff_lo =
+          _mm256_inserti128_si256(v_lo, _mm256_extracti128_si256(v_hi, 0), 1);
+      __m256i diff_hi =
+          _mm256_inserti128_si256(v_hi, _mm256_extracti128_si256(v_lo, 1), 0);
+
+      _mm256_storeu_si256((__m256i *)(dst + j), diff_lo);
+      _mm256_storeu_si256((__m256i *)(dst + j + 8), diff_hi);
+
+      v_src1 = _mm256_loadu_si256((__m256i *)(src1 + j + 16));
+      v_src2 = _mm256_loadu_si256((__m256i *)(src2 + j + 16));
+      v_diff = _mm256_sub_epi16(v_src1, v_src2);
+      v_mullo = _mm256_mullo_epi16(v_diff, v_diff);
+      v_mulhi = _mm256_mulhi_epi16(v_diff, v_diff);
+
+      v_lo = _mm256_unpacklo_epi16(v_mullo, v_mulhi);
+      v_hi = _mm256_unpackhi_epi16(v_mullo, v_mulhi);
+      diff_lo =
+          _mm256_inserti128_si256(v_lo, _mm256_extracti128_si256(v_hi, 0), 1);
+      diff_hi =
+          _mm256_inserti128_si256(v_hi, _mm256_extracti128_si256(v_lo, 1), 0);
+
+      _mm256_storeu_si256((__m256i *)(dst + j + 16), diff_lo);
+      _mm256_storeu_si256((__m256i *)(dst + j + 24), diff_hi);
+    }
     src1 += stride;
     src2 += stride2;
     dst += sse_stride;
@@ -150,18 +120,13 @@ static void highbd_apply_temporal_filter(
     const double inv_num_ref_pixels, const double decay_factor,
     const double inv_factor, const double weight_factor, double *d_factor,
     int tf_wgt_calc_lvl) {
-  assert(((block_width == 16) || (block_width == 32)) &&
-         ((block_height == 16) || (block_height == 32)));
+  assert(((block_width == 64) || (block_width == 32)) &&
+         ((block_height == 64) || (block_height == 32)));

   uint32_t acc_5x5_sse[BH][BW];

-  if (block_width == 32) {
-    get_squared_error_32x32_avx2(frame1, stride, frame2, stride2, block_width,
-                                 block_height, frame_sse, SSE_STRIDE);
-  } else {
-    get_squared_error_16x16_avx2(frame1, stride, frame2, stride2, block_width,
-                                 block_height, frame_sse, SSE_STRIDE);
-  }
+  get_squared_error_avx2(frame1, stride, frame2, stride2, block_width,
+                         block_height, frame_sse, SSE_STRIDE);

   __m256i vsrc[5];

@@ -306,15 +271,17 @@ static void highbd_apply_temporal_filter(
     acc_5x5_sse[row][col + 3] = xx_mask_and_hadd(vsum, 3);
   }

-  double subblock_mses_scaled[4];
-  double d_factor_decayed[4];
-  for (int idx = 0; idx < 4; idx++) {
+  double subblock_mses_scaled[NUM_16X16];
+  double d_factor_decayed[NUM_16X16];
+  for (int idx = 0; idx < NUM_16X16; idx++) {
     subblock_mses_scaled[idx] = subblock_mses[idx] * inv_factor;
     d_factor_decayed[idx] = d_factor[idx] * decay_factor;
   }
   if (tf_wgt_calc_lvl == 0) {
     for (int i = 0, k = 0; i < block_height; i++) {
-      const int y_blk_raster_offset = (i >= block_height / 2) * 2;
+      const int y32_blk_raster_offset = (i >= (block_height >> 1)) << 1;
+      const int y16_blk_raster_offset =
+          ((i % (block_height >> 1)) >= (block_height >> 2)) << 1;
       for (int j = 0; j < block_width; j++, k++) {
         const int pixel_value = frame2[i * stride2 + j];
         uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
@@ -323,7 +290,12 @@ static void highbd_apply_temporal_filter(
         diff_sse >>= ((bd - 8) * 2);

         const double window_error = diff_sse * inv_num_ref_pixels;
-        const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
+        const int x32_blk_raster_offset = (j >= (block_width >> 1));
+        const int x16_blk_raster_offset =
+            ((j % (block_width >> 1)) >= (block_width >> 2));
+        const int subblock_idx =
+            ((y32_blk_raster_offset + x32_blk_raster_offset) << 2) +
+            (y16_blk_raster_offset + x16_blk_raster_offset);

         const double combined_error =
             weight_factor * window_error + subblock_mses_scaled[subblock_idx];
@@ -338,7 +310,9 @@ static void highbd_apply_temporal_filter(
     }
   } else {
     for (int i = 0, k = 0; i < block_height; i++) {
-      const int y_blk_raster_offset = (i >= block_height / 2) * 2;
+      const int y32_blk_raster_offset = (i >= (block_height >> 1)) << 1;
+      const int y16_blk_raster_offset =
+          ((i % (block_height >> 1)) >= (block_height >> 2)) << 1;
       for (int j = 0; j < block_width; j++, k++) {
         const int pixel_value = frame2[i * stride2 + j];
         uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
@@ -347,7 +321,13 @@ static void highbd_apply_temporal_filter(
         diff_sse >>= ((bd - 8) * 2);

         const double window_error = diff_sse * inv_num_ref_pixels;
-        const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
+        // const int subblock_idx = (y32 * 2 + x32) * 4 + (y16 * 2 + x16);
+        const int x32_blk_raster_offset = (j >= (block_width >> 1));
+        const int x16_blk_raster_offset =
+            ((j % (block_width >> 1)) >= (block_width >> 2));
+        const int subblock_idx =
+            ((y32_blk_raster_offset + x32_blk_raster_offset) << 2) +
+            (y16_blk_raster_offset + x16_blk_raster_offset);

         const double combined_error =
             weight_factor * window_error + subblock_mses_scaled[subblock_idx];
@@ -373,8 +353,8 @@ void av1_highbd_apply_temporal_filter_avx2(
     int tf_wgt_calc_lvl, const uint8_t *pred, uint32_t *accum,
     uint16_t *count) {
   const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH;
-  assert(block_size == BLOCK_32X32 && "Only support 32x32 block with sse2!");
-  assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with sse2!");
+  assert(block_size == BLOCK_64X64 && "Only support 64x64  block with avx2!");
+  assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with avx2!");
   assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE);
   (void)is_high_bitdepth;

@@ -401,17 +381,21 @@ void av1_highbd_apply_temporal_filter_avx2(
   // Smaller strength -> smaller filtering weight.
   double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
   s_decay = CLIP(s_decay, 1e-5, 1);
-  double d_factor[4] = { 0 };
-  uint32_t frame_sse[SSE_STRIDE * BH] = { 0 };
-  uint32_t luma_sse_sum[BW * BH] = { 0 };
+  double d_factor[NUM_16X16] = { 0 };
+  uint32_t *frame_sse =
+      (uint32_t *)aom_memalign(32, sizeof(frame_sse[0]) * SSE_STRIDE * BH);
+  uint32_t *luma_sse_sum =
+      (uint32_t *)aom_memalign(32, sizeof(luma_sse_sum[0]) * BW * BH);
+  memset(frame_sse, 0, sizeof(frame_sse[0]) * SSE_STRIDE * BH);
+  memset(luma_sse_sum, 0, sizeof(luma_sse_sum[0]) * BW * BH);
   uint16_t *pred1 = CONVERT_TO_SHORTPTR(pred);

-  for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) {
+  double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
+  distance_threshold = AOMMAX(distance_threshold, 1);
+  for (int subblock_idx = 0; subblock_idx < NUM_16X16; subblock_idx++) {
     // Larger motion vector -> smaller filtering weight.
     const MV mv = subblock_mvs[subblock_idx];
     const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
-    double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
-    distance_threshold = AOMMAX(distance_threshold, 1);
     d_factor[subblock_idx] = distance / distance_threshold;
     d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1);
   }
@@ -443,13 +427,36 @@ void av1_highbd_apply_temporal_filter_avx2(
     // will be more accurate. The luma sse sum is reused in both chroma
     // planes.
     if (plane == AOM_PLANE_U) {
-      for (unsigned int i = 0, k = 0; i < plane_h; i++) {
-        for (unsigned int j = 0; j < plane_w; j++, k++) {
-          for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
-            for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
-              const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
-              const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
-              luma_sse_sum[i * BW + j] += frame_sse[yy * SSE_STRIDE + xx + 2];
+      if (ss_x_shift == 1 && ss_y_shift == 1) {
+        for (unsigned int i = 0; i < plane_h; i++) {
+          const uint32_t *src_0 = &frame_sse[2 * i * SSE_STRIDE + 2];
+          const uint32_t *src_1 = &frame_sse[(2 * i + 1) * SSE_STRIDE + 2];
+          for (unsigned int j = 0; j < plane_w; j += 8) {
+            const __m256i reg0_lo =
+                _mm256_loadu_si256((__m256i *)(src_0 + 2 * j));
+            const __m256i reg1_lo =
+                _mm256_loadu_si256((__m256i *)(src_1 + 2 * j));
+            const __m256i reg0_hi =
+                _mm256_loadu_si256((__m256i *)(src_0 + 2 * j + 8));
+            const __m256i reg1_hi =
+                _mm256_loadu_si256((__m256i *)(src_1 + 2 * j + 8));
+
+            const __m256i reg_0 = _mm256_add_epi32(reg0_lo, reg1_lo);
+            const __m256i reg_1 = _mm256_add_epi32(reg0_hi, reg1_hi);
+            __m256i res = _mm256_hadd_epi32(reg_0, reg_1);
+            res = _mm256_permute4x64_epi64(res, 0xD8);
+            _mm256_storeu_si256((__m256i *)&luma_sse_sum[i * BW + j], res);
+          }
+        }
+      } else {
+        for (unsigned int i = 0, k = 0; i < plane_h; i++) {
+          for (unsigned int j = 0; j < plane_w; j++, k++) {
+            for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
+              for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
+                const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
+                const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
+                luma_sse_sum[i * BW + j] += frame_sse[yy * SSE_STRIDE + xx + 2];
+              }
             }
           }
         }
@@ -463,4 +470,6 @@ void av1_highbd_apply_temporal_filter_avx2(
         weight_factor, d_factor, tf_wgt_calc_lvl);
     plane_offset += plane_h * plane_w;
   }
+  aom_free(frame_sse);
+  aom_free(luma_sse_sum);
 }
diff --git a/av1/encoder/x86/highbd_temporal_filter_sse2.c b/av1/encoder/x86/highbd_temporal_filter_sse2.c
index b328dcac40..5cf18ab36a 100644
--- a/av1/encoder/x86/highbd_temporal_filter_sse2.c
+++ b/av1/encoder/x86/highbd_temporal_filter_sse2.c
@@ -98,8 +98,8 @@ static void highbd_apply_temporal_filter(
     const double inv_num_ref_pixels, const double decay_factor,
     const double inv_factor, const double weight_factor, double *d_factor,
     int tf_wgt_calc_lvl) {
-  assert(((block_width == 16) || (block_width == 32)) &&
-         ((block_height == 16) || (block_height == 32)));
+  assert(((block_width == 64) || (block_width == 32)) &&
+         ((block_height == 64) || (block_height == 32)));

   uint32_t acc_5x5_sse[BH][BW];

@@ -181,15 +181,17 @@ static void highbd_apply_temporal_filter(
     }
   }

-  double subblock_mses_scaled[4];
-  double d_factor_decayed[4];
-  for (int idx = 0; idx < 4; idx++) {
+  double subblock_mses_scaled[NUM_16X16];
+  double d_factor_decayed[NUM_16X16];
+  for (int idx = 0; idx < NUM_16X16; idx++) {
     subblock_mses_scaled[idx] = subblock_mses[idx] * inv_factor;
     d_factor_decayed[idx] = d_factor[idx] * decay_factor;
   }
   if (tf_wgt_calc_lvl == 0) {
     for (int i = 0, k = 0; i < block_height; i++) {
-      const int y_blk_raster_offset = (i >= block_height / 2) * 2;
+      const int y32_blk_raster_offset = (i >= (block_height >> 1)) << 1;
+      const int y16_blk_raster_offset =
+          ((i % (block_height >> 1)) >= (block_height >> 2)) << 1;
       for (int j = 0; j < block_width; j++, k++) {
         const int pixel_value = frame2[i * stride2 + j];
         uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
@@ -198,7 +200,12 @@ static void highbd_apply_temporal_filter(
         diff_sse >>= ((bd - 8) * 2);

         const double window_error = diff_sse * inv_num_ref_pixels;
-        const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
+        const int x32_blk_raster_offset = (j >= (block_width >> 1));
+        const int x16_blk_raster_offset =
+            ((j % (block_width >> 1)) >= (block_width >> 2));
+        const int subblock_idx =
+            ((y32_blk_raster_offset + x32_blk_raster_offset) << 2) +
+            (y16_blk_raster_offset + x16_blk_raster_offset);

         const double combined_error =
             weight_factor * window_error + subblock_mses_scaled[subblock_idx];
@@ -213,7 +220,9 @@ static void highbd_apply_temporal_filter(
     }
   } else {
     for (int i = 0, k = 0; i < block_height; i++) {
-      const int y_blk_raster_offset = (i >= block_height / 2) * 2;
+      const int y32_blk_raster_offset = (i >= (block_height >> 1)) << 1;
+      const int y16_blk_raster_offset =
+          ((i % (block_height >> 1)) >= (block_height >> 2)) << 1;
       for (int j = 0; j < block_width; j++, k++) {
         const int pixel_value = frame2[i * stride2 + j];
         uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
@@ -222,7 +231,12 @@ static void highbd_apply_temporal_filter(
         diff_sse >>= ((bd - 8) * 2);

         const double window_error = diff_sse * inv_num_ref_pixels;
-        const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
+        const int x32_blk_raster_offset = (j >= (block_width >> 1));
+        const int x16_blk_raster_offset =
+            ((j % (block_width >> 1)) >= (block_width >> 2));
+        const int subblock_idx =
+            ((y32_blk_raster_offset + x32_blk_raster_offset) << 2) +
+            (y16_blk_raster_offset + x16_blk_raster_offset);

         const double combined_error =
             weight_factor * window_error + subblock_mses_scaled[subblock_idx];
@@ -248,7 +262,7 @@ void av1_highbd_apply_temporal_filter_sse2(
     int tf_wgt_calc_lvl, const uint8_t *pred, uint32_t *accum,
     uint16_t *count) {
   const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH;
-  assert(block_size == BLOCK_32X32 && "Only support 32x32 block with sse2!");
+  assert(block_size == BLOCK_64X64 && "Only support 64x64 block with sse2!");
   assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with sse2!");
   assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE);
   (void)is_high_bitdepth;
@@ -276,17 +290,22 @@ void av1_highbd_apply_temporal_filter_sse2(
   // Smaller strength -> smaller filtering weight.
   double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
   s_decay = CLIP(s_decay, 1e-5, 1);
-  double d_factor[4] = { 0 };
-  uint32_t frame_sse[SSE_STRIDE * BH] = { 0 };
-  uint32_t luma_sse_sum[BW * BH] = { 0 };
+
+  double d_factor[NUM_16X16] = { 0 };
+  uint32_t *frame_sse =
+      (uint32_t *)aom_memalign(32, sizeof(frame_sse[0]) * SSE_STRIDE * BH);
+  uint32_t *luma_sse_sum =
+      (uint32_t *)aom_memalign(32, sizeof(luma_sse_sum[0]) * BW * BH);
+  memset(frame_sse, 0, sizeof(frame_sse[0]) * SSE_STRIDE * BH);
+  memset(luma_sse_sum, 0, sizeof(luma_sse_sum[0]) * BW * BH);
   uint16_t *pred1 = CONVERT_TO_SHORTPTR(pred);

-  for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) {
+  double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
+  distance_threshold = AOMMAX(distance_threshold, 1);
+  for (int subblock_idx = 0; subblock_idx < NUM_16X16; subblock_idx++) {
     // Larger motion vector -> smaller filtering weight.
     const MV mv = subblock_mvs[subblock_idx];
     const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
-    double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
-    distance_threshold = AOMMAX(distance_threshold, 1);
     d_factor[subblock_idx] = distance / distance_threshold;
     d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1);
   }
@@ -338,4 +357,6 @@ void av1_highbd_apply_temporal_filter_sse2(
         weight_factor, d_factor, tf_wgt_calc_lvl);
     plane_offset += plane_h * plane_w;
   }
+  aom_free(frame_sse);
+  aom_free(luma_sse_sum);
 }
diff --git a/test/temporal_filter_test.cc b/test/temporal_filter_test.cc
index c9b4224adb..9927fc6eaa 100644
--- a/test/temporal_filter_test.cc
+++ b/test/temporal_filter_test.cc
@@ -526,6 +526,7 @@ void HBDTemporalFilterTest::RunTest(int isRandom, int run_times, int BD,
   static_assert(block_size == BLOCK_64X64, "");
   const int width = 64;
   const int height = 64;
+  const int pels = width * height;
   int num_planes = MAX_MB_PLANE;
   int subsampling_x = 0;
   int subsampling_y = 0;
@@ -557,18 +558,23 @@ void HBDTemporalFilterTest::RunTest(int isRandom, int run_times, int BD,
     }
     double sigma[MAX_MB_PLANE] = { 2.1002103677063437, 2.1002103677063437,
                                    2.1002103677063437 };
-    DECLARE_ALIGNED(16, unsigned int, accumulator_ref[1024 * 3]);
-    DECLARE_ALIGNED(16, uint16_t, count_ref[1024 * 3]);
-    memset(accumulator_ref, 0, 1024 * 3 * sizeof(accumulator_ref[0]));
-    memset(count_ref, 0, 1024 * 3 * sizeof(count_ref[0]));
-    DECLARE_ALIGNED(16, unsigned int, accumulator_mod[1024 * 3]);
-    DECLARE_ALIGNED(16, uint16_t, count_mod[1024 * 3]);
-    memset(accumulator_mod, 0, 1024 * 3 * sizeof(accumulator_mod[0]));
-    memset(count_mod, 0, 1024 * 3 * sizeof(count_mod[0]));
+    DECLARE_ALIGNED(16, unsigned int, accumulator_ref[pels * 3]);
+    DECLARE_ALIGNED(16, uint16_t, count_ref[pels * 3]);
+    memset(accumulator_ref, 0, pels * 3 * sizeof(accumulator_ref[0]));
+    memset(count_ref, 0, pels * 3 * sizeof(count_ref[0]));
+    DECLARE_ALIGNED(16, unsigned int, accumulator_mod[pels * 3]);
+    DECLARE_ALIGNED(16, uint16_t, count_mod[pels * 3]);
+    memset(accumulator_mod, 0, pels * 3 * sizeof(accumulator_mod[0]));
+    memset(count_mod, 0, pels * 3 * sizeof(count_mod[0]));

     static_assert(width == 64 && height == 64, "");
-    const MV subblock_mvs[4] = { { 0, 0 }, { 5, 5 }, { 7, 8 }, { 2, 10 } };
-    const int subblock_mses[4] = { 15, 16, 17, 18 };
+    const MV subblock_mvs[NUM_16X16] = {
+      { 0, 0 }, { 5, 5 },  { 7, 8 }, { 2, 10 }, { 0, 0 }, { 5, 5 },
+      { 7, 8 }, { 2, 10 }, { 0, 0 }, { 5, 5 },  { 7, 8 }, { 2, 10 },
+      { 0, 0 }, { 5, 5 },  { 7, 8 }, { 2, 10 }
+    };
+    const int subblock_mses[NUM_16X16] = { 15, 16, 17, 18, 15, 16, 17, 18,
+                                           15, 16, 17, 18, 15, 16, 17, 18 };
     const int q_factor = 12;
     const int filter_strength = 5;
     const int mb_row = 0;
@@ -582,11 +588,11 @@ void HBDTemporalFilterTest::RunTest(int isRandom, int run_times, int BD,
     frame_to_filter->heights[PLANE_TYPE_UV] = height >> subsampling_y;
     frame_to_filter->strides[PLANE_TYPE_Y] = stride;
     frame_to_filter->strides[PLANE_TYPE_UV] = stride >> subsampling_x;
-    DECLARE_ALIGNED(16, uint16_t, src[1024 * 3]);
+    DECLARE_ALIGNED(16, uint16_t, src[pels * 3]);
     frame_to_filter->buffer_alloc = CONVERT_TO_BYTEPTR(src);
     frame_to_filter->flags =
         YV12_FLAG_HIGHBITDEPTH;  // Only Hihgbd bit-depth test.
-    memcpy(src, src1_, 1024 * 3 * sizeof(uint16_t));
+    memcpy(src, src1_, pels * 3 * sizeof(uint16_t));

     std::unique_ptr<MACROBLOCKD> mbd(new (std::nothrow) MACROBLOCKD);
     ASSERT_NE(mbd, nullptr);
@@ -677,26 +683,24 @@ TEST_P(HBDTemporalFilterTest, DISABLED_Speed) {
   RunTest(1, 100000, 10, I444);
 }

-// av1_apply_temporal_filter_c works on 64x64 TF block now, the SIMD function
-// needs to be updated.
-// #if HAVE_SSE2
-// HBDTemporalFilterFuncParam HBDtemporal_filter_test_sse2[] = {
-//  HBDTemporalFilterFuncParam(&av1_highbd_apply_temporal_filter_c,
-//                             &av1_highbd_apply_temporal_filter_sse2)
-//};
-// INSTANTIATE_TEST_SUITE_P(SSE2, HBDTemporalFilterTest,
-//                         Combine(ValuesIn(HBDtemporal_filter_test_sse2),
-//                                 Values(0, 1)));
-// #endif  // HAVE_SSE2
-// #if HAVE_AVX2
-// HBDTemporalFilterFuncParam HBDtemporal_filter_test_avx2[] = {
-//  HBDTemporalFilterFuncParam(&av1_highbd_apply_temporal_filter_c,
-//                             &av1_highbd_apply_temporal_filter_avx2)
-//};
-// INSTANTIATE_TEST_SUITE_P(AVX2, HBDTemporalFilterTest,
-//                         Combine(ValuesIn(HBDtemporal_filter_test_avx2),
-//                                 Values(0, 1)));
-// #endif  // HAVE_AVX2
+#if HAVE_SSE2
+HBDTemporalFilterFuncParam HBDtemporal_filter_test_sse2[] = {
+  HBDTemporalFilterFuncParam(&av1_highbd_apply_temporal_filter_c,
+                             &av1_highbd_apply_temporal_filter_sse2)
+};
+INSTANTIATE_TEST_SUITE_P(SSE2, HBDTemporalFilterTest,
+                         Combine(ValuesIn(HBDtemporal_filter_test_sse2),
+                                 Values(0, 1)));
+#endif  // HAVE_SSE2
+#if HAVE_AVX2
+HBDTemporalFilterFuncParam HBDtemporal_filter_test_avx2[] = {
+  HBDTemporalFilterFuncParam(&av1_highbd_apply_temporal_filter_c,
+                             &av1_highbd_apply_temporal_filter_avx2)
+};
+INSTANTIATE_TEST_SUITE_P(AVX2, HBDTemporalFilterTest,
+                         Combine(ValuesIn(HBDtemporal_filter_test_avx2),
+                                 Values(0, 1)));
+#endif  // HAVE_AVX2

 // av1_apply_temporal_filter_c works on 64x64 TF block now, the SIMD function
 // needs to be updated.