Commit 198776f61f for openssl.org
commit 198776f61fa96ac084fe90c6f7fea7971d0263f6
Author: Tomasz Kantecki <tomasz.kantecki@intel.com>
Date: Wed Jan 28 16:50:38 2026 +0000
ML-DSA: optimize one vmovshdup from the NTT multiply operation
It applies to 7 out of 8 levels in NTT and NTT^-1.
It helps eliminate some zeta shuffles in NTT level 6 and NTT^-1 level 1.
Added small optimization in data shuffling between the levels.
Signed-off-by: Tomasz Kantecki <tomasz.kantecki@intel.com>
Reviewed-by: Saša NedvÄ›dický <sashan@openssl.org>
Reviewed-by: Paul Dale <paul.dale@oracle.com>
Reviewed-by: Neil Horman <nhorman@openssl.org>
MergeDate: Wed Mar 11 15:47:46 2026
(Merged from https://github.com/openssl/openssl/pull/30160)
diff --git a/crypto/ml_dsa/asm/ml_dsa_ntt-x86_64.pl b/crypto/ml_dsa/asm/ml_dsa_ntt-x86_64.pl
index c412e5e6c2..6012b34bb7 100644
--- a/crypto/ml_dsa/asm/ml_dsa_ntt-x86_64.pl
+++ b/crypto/ml_dsa/asm/ml_dsa_ntt-x86_64.pl
@@ -108,7 +108,7 @@ ml_dsa_ntt_avx2_capable:
___
###############################################################################
-# multiply_8x8_mod_Q
+# multiply_mod_Q
#
# Description:
# The inputs (A and B) are in YMM registers, where each packs 8 32-bit integers.
@@ -123,6 +123,10 @@ ___
# tmp0-tmp2 - Temporary registers for intermediate values
# q_neg_inv- YMM containing -Q^{-1} mod 2^32 (for Montgomery reduction)
# q - YMM containing modulus Q
+# bcast32 – if 1, `inB` is assumed to have each qword formed by
+# repeating its low dword (DW|DW). Multiplication uses only
+# the low dword, avoiding the need for `vmovshdup` when
+# zetas are uniform across all lanes of `inA`.
#
# Output:
# out - Resulting 8 packed 32-bit integers, each (A * B) mod Q
@@ -133,18 +137,31 @@ ___
# Notes:
# inA or inB can also be used as out
###############################################################################
-sub multiply_8x8_mod_Q {
+sub multiply_mod_Q {
my ($inA, $inB, $out,
$tmp0, $tmp1, $tmp2,
- $q_neg_inv, $q) = @_;
+ $q_neg_inv, $q, $bcast32) = @_;
+
+ if (!defined($bcast32)) {
+ $bcast32 = 0;
+ }
$code .= <<___;
# Multiply A x B
vpmuludq $inA, $inB, $tmp0 # multiply even indexes
vmovshdup $inA, $tmp1
+___
+ if ($bcast32 == 0) {
+ $code .= <<___;
vmovshdup $inB, $tmp2
vpmuludq $tmp1, $tmp2, $tmp1 # multiply odd indexes
-
+___
+ } else {
+ $code .= <<___;
+ vpmuludq $tmp1, $inB, $tmp1 # multiply odd indexes of inA
+___
+ }
+ $code .= <<___;
# Montgomery reduction: t1 = (A x B)[31..0] x Qinv
vpmuludq $q_neg_inv, $tmp0, $out
vpmuludq $q_neg_inv, $tmp1, $tmp2
@@ -196,6 +213,7 @@ ___
# tmp0, tmp1, tmp2, tmp3 - Temporary registers for intermediate values
# q_neg_inv - YMM with -Q^{-1} mod 2^32
# q - YMM with modulus Q
+# level - current NTT processing level
#
# Output:
# n_even - Updated even coefficients after butterfly and reduction
@@ -209,14 +227,23 @@ sub ntt_butterfly {
$zetas,
$n_even, $n_odd,
$tmp0, $tmp1, $tmp2, $tmp3,
- $q_neg_inv, $q) = @_;
+ $q_neg_inv, $q, $level) = @_;
- &multiply_8x8_mod_Q($w_odd, # A
+ if ($level >= 7) {
+ # level 7: each zeta (B) dword is different
+ &multiply_mod_Q($w_odd, # A
$zetas, # B
$tmp0, # out (AxB)
$tmp1, $tmp2, $tmp3, # tmp
- $q_neg_inv, $q); # qinv, q
-
+ $q_neg_inv, $q, 0); # qinv, q, no-bcast32
+ } else {
+ # levels 0 to 6: same zeta (B) dwords within each qword
+ &multiply_mod_Q($w_odd, # A
+ $zetas, # B
+ $tmp0, # out (AxB)
+ $tmp1, $tmp2, $tmp3, # tmp
+ $q_neg_inv, $q, 1); # qinv, q, bcast32
+ }
$code .= <<___;
# t_odd = $tmp0
@@ -291,16 +318,16 @@ sub ntt_levels0to2 {
___
&ntt_butterfly("%ymm0", "%ymm4", "%ymm13", "%ymm0", "%ymm4", # w_even, w_odd, zetas, n_even, n_odd
"%ymm8", "%ymm9", "%ymm10", "%ymm11", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 0); # qinv, q, level
&ntt_butterfly("%ymm1", "%ymm5", "%ymm13", "%ymm1", "%ymm5", # w_even, w_odd, zetas, n_even, n_odd
"%ymm8", "%ymm9", "%ymm10", "%ymm11", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 0); # qinv, q, level
&ntt_butterfly("%ymm2", "%ymm6", "%ymm13", "%ymm2", "%ymm6", # w_even, w_odd, zetas, n_even, n_odd
"%ymm8", "%ymm9", "%ymm10", "%ymm11", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 0); # qinv, q, level
&ntt_butterfly("%ymm3", "%ymm7", "%ymm13", "%ymm3", "%ymm7", # w_even, w_odd, zetas, n_even, n_odd
"%ymm8", "%ymm9", "%ymm10", "%ymm11", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 0); # qinv, q, level
$code .= <<___;
# ==============================================================
@@ -311,19 +338,19 @@ ___
___
&ntt_butterfly("%ymm0", "%ymm2", "%ymm13", "%ymm0", "%ymm2", # w_even, w_odd, zetas, n_even, n_odd
"%ymm8", "%ymm9", "%ymm10", "%ymm11", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 1); # qinv, q, level
&ntt_butterfly("%ymm1", "%ymm3", "%ymm13", "%ymm1", "%ymm3", # w_even, w_odd, zetas, n_even, n_odd
"%ymm8", "%ymm9", "%ymm10", "%ymm11", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 1); # qinv, q, level
$code .= <<___;
vpbroadcastd 3*4(%r11), %ymm13
___
&ntt_butterfly("%ymm4", "%ymm6", "%ymm13", "%ymm4", "%ymm6", # w_even, w_odd, zetas, n_even, n_odd
"%ymm8", "%ymm9", "%ymm10", "%ymm11", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 1); # qinv, q, level
&ntt_butterfly("%ymm5", "%ymm7", "%ymm13", "%ymm5", "%ymm7", # w_even, w_odd, zetas, n_even, n_odd
"%ymm8", "%ymm9", "%ymm10", "%ymm11", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 1); # qinv, q, level
$code .= <<___;
# ==============================================================
@@ -334,25 +361,25 @@ $code .= <<___;
___
&ntt_butterfly("%ymm0", "%ymm1", "%ymm13", "%ymm0", "%ymm1", # w_even, w_odd, zetas, n_even, n_odd
"%ymm8", "%ymm9", "%ymm10", "%ymm11", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 2); # qinv, q, level
$code .= <<___;
vpbroadcastd 5*4(%r11), %ymm13
___
&ntt_butterfly("%ymm2", "%ymm3", "%ymm13", "%ymm2", "%ymm3", # w_even, w_odd, zetas, n_even, n_odd
"%ymm8", "%ymm9", "%ymm10", "%ymm11", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 2); # qinv, q, level
$code .= <<___;
vpbroadcastd 6*4(%r11), %ymm13
___
&ntt_butterfly("%ymm4", "%ymm5", "%ymm13", "%ymm4", "%ymm5", # w_even, w_odd, zetas, n_even, n_odd
"%ymm8", "%ymm9", "%ymm10", "%ymm11", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 2); # qinv, q, level
$code .= <<___;
vpbroadcastd 7*4(%r11), %ymm13
___
&ntt_butterfly("%ymm6", "%ymm7", "%ymm13", "%ymm6", "%ymm7", # w_even, w_odd, zetas, n_even, n_odd
"%ymm8", "%ymm9", "%ymm10", "%ymm11", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 2); # qinv, q, level
$code .= <<___;
vmovdqu %ymm0, $off+0*4(%rdi)
@@ -425,10 +452,10 @@ ___
# Process first 16 coefficients with zeta in ymm13
&ntt_butterfly("%ymm0", "%ymm2", "%ymm13", "%ymm0", "%ymm2", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 3); # qinv, q, level
&ntt_butterfly("%ymm1", "%ymm3", "%ymm13", "%ymm1", "%ymm3", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 3); # qinv, q, level
$code .= <<___;
# broadcast zetas
vpbroadcastd $l3+4(%r11), %ymm13 # zeta for coefficients 16-31
@@ -442,12 +469,10 @@ ___
# Process next 16 coefficients with zeta in ymm13
&ntt_butterfly("%ymm4", "%ymm6", "%ymm13", "%ymm4", "%ymm6", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
-$code .= <<___;
-___
+ "%ymm14", "%ymm15", 3); # qinv, q, level
&ntt_butterfly("%ymm5", "%ymm7", "%ymm13", "%ymm5", "%ymm7", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 3); # qinv, q, level
$code .= <<___;
# ==============================================================
@@ -472,39 +497,34 @@ $code .= <<___;
___
&ntt_butterfly("%ymm0", "%ymm1", "%ymm13", "%ymm0", "%ymm1", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 4); # qinv, q, level
$code .= <<___;
# broadcast zetas for next 8 coefficients
vpbroadcastd $l4+4(%r11), %ymm13 # zeta for coefficients 8-15
___
&ntt_butterfly("%ymm2", "%ymm3", "%ymm13", "%ymm2", "%ymm3", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 4); # qinv, q, level
$code .= <<___;
# broadcast zetas for next 8 coefficients
vpbroadcastd $l4+8(%r11), %ymm13 # zeta for coefficients 16-23
___
&ntt_butterfly("%ymm4", "%ymm5", "%ymm13", "%ymm4", "%ymm5", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 4); # qinv, q, level
$code .= <<___;
# broadcast zetas for next 8 coefficients
vpbroadcastd $l4+12(%r11), %ymm13 # zeta for coefficients 24-31
___
&ntt_butterfly("%ymm6", "%ymm7", "%ymm13", "%ymm6", "%ymm7", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 4); # qinv, q, level
$code .= <<___;
# ==============================================================
# level 5: offset = 4, step = 32
# zeta indexes = 32, 33, 34, ..., 62, 63
- # load zetas for first 8 coefficients
- vpbroadcastd $l5(%r11), %ymm13 # zetas for coefficients 0-3
- vpbroadcastd $l5+4(%r11), %ymm12 # zetas for coefficients 4-7
- vpblendd \$0xf0, %ymm12, %ymm13, %ymm13 # blend into ymm13
-
# prepare w_even and w_odd
# Input dword layout:
# ymm0 = [ 0 1 2 3 | 4 5 6 7] (even)
@@ -515,49 +535,57 @@ ___
vperm2i128 \$0x20, %ymm1, %ymm0, %ymm8
vperm2i128 \$0x31, %ymm1, %ymm0, %ymm1
+ # load zetas for first 8 coefficients
+ vpbroadcastd $l5(%r11), %ymm13 # zetas for coefficients 0-3
+ vpbroadcastd $l5+4(%r11), %ymm12 # zetas for coefficients 4-7
+ vpblendd \$0xf0, %ymm12, %ymm13, %ymm13 # blend into ymm13
+
___
&ntt_butterfly("%ymm8", "%ymm1", "%ymm13", "%ymm0", "%ymm1", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 5); # qinv, q, level
$code .= <<___;
+ # prepare w_even and w_odd
+ vperm2i128 \$0x20, %ymm3, %ymm2, %ymm8
+ vperm2i128 \$0x31, %ymm3, %ymm2, %ymm3
+
# load zetas for first 8 coefficients
- vpbroadcastd $l5+8(%r11), %ymm13 # zetas for coefficients 8-11
+ vpbroadcastd $l5+8(%r11), %ymm13 # zetas for coefficients 8-11
vpbroadcastd $l5+12(%r11), %ymm12 # zetas for coefficients 12-15
vpblendd \$0xf0, %ymm12, %ymm13, %ymm13 # blend into ymm13
- # prepare w_even and w_odd
- vperm2i128 \$0x20, %ymm3, %ymm2, %ymm8
- vperm2i128 \$0x31, %ymm3, %ymm2, %ymm3
___
&ntt_butterfly("%ymm8", "%ymm3", "%ymm13", "%ymm2", "%ymm3", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 5); # qinv, q, level
$code .= <<___;
+ # prepare w_even and w_odd
+ vperm2i128 \$0x20, %ymm5, %ymm4, %ymm8
+ vperm2i128 \$0x31, %ymm5, %ymm4, %ymm5
+
# load zetas for next 8 coefficients
vpbroadcastd $l5+16(%r11), %ymm13 # zetas for coefficients 16-19
vpbroadcastd $l5+20(%r11), %ymm12 # zetas for coefficients 20-23
vpblendd \$0xf0, %ymm12, %ymm13, %ymm13 # blend into ymm13
- # prepare w_even and w_odd
- vperm2i128 \$0x20, %ymm5, %ymm4, %ymm8
- vperm2i128 \$0x31, %ymm5, %ymm4, %ymm5
___
&ntt_butterfly("%ymm8", "%ymm5", "%ymm13", "%ymm4", "%ymm5", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 5); # qinv, q, level
$code .= <<___;
+ # prepare w_even and w_odd
+ vperm2i128 \$0x20, %ymm7, %ymm6, %ymm8
+ vperm2i128 \$0x31, %ymm7, %ymm6, %ymm7
+
# load zetas for next 8 coefficients
vpbroadcastd $l5+24(%r11), %ymm13 # zetas for coefficients 24-27
vpbroadcastd $l5+28(%r11), %ymm12 # zetas for coefficients 28-31
vpblendd \$0xf0, %ymm12, %ymm13, %ymm13 # blend into ymm13
- # prepare w_even and w_odd
- vperm2i128 \$0x20, %ymm7, %ymm6, %ymm8
- vperm2i128 \$0x31, %ymm7, %ymm6, %ymm7
___
&ntt_butterfly("%ymm8", "%ymm7", "%ymm13", "%ymm6", "%ymm7", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 5); # qinv, q, level
$code .= <<___;
# ==============================================================
@@ -573,13 +601,10 @@ ___
# ymm8 = [ 0 1 4 5] [ 8 9 12 13] (even)
# ymm1 = [ 2 3 6 7] [10 11 14 15] (odd)
- # load & prepare zetas for first 16 coefficients
- vmovdqu $l6(%r11), %ymm10
- vpshufd \$0xfa, %ymm10, %ymm11
- vpshufd \$0x50, %ymm10, %ymm10
- vperm2i128 \$0x20, %ymm11, %ymm10, %ymm13
- vperm2i128 \$0x31, %ymm11, %ymm10, %ymm12
- vmovdqu %ymm12, (%rsp)
+ # load & prepare zetas
+ # - it is enough that the 1st dword of each qword is populated
+ vmovdqu $l6(%r11), %xmm13
+ vpmovzxdq %xmm13, %ymm13
# prepare w_even and w_odd
vpunpcklqdq %ymm1, %ymm0, %ymm8
@@ -587,10 +612,11 @@ ___
___
&ntt_butterfly("%ymm8", "%ymm1", "%ymm13", "%ymm0", "%ymm1", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 6); # qinv, q, level
$code .= <<___;
- # load formatted zetas from the stack
- vmovdqu (%rsp), %ymm13
+ # load & prepare zetas
+ vmovdqu $l6+16(%r11), %xmm13
+ vpmovzxdq %xmm13, %ymm13
# prepare w_even and w_odd
vpunpcklqdq %ymm3, %ymm2, %ymm8
@@ -598,16 +624,12 @@ ___
___
&ntt_butterfly("%ymm8", "%ymm3", "%ymm13", "%ymm2", "%ymm3", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 6); # qinv, q, level
$code .= <<___;
- # load & prepare zetas for next 16 coefficients
- vmovdqu $l6+32(%r11), %ymm10
- vpshufd \$0xfa, %ymm10, %ymm11
- vpshufd \$0x50, %ymm10, %ymm10
- vperm2i128 \$0x20, %ymm11, %ymm10, %ymm13
- vperm2i128 \$0x31, %ymm11, %ymm10, %ymm12
- vmovdqu %ymm12, (%rsp)
+ # load & prepare zetas
+ vmovdqu $l6+32(%r11), %xmm13
+ vpmovzxdq %xmm13, %ymm13
# prepare w_even and w_odd
vpunpcklqdq %ymm5, %ymm4, %ymm8
@@ -615,10 +637,11 @@ ___
___
&ntt_butterfly("%ymm8", "%ymm5", "%ymm13", "%ymm4", "%ymm5", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 6); # qinv, q, level
$code .= <<___;
- # load formatted zetas from the stack
- vmovdqu (%rsp), %ymm13
+ # load & prepare zetas
+ vmovdqu $l6+48(%r11), %xmm13
+ vpmovzxdq %xmm13, %ymm13
# prepare w_even and w_odd
vpunpcklqdq %ymm7, %ymm6, %ymm8
@@ -626,7 +649,7 @@ $code .= <<___;
___
&ntt_butterfly("%ymm8", "%ymm7", "%ymm13", "%ymm6", "%ymm7", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 6); # qinv, q, level
$code .= <<___;
# ==============================================================
@@ -637,68 +660,64 @@ ___
# ymm0 = [ 0 1 4 5] [ 8 9 12 13]
# ymm1 = [ 2 3 6 7] [10 11 14 15]
# Required DWORD layout:
- # ymm8 = [ 0 2 4 6] [ 8 10 12 14] (even)
+ # ymm0 = [ 0 2 4 6] [ 8 10 12 14] (even)
# ymm1 = [ 1 3 5 7] [ 9 11 13 15] (odd)
- vpshufd \$0x88, %ymm0, %ymm8
- vpshufd \$0x88, %ymm1, %ymm9
- vpshufd \$0xDD, %ymm0, %ymm10
- vpshufd \$0xDD, %ymm1, %ymm11
- vpunpckldq %ymm9, %ymm8, %ymm0
- vpunpckldq %ymm11, %ymm10, %ymm1
+ vpunpckldq %ymm1, %ymm0, %ymm8 # ymm8 = [0 2 1 3] [ 8 10 9 11]
+ vpunpckhdq %ymm1, %ymm0, %ymm9 # ymm9 = [4 6 5 7] [12 14 13 15]
+
+ vshufps \$0xEE, %ymm9, %ymm8, %ymm1 # ymm1 = [1 3 5 7] [ 9 11 13 15]
+ vshufps \$0x44, %ymm9, %ymm8, %ymm0 # ymm0 = [0 2 4 6] [ 8 10 12 14]
# load zetas
vmovdqu $l7(%r11), %ymm13 # 8 zetas for coefficients 0-7
___
&ntt_butterfly("%ymm0", "%ymm1", "%ymm13", "%ymm0", "%ymm1", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 7); # qinv, q, level
$code .= <<___;
# load zetas
vmovdqu $l7+32(%r11), %ymm13 # 8 zetas for coefficients 8-15
# format w_even and w_odd
- vpshufd \$0x88, %ymm2, %ymm8
- vpshufd \$0x88, %ymm3, %ymm9
- vpshufd \$0xDD, %ymm2, %ymm10
- vpshufd \$0xDD, %ymm3, %ymm11
- vpunpckldq %ymm9, %ymm8, %ymm2
- vpunpckldq %ymm11, %ymm10, %ymm3
+ vpunpckldq %ymm3, %ymm2, %ymm8
+ vpunpckhdq %ymm3, %ymm2, %ymm9
+
+ vshufps \$0xEE, %ymm9, %ymm8, %ymm3
+ vshufps \$0x44, %ymm9, %ymm8, %ymm2
___
&ntt_butterfly("%ymm2", "%ymm3", "%ymm13", "%ymm2", "%ymm3", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 7); # qinv, q, level
$code .= <<___;
# load zetas
vmovdqu $l7+64(%r11), %ymm13 # 8 zetas for coefficients 16-23
# format w_even and w_odd
- vpshufd \$0x88, %ymm4, %ymm8
- vpshufd \$0x88, %ymm5, %ymm9
- vpshufd \$0xDD, %ymm4, %ymm10
- vpshufd \$0xDD, %ymm5, %ymm11
- vpunpckldq %ymm9, %ymm8, %ymm4
- vpunpckldq %ymm11, %ymm10, %ymm5
+ vpunpckldq %ymm5, %ymm4, %ymm8
+ vpunpckhdq %ymm5, %ymm4, %ymm9
+
+ vshufps \$0xEE, %ymm9, %ymm8, %ymm5
+ vshufps \$0x44, %ymm9, %ymm8, %ymm4
___
&ntt_butterfly("%ymm4", "%ymm5", "%ymm13", "%ymm4", "%ymm5", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 7); # qinv, q, level
$code .= <<___;
# load zetas
vmovdqu $l7+96(%r11), %ymm13 # 8 zetas for coefficients 24-31
# format w_even and w_odd
- vpshufd \$0x88, %ymm6, %ymm8
- vpshufd \$0x88, %ymm7, %ymm9
- vpshufd \$0xDD, %ymm6, %ymm10
- vpshufd \$0xDD, %ymm7, %ymm11
- vpunpckldq %ymm9, %ymm8, %ymm6
- vpunpckldq %ymm11, %ymm10, %ymm7
+ vpunpckldq %ymm7, %ymm6, %ymm8
+ vpunpckhdq %ymm7, %ymm6, %ymm9
+
+ vshufps \$0xEE, %ymm9, %ymm8, %ymm7
+ vshufps \$0x44, %ymm9, %ymm8, %ymm6
___
&ntt_butterfly("%ymm6", "%ymm7", "%ymm13", "%ymm6", "%ymm7", # w_even, w_odd, zetas, n_even, n_odd
"%ymm9", "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ "%ymm14", "%ymm15", 7); # qinv, q, level
$code .= <<___;
# Interleave and store first 16
@@ -771,6 +790,7 @@ ___
# n_odd - Output YMM for new odd coefficients
# q_neg_inv - YMM with -Q^{-1} mod 2^32 for Montgomery reduction
# q - YMM with modulus Q
+# level - current INTT processing level
#
# Output:
# n_even, n_odd
@@ -783,7 +803,7 @@ sub intt_butterfly {
my ($w_even, $w_odd, $zetas,
$tmp0, $tmp1, $tmp2,
$n_even, $n_odd,
- $q_neg_inv, $q) = @_;
+ $q_neg_inv, $q, $level) = @_;
$code .= <<___;
# n_even = reduce_once(w_even + w_odd)
@@ -798,11 +818,21 @@ $code .= <<___;
# Multiply n_odd by zetas (step root)
___
- &multiply_8x8_mod_Q($n_odd, # A
+ if ($level < 1) {
+ # level 0: each zeta (B) dword is different
+ &multiply_mod_Q($n_odd, # A
+ $zetas, # B
+ $n_odd, # out (AxB)
+ $tmp0, $tmp1, $tmp2, # tmp
+ $q_neg_inv, $q, 0); # qinv, q, no-bcast32
+ } else {
+ # levels 1 to 7: same zeta (B) dwords within each qword
+ &multiply_mod_Q($n_odd, # A
$zetas, # B
$n_odd, # out (AxB)
$tmp0, $tmp1, $tmp2, # tmp
- $q_neg_inv, $q); # qinv, q
+ $q_neg_inv, $q, 1); # qinv, q, bcast32
+ }
}
###############################################################################
@@ -882,7 +912,7 @@ ___
&intt_butterfly("%ymm0", "%ymm1", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm0", "%ymm1", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 0); # qinv, q, level
$code .= <<___;
@@ -910,7 +940,7 @@ ___
&intt_butterfly("%ymm2", "%ymm3", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm2", "%ymm3", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 0); # qinv, q, level
$code .= <<___;
# load w_even and w_odd
@@ -937,7 +967,7 @@ ___
&intt_butterfly("%ymm4", "%ymm5", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm4", "%ymm5", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 0); # qinv, q, level
$code .= <<___;
@@ -965,7 +995,7 @@ ___
&intt_butterfly("%ymm6", "%ymm7", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm6", "%ymm7", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 0); # qinv, q, level
$code .= <<___;
@@ -992,15 +1022,14 @@ $code .= <<___;
vshufps \$0xee, %ymm9, %ymm8, %ymm1
# load 4 zetas and populate across ymm
- vmovdqu $l1(%r11), %xmm13 # [0 1 2 3]
- vpmovzxdq %xmm13, %ymm13
- vmovsldup %ymm13, %ymm13 # [0 0 1 1] [2 2 3 3]
+ vmovdqu $l1(%r11), %xmm13
+ vpmovzxdq %xmm13, %ymm13 # [0 - 1 -] [2 - 3 -]
___
&intt_butterfly("%ymm0", "%ymm1", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm0", "%ymm1", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 1); # qinv, q, level
$code .= <<___;
@@ -1015,15 +1044,14 @@ $code .= <<___;
vshufps \$0xee, %ymm9, %ymm8, %ymm3
# load 4 zetas and populate across YMM
- vmovdqu $l1+16(%r11), %xmm13 # [0 1 2 3]
- vpmovzxdq %xmm13, %ymm13
- vmovsldup %ymm13, %ymm13 # [0 0 1 1] [2 2 3 3]
+ vmovdqu $l1+16(%r11), %xmm13
+ vpmovzxdq %xmm13, %ymm13 # [0 - 1 -] [2 - 3 -]
___
&intt_butterfly("%ymm2", "%ymm3", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm2", "%ymm3", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 1); # qinv, q, level
$code .= <<___;
# Interleave even/odd within each 128-bit lane:
@@ -1037,15 +1065,14 @@ $code .= <<___;
vshufps \$0xee, %ymm9, %ymm8, %ymm5
# load 4 zetas and populate across ymm
- vmovdqu $l1+32(%r11), %xmm13 # [0 1 2 3]
- vpmovzxdq %xmm13, %ymm13
- vmovsldup %ymm13, %ymm13 # [0 0 1 1] [2 2 3 3]
+ vmovdqu $l1+32(%r11), %xmm13
+ vpmovzxdq %xmm13, %ymm13 # [0 - 1 -] [2 - 3 -]
___
&intt_butterfly("%ymm4", "%ymm5", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm4", "%ymm5", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 1); # qinv, q, level
$code .= <<___;
@@ -1060,15 +1087,14 @@ $code .= <<___;
vshufps \$0xee, %ymm9, %ymm8, %ymm7
# load 4 zetas and populate across YMM
- vmovdqu $l1+48(%r11), %xmm13 # [0 1 2 3]
- vpmovzxdq %xmm13, %ymm13
- vmovsldup %ymm13, %ymm13 # [0 0 1 1] [2 2 3 3]
+ vmovdqu $l1+48(%r11), %xmm13
+ vpmovzxdq %xmm13, %ymm13 # [0 - 1 -] [2 - 3 -]
___
&intt_butterfly("%ymm6", "%ymm7", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm6", "%ymm7", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 1); # qinv, q, level
$code .= <<___;
@@ -1099,7 +1125,7 @@ ___
&intt_butterfly("%ymm8", "%ymm1", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm0", "%ymm1", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 2); # qinv, q, level
$code .= <<___;
vshufps \$0x44, %ymm3, %ymm2, %ymm8
@@ -1114,7 +1140,7 @@ ___
&intt_butterfly("%ymm8", "%ymm3", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm2", "%ymm3", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 2); # qinv, q, level
$code .= <<___;
vshufps \$0x44, %ymm5, %ymm4, %ymm8
@@ -1129,7 +1155,7 @@ ___
&intt_butterfly("%ymm8", "%ymm5", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm4", "%ymm5", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 2); # qinv, q, level
$code .= <<___;
vshufps \$0x44, %ymm7, %ymm6, %ymm8
@@ -1144,7 +1170,7 @@ ___
&intt_butterfly("%ymm8", "%ymm7", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm6", "%ymm7", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 2); # qinv, q, level
$code .= <<___;
@@ -1165,7 +1191,7 @@ ___
&intt_butterfly("%ymm8", "%ymm1", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm0", "%ymm1", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 3); # qinv, q, level
$code .= <<___;
@@ -1179,7 +1205,7 @@ ___
&intt_butterfly("%ymm8", "%ymm3", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm2", "%ymm3", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 3); # qinv, q, level
$code .= <<___;
vperm2i128 \$0x20, %ymm5, %ymm4, %ymm8 # [0,1,2,3 | 4,5,6,7]
@@ -1192,7 +1218,7 @@ ___
&intt_butterfly("%ymm8", "%ymm5", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm4", "%ymm5", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 3); # qinv, q, level
$code .= <<___;
@@ -1206,7 +1232,7 @@ ___
&intt_butterfly("%ymm8", "%ymm7", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm6", "%ymm7", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 3); # qinv, q, level
$code .= <<___;
@@ -1222,12 +1248,12 @@ ___
&intt_butterfly("%ymm0", "%ymm2", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm0", "%ymm2", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 4); # qinv, q, level
&intt_butterfly("%ymm1", "%ymm3", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm1", "%ymm3", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 4); # qinv, q, level
$code .= <<___;
# broadcast zetas
@@ -1237,12 +1263,12 @@ ___
&intt_butterfly("%ymm4", "%ymm6", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm4", "%ymm6", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 4); # qinv, q, level
&intt_butterfly("%ymm5", "%ymm7", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm5", "%ymm7", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 4); # qinv, q, level
$code .= <<___;
@@ -1317,14 +1343,14 @@ ___
&intt_butterfly("%ymm0", "%ymm1", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm0", "%ymm1", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 5); # qinv, q, level
$code .= <<___;
vpbroadcastd 249*4(%r11), %ymm13
___
&intt_butterfly("%ymm2", "%ymm3", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm2", "%ymm3", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 5); # qinv, q, level
$code .= <<___;
vpbroadcastd 250*4(%r11), %ymm13
@@ -1332,7 +1358,7 @@ ___
&intt_butterfly("%ymm4", "%ymm5", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm4", "%ymm5", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 5); # qinv, q, level
$code .= <<___;
vpbroadcastd 251*4(%r11), %ymm13
@@ -1340,7 +1366,7 @@ ___
&intt_butterfly("%ymm6", "%ymm7", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm6", "%ymm7", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 5); # qinv, q, level
$code .= <<___;
# ==============================================================
@@ -1351,22 +1377,22 @@ ___
&intt_butterfly("%ymm0", "%ymm2", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm0", "%ymm2", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 6); # qinv, q, level
&intt_butterfly("%ymm1", "%ymm3", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm1", "%ymm3", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 6); # qinv, q, level
$code .= <<___;
vpbroadcastd 253*4(%r11), %ymm13
___
&intt_butterfly("%ymm4", "%ymm6", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm4", "%ymm6", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 6); # qinv, q, level
&intt_butterfly("%ymm5", "%ymm7", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm5", "%ymm7", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 6); # qinv, q, level
$code .= <<___;
# ==============================================================
@@ -1377,19 +1403,19 @@ ___
&intt_butterfly("%ymm0", "%ymm4", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm0", "%ymm4", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 7); # qinv, q, level
&intt_butterfly("%ymm1", "%ymm5", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm1", "%ymm5", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 7); # qinv, q, level
&intt_butterfly("%ymm2", "%ymm6", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm2", "%ymm6", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 7); # qinv, q, level
&intt_butterfly("%ymm3", "%ymm7", "%ymm13", # w_even, w_odd, zetas
"%ymm10", "%ymm11", "%ymm12", # temporary
"%ymm3", "%ymm7", # n_even, n_odd
- "%ymm14", "%ymm15"); # qinv, q32
+ "%ymm14", "%ymm15", 7); # qinv, q, level
$code .= <<___;
# ==============================================================
@@ -1398,30 +1424,30 @@ $code .= <<___;
vpbroadcastd ml_dsa_inverse_degree_montgomery(%rip), %ymm13
___
- &multiply_8x8_mod_Q("%ymm0", "%ymm13", "%ymm0", # A, B, out (AxB)
- "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
- &multiply_8x8_mod_Q("%ymm4", "%ymm13", "%ymm4", # A, B, out (AxB)
- "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
- &multiply_8x8_mod_Q("%ymm1", "%ymm13", "%ymm1", # A, B, out (AxB)
- "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
- &multiply_8x8_mod_Q("%ymm5", "%ymm13", "%ymm5", # A, B, out (AxB)
- "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
- &multiply_8x8_mod_Q("%ymm2", "%ymm13", "%ymm2", # A, B, out (AxB)
- "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
- &multiply_8x8_mod_Q("%ymm6", "%ymm13", "%ymm6", # A, B, out (AxB)
- "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
- &multiply_8x8_mod_Q("%ymm3", "%ymm13", "%ymm3", # A, B, out (AxB)
- "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
- &multiply_8x8_mod_Q("%ymm7", "%ymm13", "%ymm7", # A, B, out (AxB)
- "%ymm10", "%ymm11", "%ymm12", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ &multiply_mod_Q("%ymm0", "%ymm13", "%ymm0", # A, B, out (AxB)
+ "%ymm10", "%ymm11", "%ymm12", # tmp
+ "%ymm14", "%ymm15", 1); # qinv, q, bcast32
+ &multiply_mod_Q("%ymm4", "%ymm13", "%ymm4", # A, B, out (AxB)
+ "%ymm10", "%ymm11", "%ymm12", # tmp
+ "%ymm14", "%ymm15", 1); # qinv, q, bcast32
+ &multiply_mod_Q("%ymm1", "%ymm13", "%ymm1", # A, B, out (AxB)
+ "%ymm10", "%ymm11", "%ymm12", # tmp
+ "%ymm14", "%ymm15", 1); # qinv, q, bcast32
+ &multiply_mod_Q("%ymm5", "%ymm13", "%ymm5", # A, B, out (AxB)
+ "%ymm10", "%ymm11", "%ymm12", # tmp
+ "%ymm14", "%ymm15", 1); # qinv, q, bcast32
+ &multiply_mod_Q("%ymm2", "%ymm13", "%ymm2", # A, B, out (AxB)
+ "%ymm10", "%ymm11", "%ymm12", # tmp
+ "%ymm14", "%ymm15", 1); # qinv, q, bcast32
+ &multiply_mod_Q("%ymm6", "%ymm13", "%ymm6", # A, B, out (AxB)
+ "%ymm10", "%ymm11", "%ymm12", # tmp
+ "%ymm14", "%ymm15", 1); # qinv, q, bcast32
+ &multiply_mod_Q("%ymm3", "%ymm13", "%ymm3", # A, B, out (AxB)
+ "%ymm10", "%ymm11", "%ymm12", # tmp
+ "%ymm14", "%ymm15", 1); # qinv, q, bcast32
+ &multiply_mod_Q("%ymm7", "%ymm13", "%ymm7", # A, B, out (AxB)
+ "%ymm10", "%ymm11", "%ymm12", # tmp
+ "%ymm14", "%ymm15", 1); # qinv, q, bcast32
$code .= <<___;
@@ -1579,9 +1605,9 @@ ml_dsa_poly_ntt_mult_avx2:
# multiply this part of input data
___
- &multiply_8x8_mod_Q("%ymm0", "%ymm1", "%ymm0", # A, B, out (AxB)
- "%ymm8", "%ymm9", "%ymm10", # tmp
- "%ymm14", "%ymm15"); # qinv, q
+ &multiply_mod_Q("%ymm0", "%ymm1", "%ymm0", # A, B, out (AxB)
+ "%ymm8", "%ymm9", "%ymm10", # tmp
+ "%ymm14", "%ymm15", 0); # qinv, q, bcast32
$code .= <<___;
# store result to output
@@ -1632,11 +1658,7 @@ $code .= <<___;
ml_dsa_poly_ntt_avx2:
.cfi_startproc
- sub \$32, %rsp
-.cfi_adjust_cfa_offset 32 # track rsp change so unwinder can find CFA
-
- # save input arguments
- mov %rdi, %r10
+ # move p_zetas to r11
mov %rsi, %r11
# load constants
@@ -1647,8 +1669,7 @@ ml_dsa_poly_ntt_avx2:
# - level 0: offset = 128, step = 1, zeta indexes = 1
# - level 1: offset = 64, step = 2, zeta indexes = 2, 3
# - level 2: offset = 32, step = 4, zeta indexes = 4, 5, 6, 7
-
- mov %r10, %rdi # p_coeffs
+ # p_coeffs already in rdi
___
&ntt_levels0to2(0*4);
@@ -1669,8 +1690,7 @@ $code .= <<___;
# zeta indexes = 64, 65, 66, ..., 126, 127
# - level 7: offset = 1, step = 128
# zeta indexes = 128, 129, 130, ..., 254, 255
-
- mov %r10, %rdi # p_even / p_coeff
+ # p_coeffs already in rdi
___
# arguments: coeff, l3, l4, l5, l6, l7
@@ -1682,10 +1702,6 @@ ___
$code .= <<___;
vzeroall
-
- lea 32(%rsp), %rsp
-.cfi_adjust_cfa_offset -32
-
ret
.cfi_endproc
.size ml_dsa_poly_ntt_avx2, .-ml_dsa_poly_ntt_avx2