Commit 3b3680c639 for openssl.org
commit 3b3680c63936a0c316d25f783eee7406bca1b320
Author: Marcel Cornu <marcel.d.cornu@intel.com>
Date: Wed Feb 25 16:44:07 2026 +0000
ML-DSA: Refactor to use function pointers for AVX2/scalar NTT
Replace inline AVX2 capability checks with function pointers that are
initialized once at startup using CRYPTO_THREAD_run_once.
Reviewed-by: Saša NedvÄ›dický <sashan@openssl.org>
Reviewed-by: Paul Dale <paul.dale@oracle.com>
Reviewed-by: Neil Horman <nhorman@openssl.org>
MergeDate: Wed Mar 11 15:47:43 2026
(Merged from https://github.com/openssl/openssl/pull/30160)
diff --git a/crypto/ml_dsa/ml_dsa_ntt.c b/crypto/ml_dsa/ml_dsa_ntt.c
index 41e8b27e81..dc81f81822 100644
--- a/crypto/ml_dsa/ml_dsa_ntt.c
+++ b/crypto/ml_dsa/ml_dsa_ntt.c
@@ -1,5 +1,5 @@
/*
- * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2024-2026 The OpenSSL Project Authors. All Rights Reserved.
*
* Licensed under the Apache License 2.0 (the "License"). You may not use
* this file except in compliance with the License. You can obtain a copy
@@ -9,15 +9,40 @@
#include "ml_dsa_local.h"
#include "ml_dsa_poly.h"
+#include <openssl/crypto.h>
/* Assembly function declarations for AVX2 implementations */
#if !defined(OPENSSL_NO_ASM) && (defined(__x86_64) || defined(__x86_64__) || defined(_M_AMD64) || defined(_M_X64))
+#define ML_DSA_NTT_ASM
int ml_dsa_ntt_avx2_capable(void);
void ml_dsa_poly_ntt_avx2(uint32_t *p_coeff, const uint32_t *p_zetas);
void ml_dsa_poly_ntt_inverse_avx2(uint32_t *p_coeff);
void ml_dsa_poly_ntt_mult_avx2(const uint32_t *a, const uint32_t *b, uint32_t *out);
#endif
+/*
+ * Function pointer types for NTT operations.
+ * These allow selecting AVX2 or scalar implementations at initialization time.
+ */
+typedef void (*ml_dsa_poly_ntt_fn)(POLY *p);
+typedef void (*ml_dsa_poly_ntt_inverse_fn)(POLY *p);
+typedef void (*ml_dsa_poly_ntt_mult_fn)(const POLY *lhs, const POLY *rhs,
+ POLY *out);
+
+/* Forward declarations of scalar NTT functions */
+static void poly_ntt_scalar(POLY *p);
+static void poly_ntt_inverse_scalar(POLY *p);
+static void poly_ntt_mult_scalar(const POLY *lhs, const POLY *rhs, POLY *out);
+
+/*
+ * NTT function pointers - initialized to scalar implementations by default.
+ */
+static ml_dsa_poly_ntt_fn poly_ntt_impl = poly_ntt_scalar;
+static ml_dsa_poly_ntt_inverse_fn poly_ntt_inverse_impl = poly_ntt_inverse_scalar;
+static ml_dsa_poly_ntt_mult_fn poly_ntt_mult_impl = poly_ntt_mult_scalar;
+
+static CRYPTO_ONCE ml_dsa_ntt_once = CRYPTO_ONCE_STATIC_INIT;
+
/*
* This file has multiple parts required for fast matrix multiplication,
* 1) NTT (See https://eprint.iacr.org/2024/585.pdf)
@@ -108,52 +133,24 @@ static uint32_t reduce_montgomery(uint64_t a)
}
/*
- * @brief Multiply two polynomials in the number theoretically transformed state.
- * See FIPS 204, Algorithm 45, MultiplyNTT()
- * This function has been modified to use montgomery multiplication
- *
- * @param lhs A polynomial multiplicand
- * @param rhs A polynomial multiplier
- * @param out The returned result of the polynomial multiply
+ * Scalar (fallback) implementations of NTT operations.
+ * These are used when AVX2 is not available.
*/
-void ossl_ml_dsa_poly_ntt_mult(const POLY *lhs, const POLY *rhs, POLY *out)
+static void poly_ntt_mult_scalar(const POLY *lhs, const POLY *rhs, POLY *out)
{
int i;
-#if !defined(OPENSSL_NO_ASM) && (defined(__x86_64) || defined(__x86_64__) || defined(_M_AMD64) || defined(_M_X64))
- if (ml_dsa_ntt_avx2_capable()) {
- ml_dsa_poly_ntt_mult_avx2(&lhs->coeff[0], &rhs->coeff[0], &out->coeff[0]);
- return;
- }
-#endif
-
for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
- out->coeff[i] = reduce_montgomery((uint64_t)lhs->coeff[i] * (uint64_t)rhs->coeff[i]);
+ out->coeff[i] = reduce_montgomery((uint64_t)lhs->coeff[i]
+ * (uint64_t)rhs->coeff[i]);
}
-/*
- * In place number theoretic transform of a given polynomial.
- *
- * See FIPS 204, Algorithm 41, NTT()
- * This function uses montgomery multiplication.
- *
- * @param p a polynomial that is used as the input, that is replaced with
- * the NTT of the polynomial
- */
-void ossl_ml_dsa_poly_ntt(POLY *p)
+static void poly_ntt_scalar(POLY *p)
{
int i, j, k;
int step;
int offset = ML_DSA_NUM_POLY_COEFFICIENTS;
-#if !defined(OPENSSL_NO_ASM) && (defined(__x86_64) || defined(__x86_64__) || defined(_M_AMD64) || defined(_M_X64))
- if (ml_dsa_ntt_avx2_capable()) {
- ml_dsa_poly_ntt_avx2(&p->coeff[0], zetas_montgomery);
- return;
- }
-#endif
-
- /* Fallback implementation */
/* Step: 1, 2, 4, 8, ..., 128 */
for (step = 1; step < ML_DSA_NUM_POLY_COEFFICIENTS; step <<= 1) {
k = 0;
@@ -174,14 +171,7 @@ void ossl_ml_dsa_poly_ntt(POLY *p)
}
}
-/*
- * @brief In place inverse number theoretic transform of a given polynomial.
- * See FIPS 204, Algorithm 42, NTT^-1()
- *
- * @param p a polynomial that is used as the input, that is overwritten with
- * the inverse of the NTT.
- */
-void ossl_ml_dsa_poly_ntt_inverse(POLY *p)
+static void poly_ntt_inverse_scalar(POLY *p)
{
/*
* Step: 128, 64, 32, 16, ..., 1
@@ -194,14 +184,6 @@ void ossl_ml_dsa_poly_ntt_inverse(POLY *p)
*/
static const uint32_t inverse_degree_montgomery = 41978;
-#if !defined(OPENSSL_NO_ASM) && (defined(__x86_64) || defined(__x86_64__) || defined(_M_AMD64) || defined(_M_X64))
- if (ml_dsa_ntt_avx2_capable()) {
- ml_dsa_poly_ntt_inverse_avx2(&p->coeff[0]);
- return;
- }
-#endif
-
- /* Fallback implementation */
for (offset = 1; offset < ML_DSA_NUM_POLY_COEFFICIENTS; offset <<= 1) {
step >>= 1;
k = 0;
@@ -220,5 +202,85 @@ void ossl_ml_dsa_poly_ntt_inverse(POLY *p)
}
}
for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
- p->coeff[i] = reduce_montgomery((uint64_t)p->coeff[i] * (uint64_t)inverse_degree_montgomery);
+ p->coeff[i] = reduce_montgomery((uint64_t)p->coeff[i]
+ * (uint64_t)inverse_degree_montgomery);
+}
+
+/*
+ * AVX2 wrapper functions
+ */
+#ifdef ML_DSA_NTT_ASM
+static void poly_ntt_mult_avx2_wrapper(const POLY *lhs, const POLY *rhs,
+ POLY *out)
+{
+ ml_dsa_poly_ntt_mult_avx2(&lhs->coeff[0], &rhs->coeff[0], &out->coeff[0]);
+}
+
+static void poly_ntt_avx2_wrapper(POLY *p)
+{
+ ml_dsa_poly_ntt_avx2(&p->coeff[0], zetas_montgomery);
+}
+
+static void poly_ntt_inverse_avx2_wrapper(POLY *p)
+{
+ ml_dsa_poly_ntt_inverse_avx2(&p->coeff[0]);
+}
+#endif
+
+/*
+ * Initialize NTT function pointers to AVX2 implementations if available.
+ * Scalar implementations are used by default.
+ */
+static void ml_dsa_ntt_init(void)
+{
+#ifdef ML_DSA_NTT_ASM
+ if (ml_dsa_ntt_avx2_capable()) {
+ poly_ntt_impl = poly_ntt_avx2_wrapper;
+ poly_ntt_inverse_impl = poly_ntt_inverse_avx2_wrapper;
+ poly_ntt_mult_impl = poly_ntt_mult_avx2_wrapper;
+ }
+#endif
+}
+
+/*
+ * @brief Multiply two polynomials in the number theoretically transformed state.
+ * See FIPS 204, Algorithm 45, MultiplyNTT()
+ * This function has been modified to use montgomery multiplication
+ *
+ * @param lhs A polynomial multiplicand
+ * @param rhs A polynomial multiplier
+ * @param out The returned result of the polynomial multiply
+ */
+void ossl_ml_dsa_poly_ntt_mult(const POLY *lhs, const POLY *rhs, POLY *out)
+{
+ (void)CRYPTO_THREAD_run_once(&ml_dsa_ntt_once, ml_dsa_ntt_init);
+ poly_ntt_mult_impl(lhs, rhs, out);
+}
+
+/*
+ * In place number theoretic transform of a given polynomial.
+ *
+ * See FIPS 204, Algorithm 41, NTT()
+ * This function uses montgomery multiplication.
+ *
+ * @param p a polynomial that is used as the input, that is replaced with
+ * the NTT of the polynomial
+ */
+void ossl_ml_dsa_poly_ntt(POLY *p)
+{
+ (void)CRYPTO_THREAD_run_once(&ml_dsa_ntt_once, ml_dsa_ntt_init);
+ poly_ntt_impl(p);
+}
+
+/*
+ * @brief In place inverse number theoretic transform of a given polynomial.
+ * See FIPS 204, Algorithm 42, NTT^-1()
+ *
+ * @param p a polynomial that is used as the input, that is overwritten with
+ * the inverse of the NTT.
+ */
+void ossl_ml_dsa_poly_ntt_inverse(POLY *p)
+{
+ (void)CRYPTO_THREAD_run_once(&ml_dsa_ntt_once, ml_dsa_ntt_init);
+ poly_ntt_inverse_impl(p);
}