Commit 198776f61f for openssl.org

commit 198776f61fa96ac084fe90c6f7fea7971d0263f6
Author: Tomasz Kantecki <tomasz.kantecki@intel.com>
Date:   Wed Jan 28 16:50:38 2026 +0000

    ML-DSA: optimize one vmovshdup from the NTT multiply operation

    It applies to 7 out of 8 levels in NTT and NTT^-1.
    It helps eliminate some zeta shuffles in NTT level 6 and NTT^-1 level 1.
    Added small optimization in data shuffling between the levels.

    Signed-off-by: Tomasz Kantecki <tomasz.kantecki@intel.com>

    Reviewed-by: Saša NedvÄ›dický <sashan@openssl.org>
    Reviewed-by: Paul Dale <paul.dale@oracle.com>
    Reviewed-by: Neil Horman <nhorman@openssl.org>
    MergeDate: Wed Mar 11 15:47:46 2026
    (Merged from https://github.com/openssl/openssl/pull/30160)

diff --git a/crypto/ml_dsa/asm/ml_dsa_ntt-x86_64.pl b/crypto/ml_dsa/asm/ml_dsa_ntt-x86_64.pl
index c412e5e6c2..6012b34bb7 100644
--- a/crypto/ml_dsa/asm/ml_dsa_ntt-x86_64.pl
+++ b/crypto/ml_dsa/asm/ml_dsa_ntt-x86_64.pl
@@ -108,7 +108,7 @@ ml_dsa_ntt_avx2_capable:
 ___

 ###############################################################################
-# multiply_8x8_mod_Q
+# multiply_mod_Q
 #
 # Description:
 #   The inputs (A and B) are in YMM registers, where each packs 8 32-bit integers.
@@ -123,6 +123,10 @@ ___
 #   tmp0-tmp2 - Temporary registers for intermediate values
 #   q_neg_inv- YMM containing -Q^{-1} mod 2^32 (for Montgomery reduction)
 #   q        - YMM containing modulus Q
+#   bcast32  – if 1, `inB` is assumed to have each qword formed by
+#              repeating its low dword (DW|DW). Multiplication uses only
+#              the low dword, avoiding the need for `vmovshdup` when
+#              zetas are uniform across all lanes of `inA`.
 #
 # Output:
 #   out      - Resulting 8 packed 32-bit integers, each (A * B) mod Q
@@ -133,18 +137,31 @@ ___
 # Notes:
 #   inA or inB can also be used as out
 ###############################################################################
-sub multiply_8x8_mod_Q {
+sub multiply_mod_Q {
     my ($inA, $inB, $out,
         $tmp0, $tmp1, $tmp2,
-        $q_neg_inv, $q) = @_;
+        $q_neg_inv, $q, $bcast32) = @_;
+
+    if (!defined($bcast32)) {
+        $bcast32 = 0;
+    }

     $code .= <<___;
     # Multiply A x B
     vpmuludq $inA, $inB, $tmp0  # multiply even indexes
     vmovshdup $inA, $tmp1
+___
+    if ($bcast32 == 0) {
+        $code .= <<___;
     vmovshdup $inB, $tmp2
     vpmuludq $tmp1, $tmp2, $tmp1  # multiply odd indexes
-
+___
+    } else {
+        $code .= <<___;
+    vpmuludq $tmp1, $inB, $tmp1   # multiply odd indexes of inA
+___
+    }
+        $code .= <<___;
     # Montgomery reduction: t1 = (A x B)[31..0] x Qinv
     vpmuludq $q_neg_inv, $tmp0, $out
     vpmuludq $q_neg_inv, $tmp1, $tmp2
@@ -196,6 +213,7 @@ ___
 #   tmp0, tmp1, tmp2, tmp3 - Temporary registers for intermediate values
 #   q_neg_inv - YMM with -Q^{-1} mod 2^32
 #   q       - YMM with modulus Q
+#   level   - current NTT processing level
 #
 # Output:
 #   n_even  - Updated even coefficients after butterfly and reduction
@@ -209,14 +227,23 @@ sub ntt_butterfly {
         $zetas,
         $n_even, $n_odd,
         $tmp0, $tmp1, $tmp2, $tmp3,
-        $q_neg_inv, $q) = @_;
+        $q_neg_inv, $q, $level) = @_;

-    &multiply_8x8_mod_Q($w_odd, # A
+    if ($level >= 7) {
+        # level 7: each zeta (B) dword is different
+        &multiply_mod_Q($w_odd, # A
                         $zetas, # B
                         $tmp0,  # out (AxB)
                         $tmp1, $tmp2, $tmp3, # tmp
-                        $q_neg_inv, $q);  # qinv, q
-
+                        $q_neg_inv, $q, 0);  # qinv, q, no-bcast32
+    } else {
+        # levels 0 to 6: same zeta (B) dwords within each qword
+        &multiply_mod_Q($w_odd, # A
+                        $zetas, # B
+                        $tmp0,  # out (AxB)
+                        $tmp1, $tmp2, $tmp3, # tmp
+                        $q_neg_inv, $q, 1);  # qinv, q, bcast32
+    }
     $code .= <<___;

     # t_odd = $tmp0
@@ -291,16 +318,16 @@ sub ntt_levels0to2 {
 ___
     &ntt_butterfly("%ymm0", "%ymm4", "%ymm13", "%ymm0", "%ymm4",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm8", "%ymm9", "%ymm10", "%ymm11",            # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 0);                          # qinv, q, level
     &ntt_butterfly("%ymm1", "%ymm5", "%ymm13", "%ymm1", "%ymm5",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm8", "%ymm9", "%ymm10", "%ymm11",            # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 0);                          # qinv, q, level
     &ntt_butterfly("%ymm2", "%ymm6", "%ymm13", "%ymm2", "%ymm6",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm8", "%ymm9", "%ymm10", "%ymm11",            # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 0);                          # qinv, q, level
     &ntt_butterfly("%ymm3", "%ymm7", "%ymm13", "%ymm3", "%ymm7",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm8", "%ymm9", "%ymm10", "%ymm11",            # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 0);                          # qinv, q, level
     $code .= <<___;

     # ==============================================================
@@ -311,19 +338,19 @@ ___
 ___
     &ntt_butterfly("%ymm0", "%ymm2", "%ymm13", "%ymm0", "%ymm2",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm8", "%ymm9", "%ymm10", "%ymm11",            # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 1);                          # qinv, q, level
     &ntt_butterfly("%ymm1", "%ymm3", "%ymm13", "%ymm1", "%ymm3",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm8", "%ymm9", "%ymm10", "%ymm11",            # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 1);                          # qinv, q, level
     $code .= <<___;
     vpbroadcastd 3*4(%r11), %ymm13
 ___
     &ntt_butterfly("%ymm4", "%ymm6", "%ymm13", "%ymm4", "%ymm6",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm8", "%ymm9", "%ymm10", "%ymm11",            # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 1);                          # qinv, q, level
     &ntt_butterfly("%ymm5", "%ymm7", "%ymm13", "%ymm5", "%ymm7",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm8", "%ymm9", "%ymm10", "%ymm11",            # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 1);                          # qinv, q, level
 $code .= <<___;

     # ==============================================================
@@ -334,25 +361,25 @@ $code .= <<___;
 ___
     &ntt_butterfly("%ymm0", "%ymm1", "%ymm13", "%ymm0", "%ymm1",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm8", "%ymm9", "%ymm10", "%ymm11",            # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 2);                          # qinv, q, level
     $code .= <<___;
     vpbroadcastd 5*4(%r11), %ymm13
 ___
     &ntt_butterfly("%ymm2", "%ymm3", "%ymm13", "%ymm2", "%ymm3",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm8", "%ymm9", "%ymm10", "%ymm11",            # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 2);                          # qinv, q, level
     $code .= <<___;
     vpbroadcastd 6*4(%r11), %ymm13
 ___
     &ntt_butterfly("%ymm4", "%ymm5", "%ymm13", "%ymm4", "%ymm5",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm8", "%ymm9", "%ymm10", "%ymm11",            # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 2);                          # qinv, q, level
     $code .= <<___;
     vpbroadcastd 7*4(%r11), %ymm13
 ___
     &ntt_butterfly("%ymm6", "%ymm7", "%ymm13", "%ymm6", "%ymm7",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm8", "%ymm9", "%ymm10", "%ymm11",            # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 2);                          # qinv, q, level
 $code .= <<___;

     vmovdqu %ymm0, $off+0*4(%rdi)
@@ -425,10 +452,10 @@ ___
     # Process first 16 coefficients with zeta in ymm13
     &ntt_butterfly("%ymm0", "%ymm2", "%ymm13", "%ymm0", "%ymm2",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 3);                          # qinv, q, level
     &ntt_butterfly("%ymm1", "%ymm3", "%ymm13", "%ymm1", "%ymm3",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 3);                          # qinv, q, level
 $code .= <<___;
     # broadcast zetas
     vpbroadcastd $l3+4(%r11), %ymm13  # zeta for coefficients 16-31
@@ -442,12 +469,10 @@ ___
     # Process next 16 coefficients with zeta in ymm13
     &ntt_butterfly("%ymm4", "%ymm6", "%ymm13", "%ymm4", "%ymm6",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
-$code .= <<___;
-___
+                   "%ymm14", "%ymm15", 3);                          # qinv, q, level
     &ntt_butterfly("%ymm5", "%ymm7", "%ymm13", "%ymm5", "%ymm7",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 3);                          # qinv, q, level
 $code .= <<___;

     # ==============================================================
@@ -472,39 +497,34 @@ $code .= <<___;
 ___
     &ntt_butterfly("%ymm0", "%ymm1", "%ymm13", "%ymm0", "%ymm1",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 4);                          # qinv, q, level
 $code .= <<___;
     # broadcast zetas for next 8 coefficients
     vpbroadcastd $l4+4(%r11), %ymm13    # zeta for coefficients 8-15
 ___
     &ntt_butterfly("%ymm2", "%ymm3", "%ymm13", "%ymm2", "%ymm3",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 4);                          # qinv, q, level
     $code .= <<___;
     # broadcast zetas for next 8 coefficients
     vpbroadcastd $l4+8(%r11), %ymm13    # zeta for coefficients 16-23
 ___
     &ntt_butterfly("%ymm4", "%ymm5", "%ymm13", "%ymm4", "%ymm5",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 4);                          # qinv, q, level
 $code .= <<___;
     # broadcast zetas for next 8 coefficients
     vpbroadcastd $l4+12(%r11), %ymm13   # zeta for coefficients 24-31
 ___
     &ntt_butterfly("%ymm6", "%ymm7", "%ymm13", "%ymm6", "%ymm7",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 4);                          # qinv, q, level
     $code .= <<___;

     # ==============================================================
     # level 5: offset = 4, step = 32
     # zeta indexes = 32, 33, 34, ..., 62, 63

-    # load zetas for first 8 coefficients
-    vpbroadcastd $l5(%r11), %ymm13         # zetas for coefficients 0-3
-    vpbroadcastd $l5+4(%r11), %ymm12       # zetas for coefficients 4-7
-    vpblendd \$0xf0, %ymm12, %ymm13, %ymm13 # blend into ymm13
-
     # prepare w_even and w_odd
     # Input dword layout:
     #   ymm0 = [ 0  1  2  3 |  4  5  6  7] (even)
@@ -515,49 +535,57 @@ ___
     vperm2i128 \$0x20, %ymm1, %ymm0, %ymm8
     vperm2i128 \$0x31, %ymm1, %ymm0, %ymm1

+    # load zetas for first 8 coefficients
+    vpbroadcastd $l5(%r11), %ymm13         # zetas for coefficients 0-3
+    vpbroadcastd $l5+4(%r11), %ymm12       # zetas for coefficients 4-7
+    vpblendd \$0xf0, %ymm12, %ymm13, %ymm13 # blend into ymm13
+
 ___
     &ntt_butterfly("%ymm8", "%ymm1", "%ymm13", "%ymm0", "%ymm1",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 5);                          # qinv, q, level
 $code .= <<___;
+    # prepare w_even and w_odd
+    vperm2i128 \$0x20, %ymm3, %ymm2, %ymm8
+    vperm2i128 \$0x31, %ymm3, %ymm2, %ymm3
+
     # load zetas for first 8 coefficients
-    vpbroadcastd $l5+8(%r11), %ymm13       # zetas for coefficients 8-11
+    vpbroadcastd $l5+8(%r11), %ymm13        # zetas for coefficients 8-11
     vpbroadcastd $l5+12(%r11), %ymm12       # zetas for coefficients 12-15
     vpblendd \$0xf0, %ymm12, %ymm13, %ymm13 # blend into ymm13

-    # prepare w_even and w_odd
-    vperm2i128 \$0x20, %ymm3, %ymm2, %ymm8
-    vperm2i128 \$0x31, %ymm3, %ymm2, %ymm3
 ___
     &ntt_butterfly("%ymm8", "%ymm3", "%ymm13", "%ymm2", "%ymm3",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 5);                          # qinv, q, level
     $code .= <<___;
+    # prepare w_even and w_odd
+    vperm2i128 \$0x20, %ymm5, %ymm4, %ymm8
+    vperm2i128 \$0x31, %ymm5, %ymm4, %ymm5
+
     # load zetas for next 8 coefficients
     vpbroadcastd $l5+16(%r11), %ymm13       # zetas for coefficients 16-19
     vpbroadcastd $l5+20(%r11), %ymm12       # zetas for coefficients 20-23
     vpblendd \$0xf0, %ymm12, %ymm13, %ymm13 # blend into ymm13

-    # prepare w_even and w_odd
-    vperm2i128 \$0x20, %ymm5, %ymm4, %ymm8
-    vperm2i128 \$0x31, %ymm5, %ymm4, %ymm5
 ___
     &ntt_butterfly("%ymm8", "%ymm5", "%ymm13", "%ymm4", "%ymm5",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 5);                          # qinv, q, level
 $code .= <<___;
+    # prepare w_even and w_odd
+    vperm2i128 \$0x20, %ymm7, %ymm6, %ymm8
+    vperm2i128 \$0x31, %ymm7, %ymm6, %ymm7
+
     # load zetas for next 8 coefficients
     vpbroadcastd $l5+24(%r11), %ymm13       # zetas for coefficients 24-27
     vpbroadcastd $l5+28(%r11), %ymm12       # zetas for coefficients 28-31
     vpblendd \$0xf0, %ymm12, %ymm13, %ymm13 # blend into ymm13

-    # prepare w_even and w_odd
-    vperm2i128 \$0x20, %ymm7, %ymm6, %ymm8
-    vperm2i128 \$0x31, %ymm7, %ymm6, %ymm7
 ___
     &ntt_butterfly("%ymm8", "%ymm7", "%ymm13", "%ymm6", "%ymm7",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 5);                          # qinv, q, level
     $code .= <<___;

     # ==============================================================
@@ -573,13 +601,10 @@ ___
     #   ymm8 = [ 0  1  4  5] [ 8  9 12 13] (even)
     #   ymm1 = [ 2  3  6  7] [10 11 14 15] (odd)

-    # load & prepare zetas for first 16 coefficients
-    vmovdqu $l6(%r11), %ymm10
-    vpshufd \$0xfa, %ymm10, %ymm11
-    vpshufd \$0x50, %ymm10, %ymm10
-    vperm2i128 \$0x20, %ymm11, %ymm10, %ymm13
-    vperm2i128 \$0x31, %ymm11, %ymm10, %ymm12
-    vmovdqu %ymm12, (%rsp)
+    # load & prepare zetas
+    # - it is enough that the 1st dword of each qword is populated
+    vmovdqu $l6(%r11), %xmm13
+    vpmovzxdq %xmm13, %ymm13

     # prepare w_even and w_odd
     vpunpcklqdq %ymm1, %ymm0, %ymm8
@@ -587,10 +612,11 @@ ___
 ___
     &ntt_butterfly("%ymm8", "%ymm1", "%ymm13", "%ymm0", "%ymm1",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 6);                          # qinv, q, level
     $code .= <<___;
-    # load formatted zetas from the stack
-    vmovdqu (%rsp), %ymm13
+    # load & prepare zetas
+    vmovdqu $l6+16(%r11), %xmm13
+    vpmovzxdq %xmm13, %ymm13

     # prepare w_even and w_odd
     vpunpcklqdq %ymm3, %ymm2, %ymm8
@@ -598,16 +624,12 @@ ___
 ___
     &ntt_butterfly("%ymm8", "%ymm3", "%ymm13", "%ymm2", "%ymm3",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 6);                          # qinv, q, level
     $code .= <<___;

-    # load & prepare zetas for next 16 coefficients
-    vmovdqu $l6+32(%r11), %ymm10
-    vpshufd \$0xfa, %ymm10, %ymm11
-    vpshufd \$0x50, %ymm10, %ymm10
-    vperm2i128 \$0x20, %ymm11, %ymm10, %ymm13
-    vperm2i128 \$0x31, %ymm11, %ymm10, %ymm12
-    vmovdqu %ymm12, (%rsp)
+    # load & prepare zetas
+    vmovdqu $l6+32(%r11), %xmm13
+    vpmovzxdq %xmm13, %ymm13

     # prepare w_even and w_odd
     vpunpcklqdq %ymm5, %ymm4, %ymm8
@@ -615,10 +637,11 @@ ___
 ___
     &ntt_butterfly("%ymm8", "%ymm5", "%ymm13", "%ymm4", "%ymm5",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 6);                          # qinv, q, level
 $code .= <<___;
-    # load formatted zetas from the stack
-    vmovdqu (%rsp), %ymm13
+    # load & prepare zetas
+    vmovdqu $l6+48(%r11), %xmm13
+    vpmovzxdq %xmm13, %ymm13

     # prepare w_even and w_odd
     vpunpcklqdq %ymm7, %ymm6, %ymm8
@@ -626,7 +649,7 @@ $code .= <<___;
 ___
     &ntt_butterfly("%ymm8", "%ymm7", "%ymm13", "%ymm6", "%ymm7",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 6);                          # qinv, q, level
     $code .= <<___;

     # ==============================================================
@@ -637,68 +660,64 @@ ___
     #   ymm0 = [ 0  1  4  5] [ 8  9 12 13]
     #   ymm1 = [ 2  3  6  7] [10 11 14 15]
     # Required DWORD layout:
-    #   ymm8 = [ 0  2  4  6] [ 8 10 12 14] (even)
+    #   ymm0 = [ 0  2  4  6] [ 8 10 12 14] (even)
     #   ymm1 = [ 1  3  5  7] [ 9 11 13 15] (odd)

-    vpshufd \$0x88, %ymm0, %ymm8
-    vpshufd \$0x88, %ymm1, %ymm9
-    vpshufd \$0xDD, %ymm0, %ymm10
-    vpshufd \$0xDD, %ymm1, %ymm11
-    vpunpckldq %ymm9, %ymm8, %ymm0
-    vpunpckldq %ymm11, %ymm10, %ymm1
+    vpunpckldq   %ymm1, %ymm0, %ymm8      # ymm8 = [0 2 1 3] [ 8 10  9 11]
+    vpunpckhdq   %ymm1, %ymm0, %ymm9      # ymm9 = [4 6 5 7] [12 14 13 15]
+
+    vshufps      \$0xEE, %ymm9, %ymm8, %ymm1   # ymm1 = [1 3 5 7] [ 9 11 13 15]
+    vshufps      \$0x44, %ymm9, %ymm8, %ymm0   # ymm0 = [0 2 4 6] [ 8 10 12 14]

     # load zetas
     vmovdqu $l7(%r11), %ymm13          # 8 zetas for coefficients 0-7
 ___
     &ntt_butterfly("%ymm0", "%ymm1", "%ymm13", "%ymm0", "%ymm1",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 7);                          # qinv, q, level
 $code .= <<___;
     # load zetas
     vmovdqu $l7+32(%r11), %ymm13       # 8 zetas for coefficients 8-15

     # format w_even and w_odd
-    vpshufd \$0x88, %ymm2, %ymm8
-    vpshufd \$0x88, %ymm3, %ymm9
-    vpshufd \$0xDD, %ymm2, %ymm10
-    vpshufd \$0xDD, %ymm3, %ymm11
-    vpunpckldq %ymm9, %ymm8, %ymm2
-    vpunpckldq %ymm11, %ymm10, %ymm3
+    vpunpckldq   %ymm3, %ymm2, %ymm8
+    vpunpckhdq   %ymm3, %ymm2, %ymm9
+
+    vshufps      \$0xEE, %ymm9, %ymm8, %ymm3
+    vshufps      \$0x44, %ymm9, %ymm8, %ymm2
 ___
     &ntt_butterfly("%ymm2", "%ymm3", "%ymm13", "%ymm2", "%ymm3",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 7);                          # qinv, q, level
     $code .= <<___;

     # load zetas
     vmovdqu $l7+64(%r11), %ymm13       # 8 zetas for coefficients 16-23

     # format w_even and w_odd
-    vpshufd \$0x88, %ymm4, %ymm8
-    vpshufd \$0x88, %ymm5, %ymm9
-    vpshufd \$0xDD, %ymm4, %ymm10
-    vpshufd \$0xDD, %ymm5, %ymm11
-    vpunpckldq %ymm9, %ymm8, %ymm4
-    vpunpckldq %ymm11, %ymm10, %ymm5
+    vpunpckldq   %ymm5, %ymm4, %ymm8
+    vpunpckhdq   %ymm5, %ymm4, %ymm9
+
+    vshufps      \$0xEE, %ymm9, %ymm8, %ymm5
+    vshufps      \$0x44, %ymm9, %ymm8, %ymm4
 ___
     &ntt_butterfly("%ymm4", "%ymm5", "%ymm13", "%ymm4", "%ymm5",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 7);                          # qinv, q, level
 $code .= <<___;
     # load zetas
     vmovdqu $l7+96(%r11), %ymm13       # 8 zetas for coefficients 24-31

     # format w_even and w_odd
-    vpshufd \$0x88, %ymm6, %ymm8
-    vpshufd \$0x88, %ymm7, %ymm9
-    vpshufd \$0xDD, %ymm6, %ymm10
-    vpshufd \$0xDD, %ymm7, %ymm11
-    vpunpckldq %ymm9, %ymm8, %ymm6
-    vpunpckldq %ymm11, %ymm10, %ymm7
+    vpunpckldq   %ymm7, %ymm6, %ymm8
+    vpunpckhdq   %ymm7, %ymm6, %ymm9
+
+    vshufps      \$0xEE, %ymm9, %ymm8, %ymm7
+    vshufps      \$0x44, %ymm9, %ymm8, %ymm6
 ___
     &ntt_butterfly("%ymm6", "%ymm7", "%ymm13", "%ymm6", "%ymm7",    # w_even, w_odd, zetas, n_even, n_odd
                    "%ymm9", "%ymm10", "%ymm11", "%ymm12",           # tmp
-                   "%ymm14", "%ymm15");                             # qinv, q
+                   "%ymm14", "%ymm15", 7);                          # qinv, q, level
     $code .= <<___;

     # Interleave and store first 16
@@ -771,6 +790,7 @@ ___
 #   n_odd     - Output YMM for new odd coefficients
 #   q_neg_inv - YMM with -Q^{-1} mod 2^32 for Montgomery reduction
 #   q         - YMM with modulus Q
+#   level     - current INTT processing level
 #
 # Output:
 #   n_even, n_odd
@@ -783,7 +803,7 @@ sub intt_butterfly {
     my ($w_even, $w_odd, $zetas,
         $tmp0, $tmp1, $tmp2,
         $n_even, $n_odd,
-        $q_neg_inv, $q) = @_;
+        $q_neg_inv, $q, $level) = @_;

 $code .= <<___;
     # n_even = reduce_once(w_even + w_odd)
@@ -798,11 +818,21 @@ $code .= <<___;

     # Multiply n_odd by zetas (step root)
 ___
-    &multiply_8x8_mod_Q($n_odd,     # A
+    if ($level < 1) {
+        # level 0: each zeta (B) dword is different
+        &multiply_mod_Q($n_odd,     # A
+                        $zetas,     # B
+                        $n_odd,     # out (AxB)
+                        $tmp0, $tmp1, $tmp2,    # tmp
+                        $q_neg_inv, $q, 0);     # qinv, q, no-bcast32
+    } else {
+        # levels 1 to 7: same zeta (B) dwords within each qword
+        &multiply_mod_Q($n_odd,     # A
                         $zetas,     # B
                         $n_odd,     # out (AxB)
                         $tmp0, $tmp1, $tmp2,    # tmp
-                        $q_neg_inv, $q);        # qinv, q
+                        $q_neg_inv, $q, 1);     # qinv, q, bcast32
+    }
 }

 ###############################################################################
@@ -882,7 +912,7 @@ ___
     &intt_butterfly("%ymm0", "%ymm1", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm0", "%ymm1",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 0);         # qinv, q, level

 $code .= <<___;

@@ -910,7 +940,7 @@ ___
     &intt_butterfly("%ymm2", "%ymm3", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm2", "%ymm3",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 0);         # qinv, q, level

 $code .= <<___;
     # load w_even and w_odd
@@ -937,7 +967,7 @@ ___
     &intt_butterfly("%ymm4", "%ymm5", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm4", "%ymm5",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 0);         # qinv, q, level

 $code .= <<___;

@@ -965,7 +995,7 @@ ___
     &intt_butterfly("%ymm6", "%ymm7", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm6", "%ymm7",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 0);         # qinv, q, level

 $code .= <<___;

@@ -992,15 +1022,14 @@ $code .= <<___;
     vshufps \$0xee, %ymm9, %ymm8, %ymm1

     # load 4 zetas and populate across ymm
-    vmovdqu $l1(%r11), %xmm13               # [0 1 2 3]
-    vpmovzxdq %xmm13, %ymm13
-    vmovsldup %ymm13, %ymm13                # [0 0 1 1] [2 2 3 3]
+    vmovdqu $l1(%r11), %xmm13
+    vpmovzxdq %xmm13, %ymm13            # [0 - 1 -] [2 - 3 -]
 ___

     &intt_butterfly("%ymm0", "%ymm1", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm0", "%ymm1",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 1);         # qinv, q, level

 $code .= <<___;

@@ -1015,15 +1044,14 @@ $code .= <<___;
     vshufps \$0xee, %ymm9, %ymm8, %ymm3

     # load 4 zetas and populate across YMM
-    vmovdqu $l1+16(%r11), %xmm13            # [0 1 2 3]
-    vpmovzxdq %xmm13, %ymm13
-    vmovsldup %ymm13, %ymm13                # [0 0 1 1] [2 2 3 3]
+    vmovdqu $l1+16(%r11), %xmm13
+    vpmovzxdq %xmm13, %ymm13            # [0 - 1 -] [2 - 3 -]
 ___

     &intt_butterfly("%ymm2", "%ymm3", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm2", "%ymm3",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 1);         # qinv, q, level

 $code .= <<___;
     # Interleave even/odd within each 128-bit lane:
@@ -1037,15 +1065,14 @@ $code .= <<___;
     vshufps \$0xee, %ymm9, %ymm8, %ymm5

     # load 4 zetas and populate across ymm
-    vmovdqu $l1+32(%r11), %xmm13             # [0 1 2 3]
-    vpmovzxdq %xmm13, %ymm13
-    vmovsldup %ymm13, %ymm13                # [0 0 1 1] [2 2 3 3]
+    vmovdqu $l1+32(%r11), %xmm13
+    vpmovzxdq %xmm13, %ymm13            # [0 - 1 -] [2 - 3 -]
 ___

     &intt_butterfly("%ymm4", "%ymm5", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm4", "%ymm5",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 1);         # qinv, q, level

 $code .= <<___;

@@ -1060,15 +1087,14 @@ $code .= <<___;
     vshufps \$0xee, %ymm9, %ymm8, %ymm7

     # load 4 zetas and populate across YMM
-    vmovdqu $l1+48(%r11), %xmm13            # [0 1 2 3]
-    vpmovzxdq %xmm13, %ymm13
-    vmovsldup %ymm13, %ymm13                # [0 0 1 1] [2 2 3 3]
+    vmovdqu $l1+48(%r11), %xmm13
+    vpmovzxdq %xmm13, %ymm13            # [0 - 1 -] [2 - 3 -]
 ___

     &intt_butterfly("%ymm6", "%ymm7", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm6", "%ymm7",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 1);         # qinv, q, level

 $code .= <<___;

@@ -1099,7 +1125,7 @@ ___
     &intt_butterfly("%ymm8", "%ymm1", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm0", "%ymm1",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 2);         # qinv, q, level

 $code .= <<___;
     vshufps \$0x44, %ymm3, %ymm2, %ymm8
@@ -1114,7 +1140,7 @@ ___
     &intt_butterfly("%ymm8", "%ymm3", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm2", "%ymm3",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 2);         # qinv, q, level

 $code .= <<___;
     vshufps \$0x44, %ymm5, %ymm4, %ymm8
@@ -1129,7 +1155,7 @@ ___
     &intt_butterfly("%ymm8", "%ymm5", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm4", "%ymm5",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 2);         # qinv, q, level

 $code .= <<___;
     vshufps \$0x44, %ymm7, %ymm6, %ymm8
@@ -1144,7 +1170,7 @@ ___
     &intt_butterfly("%ymm8", "%ymm7", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm6", "%ymm7",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 2);         # qinv, q, level

 $code .= <<___;

@@ -1165,7 +1191,7 @@ ___
     &intt_butterfly("%ymm8", "%ymm1", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm0", "%ymm1",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 3);         # qinv, q, level

 $code .= <<___;

@@ -1179,7 +1205,7 @@ ___
     &intt_butterfly("%ymm8", "%ymm3", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm2", "%ymm3",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 3);         # qinv, q, level

 $code .= <<___;
     vperm2i128 \$0x20, %ymm5, %ymm4, %ymm8      # [0,1,2,3 | 4,5,6,7]
@@ -1192,7 +1218,7 @@ ___
     &intt_butterfly("%ymm8", "%ymm5", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm4", "%ymm5",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 3);         # qinv, q, level

 $code .= <<___;

@@ -1206,7 +1232,7 @@ ___
     &intt_butterfly("%ymm8", "%ymm7", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm6", "%ymm7",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 3);         # qinv, q, level

 $code .= <<___;

@@ -1222,12 +1248,12 @@ ___
     &intt_butterfly("%ymm0", "%ymm2", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm0", "%ymm2",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 4);         # qinv, q, level

     &intt_butterfly("%ymm1", "%ymm3", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm1", "%ymm3",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 4);         # qinv, q, level

 $code .= <<___;
     # broadcast zetas
@@ -1237,12 +1263,12 @@ ___
     &intt_butterfly("%ymm4", "%ymm6", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm4", "%ymm6",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 4);         # qinv, q, level

     &intt_butterfly("%ymm5", "%ymm7", "%ymm13",     # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12",   # temporary
                     "%ymm5", "%ymm7",               # n_even, n_odd
-                    "%ymm14", "%ymm15");            # qinv, q32
+                    "%ymm14", "%ymm15", 4);         # qinv, q, level

 $code .= <<___;

@@ -1317,14 +1343,14 @@ ___
     &intt_butterfly("%ymm0", "%ymm1", "%ymm13", # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12", # temporary
                     "%ymm0", "%ymm1",           # n_even, n_odd
-                    "%ymm14", "%ymm15");        # qinv, q32
+                    "%ymm14", "%ymm15", 5);     # qinv, q, level
     $code .= <<___;
     vpbroadcastd 249*4(%r11), %ymm13
 ___
     &intt_butterfly("%ymm2", "%ymm3", "%ymm13", # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12", # temporary
                     "%ymm2", "%ymm3",           # n_even, n_odd
-                    "%ymm14", "%ymm15");        # qinv, q32
+                    "%ymm14", "%ymm15", 5);     # qinv, q, level

     $code .= <<___;
     vpbroadcastd 250*4(%r11), %ymm13
@@ -1332,7 +1358,7 @@ ___
     &intt_butterfly("%ymm4", "%ymm5", "%ymm13", # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12", # temporary
                     "%ymm4", "%ymm5",           # n_even, n_odd
-                    "%ymm14", "%ymm15");        # qinv, q32
+                    "%ymm14", "%ymm15", 5);     # qinv, q, level

     $code .= <<___;
     vpbroadcastd 251*4(%r11), %ymm13
@@ -1340,7 +1366,7 @@ ___
     &intt_butterfly("%ymm6", "%ymm7", "%ymm13", # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12", # temporary
                     "%ymm6", "%ymm7",           # n_even, n_odd
-                    "%ymm14", "%ymm15");        # qinv, q32
+                    "%ymm14", "%ymm15", 5);     # qinv, q, level
     $code .= <<___;

     # ==============================================================
@@ -1351,22 +1377,22 @@ ___
     &intt_butterfly("%ymm0", "%ymm2", "%ymm13", # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12", # temporary
                     "%ymm0", "%ymm2",           # n_even, n_odd
-                    "%ymm14", "%ymm15");        # qinv, q32
+                    "%ymm14", "%ymm15", 6);     # qinv, q, level
     &intt_butterfly("%ymm1", "%ymm3", "%ymm13", # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12", # temporary
                     "%ymm1", "%ymm3",           # n_even, n_odd
-                    "%ymm14", "%ymm15");        # qinv, q32
+                    "%ymm14", "%ymm15", 6);     # qinv, q, level
     $code .= <<___;
     vpbroadcastd 253*4(%r11), %ymm13
 ___
     &intt_butterfly("%ymm4", "%ymm6", "%ymm13", # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12", # temporary
                     "%ymm4", "%ymm6",           # n_even, n_odd
-                    "%ymm14", "%ymm15");        # qinv, q32
+                    "%ymm14", "%ymm15", 6);     # qinv, q, level
     &intt_butterfly("%ymm5", "%ymm7", "%ymm13", # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12", # temporary
                     "%ymm5", "%ymm7",           # n_even, n_odd
-                    "%ymm14", "%ymm15");        # qinv, q32
+                    "%ymm14", "%ymm15", 6);     # qinv, q, level
 $code .= <<___;

     # ==============================================================
@@ -1377,19 +1403,19 @@ ___
     &intt_butterfly("%ymm0", "%ymm4", "%ymm13", # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12", # temporary
                     "%ymm0", "%ymm4",           # n_even, n_odd
-                    "%ymm14", "%ymm15");        # qinv, q32
+                    "%ymm14", "%ymm15", 7);     # qinv, q, level
     &intt_butterfly("%ymm1", "%ymm5", "%ymm13", # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12", # temporary
                     "%ymm1", "%ymm5",           # n_even, n_odd
-                    "%ymm14", "%ymm15");        # qinv, q32
+                    "%ymm14", "%ymm15", 7);     # qinv, q, level
     &intt_butterfly("%ymm2", "%ymm6", "%ymm13", # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12", # temporary
                     "%ymm2", "%ymm6",           # n_even, n_odd
-                    "%ymm14", "%ymm15");        # qinv, q32
+                    "%ymm14", "%ymm15", 7);     # qinv, q, level
     &intt_butterfly("%ymm3", "%ymm7", "%ymm13", # w_even, w_odd, zetas
                     "%ymm10", "%ymm11", "%ymm12", # temporary
                     "%ymm3", "%ymm7",           # n_even, n_odd
-                    "%ymm14", "%ymm15");        # qinv, q32
+                    "%ymm14", "%ymm15", 7);     # qinv, q, level
 $code .= <<___;

     # ==============================================================
@@ -1398,30 +1424,30 @@ $code .= <<___;
     vpbroadcastd ml_dsa_inverse_degree_montgomery(%rip), %ymm13
 ___

-    &multiply_8x8_mod_Q("%ymm0", "%ymm13", "%ymm0", # A, B, out (AxB)
-                        "%ymm10", "%ymm11", "%ymm12", # tmp
-                        "%ymm14", "%ymm15");        # qinv, q
-    &multiply_8x8_mod_Q("%ymm4", "%ymm13", "%ymm4", # A, B, out (AxB)
-                        "%ymm10", "%ymm11", "%ymm12", # tmp
-                        "%ymm14", "%ymm15");        # qinv, q
-    &multiply_8x8_mod_Q("%ymm1", "%ymm13", "%ymm1", # A, B, out (AxB)
-                        "%ymm10", "%ymm11", "%ymm12", # tmp
-                        "%ymm14", "%ymm15");        # qinv, q
-    &multiply_8x8_mod_Q("%ymm5", "%ymm13", "%ymm5", # A, B, out (AxB)
-                        "%ymm10", "%ymm11", "%ymm12", # tmp
-                        "%ymm14", "%ymm15");        # qinv, q
-    &multiply_8x8_mod_Q("%ymm2", "%ymm13", "%ymm2", # A, B, out (AxB)
-                        "%ymm10", "%ymm11", "%ymm12", # tmp
-                        "%ymm14", "%ymm15");        # qinv, q
-    &multiply_8x8_mod_Q("%ymm6", "%ymm13", "%ymm6", # A, B, out (AxB)
-                        "%ymm10", "%ymm11", "%ymm12", # tmp
-                        "%ymm14", "%ymm15");        # qinv, q
-    &multiply_8x8_mod_Q("%ymm3", "%ymm13", "%ymm3", # A, B, out (AxB)
-                        "%ymm10", "%ymm11", "%ymm12", # tmp
-                        "%ymm14", "%ymm15");        # qinv, q
-    &multiply_8x8_mod_Q("%ymm7", "%ymm13", "%ymm7", # A, B, out (AxB)
-                        "%ymm10", "%ymm11", "%ymm12", # tmp
-                        "%ymm14", "%ymm15");        # qinv, q
+    &multiply_mod_Q("%ymm0", "%ymm13", "%ymm0",   # A, B, out (AxB)
+                    "%ymm10", "%ymm11", "%ymm12", # tmp
+                    "%ymm14", "%ymm15", 1);       # qinv, q, bcast32
+    &multiply_mod_Q("%ymm4", "%ymm13", "%ymm4",   # A, B, out (AxB)
+                    "%ymm10", "%ymm11", "%ymm12", # tmp
+                    "%ymm14", "%ymm15", 1);       # qinv, q, bcast32
+    &multiply_mod_Q("%ymm1", "%ymm13", "%ymm1",   # A, B, out (AxB)
+                    "%ymm10", "%ymm11", "%ymm12", # tmp
+                    "%ymm14", "%ymm15", 1);       # qinv, q, bcast32
+    &multiply_mod_Q("%ymm5", "%ymm13", "%ymm5",   # A, B, out (AxB)
+                    "%ymm10", "%ymm11", "%ymm12", # tmp
+                    "%ymm14", "%ymm15", 1);       # qinv, q, bcast32
+    &multiply_mod_Q("%ymm2", "%ymm13", "%ymm2",   # A, B, out (AxB)
+                    "%ymm10", "%ymm11", "%ymm12", # tmp
+                    "%ymm14", "%ymm15", 1);       # qinv, q, bcast32
+    &multiply_mod_Q("%ymm6", "%ymm13", "%ymm6",   # A, B, out (AxB)
+                    "%ymm10", "%ymm11", "%ymm12", # tmp
+                    "%ymm14", "%ymm15", 1);       # qinv, q, bcast32
+    &multiply_mod_Q("%ymm3", "%ymm13", "%ymm3",   # A, B, out (AxB)
+                    "%ymm10", "%ymm11", "%ymm12", # tmp
+                    "%ymm14", "%ymm15", 1);       # qinv, q, bcast32
+    &multiply_mod_Q("%ymm7", "%ymm13", "%ymm7",   # A, B, out (AxB)
+                    "%ymm10", "%ymm11", "%ymm12", # tmp
+                    "%ymm14", "%ymm15", 1);       # qinv, q, bcast32

 $code .= <<___;

@@ -1579,9 +1605,9 @@ ml_dsa_poly_ntt_mult_avx2:
     # multiply this part of input data
 ___

-    &multiply_8x8_mod_Q("%ymm0", "%ymm1", "%ymm0",  # A, B, out (AxB)
-                        "%ymm8", "%ymm9", "%ymm10", # tmp
-                        "%ymm14", "%ymm15");        # qinv, q
+    &multiply_mod_Q("%ymm0", "%ymm1", "%ymm0",  # A, B, out (AxB)
+                    "%ymm8", "%ymm9", "%ymm10", # tmp
+                    "%ymm14", "%ymm15", 0);     # qinv, q, bcast32

 $code .= <<___;
     # store result to output
@@ -1632,11 +1658,7 @@ $code .= <<___;
 ml_dsa_poly_ntt_avx2:
 .cfi_startproc

-    sub \$32, %rsp
-.cfi_adjust_cfa_offset 32   # track rsp change so unwinder can find CFA
-
-    # save input arguments
-    mov %rdi, %r10
+    # move p_zetas to r11
     mov %rsi, %r11

     # load constants
@@ -1647,8 +1669,7 @@ ml_dsa_poly_ntt_avx2:
     # - level 0: offset = 128, step = 1, zeta indexes = 1
     # - level 1: offset = 64, step = 2, zeta indexes = 2, 3
     # - level 2: offset = 32, step = 4, zeta indexes = 4, 5, 6, 7
-
-    mov %r10, %rdi                  # p_coeffs
+    # p_coeffs already in rdi
 ___

     &ntt_levels0to2(0*4);
@@ -1669,8 +1690,7 @@ $code .= <<___;
     #     zeta indexes = 64, 65, 66, ..., 126, 127
     # - level 7: offset = 1, step = 128
     #     zeta indexes = 128, 129, 130, ..., 254, 255
-
-    mov %r10, %rdi                  # p_even / p_coeff
+    # p_coeffs already in rdi
 ___

     # arguments:    coeff,   l3,    l4,    l5,    l6,    l7
@@ -1682,10 +1702,6 @@ ___
 $code .= <<___;

     vzeroall
-
-    lea 32(%rsp), %rsp
-.cfi_adjust_cfa_offset -32
-
     ret
 .cfi_endproc
 .size   ml_dsa_poly_ntt_avx2, .-ml_dsa_poly_ntt_avx2