Commit 3b3680c639 for openssl.org

commit 3b3680c63936a0c316d25f783eee7406bca1b320
Author: Marcel Cornu <marcel.d.cornu@intel.com>
Date:   Wed Feb 25 16:44:07 2026 +0000

    ML-DSA: Refactor to use function pointers for AVX2/scalar NTT

    Replace inline AVX2 capability checks with function pointers that are
    initialized once at startup using CRYPTO_THREAD_run_once.

    Reviewed-by: Saša NedvÄ›dický <sashan@openssl.org>
    Reviewed-by: Paul Dale <paul.dale@oracle.com>
    Reviewed-by: Neil Horman <nhorman@openssl.org>
    MergeDate: Wed Mar 11 15:47:43 2026
    (Merged from https://github.com/openssl/openssl/pull/30160)

diff --git a/crypto/ml_dsa/ml_dsa_ntt.c b/crypto/ml_dsa/ml_dsa_ntt.c
index 41e8b27e81..dc81f81822 100644
--- a/crypto/ml_dsa/ml_dsa_ntt.c
+++ b/crypto/ml_dsa/ml_dsa_ntt.c
@@ -1,5 +1,5 @@
 /*
- * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2024-2026 The OpenSSL Project Authors. All Rights Reserved.
  *
  * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
@@ -9,15 +9,40 @@

 #include "ml_dsa_local.h"
 #include "ml_dsa_poly.h"
+#include <openssl/crypto.h>

 /* Assembly function declarations for AVX2 implementations */
 #if !defined(OPENSSL_NO_ASM) && (defined(__x86_64) || defined(__x86_64__) || defined(_M_AMD64) || defined(_M_X64))
+#define ML_DSA_NTT_ASM
 int ml_dsa_ntt_avx2_capable(void);
 void ml_dsa_poly_ntt_avx2(uint32_t *p_coeff, const uint32_t *p_zetas);
 void ml_dsa_poly_ntt_inverse_avx2(uint32_t *p_coeff);
 void ml_dsa_poly_ntt_mult_avx2(const uint32_t *a, const uint32_t *b, uint32_t *out);
 #endif

+/*
+ * Function pointer types for NTT operations.
+ * These allow selecting AVX2 or scalar implementations at initialization time.
+ */
+typedef void (*ml_dsa_poly_ntt_fn)(POLY *p);
+typedef void (*ml_dsa_poly_ntt_inverse_fn)(POLY *p);
+typedef void (*ml_dsa_poly_ntt_mult_fn)(const POLY *lhs, const POLY *rhs,
+    POLY *out);
+
+/* Forward declarations of scalar NTT functions */
+static void poly_ntt_scalar(POLY *p);
+static void poly_ntt_inverse_scalar(POLY *p);
+static void poly_ntt_mult_scalar(const POLY *lhs, const POLY *rhs, POLY *out);
+
+/*
+ * NTT function pointers - initialized to scalar implementations by default.
+ */
+static ml_dsa_poly_ntt_fn poly_ntt_impl = poly_ntt_scalar;
+static ml_dsa_poly_ntt_inverse_fn poly_ntt_inverse_impl = poly_ntt_inverse_scalar;
+static ml_dsa_poly_ntt_mult_fn poly_ntt_mult_impl = poly_ntt_mult_scalar;
+
+static CRYPTO_ONCE ml_dsa_ntt_once = CRYPTO_ONCE_STATIC_INIT;
+
 /*
  * This file has multiple parts required for fast matrix multiplication,
  * 1) NTT (See https://eprint.iacr.org/2024/585.pdf)
@@ -108,52 +133,24 @@ static uint32_t reduce_montgomery(uint64_t a)
 }

 /*
- * @brief Multiply two polynomials in the number theoretically transformed state.
- * See FIPS 204, Algorithm 45, MultiplyNTT()
- * This function has been modified to use montgomery multiplication
- *
- * @param lhs A polynomial multiplicand
- * @param rhs A polynomial multiplier
- * @param out The returned result of the polynomial multiply
+ * Scalar (fallback) implementations of NTT operations.
+ * These are used when AVX2 is not available.
  */
-void ossl_ml_dsa_poly_ntt_mult(const POLY *lhs, const POLY *rhs, POLY *out)
+static void poly_ntt_mult_scalar(const POLY *lhs, const POLY *rhs, POLY *out)
 {
     int i;

-#if !defined(OPENSSL_NO_ASM) && (defined(__x86_64) || defined(__x86_64__) || defined(_M_AMD64) || defined(_M_X64))
-    if (ml_dsa_ntt_avx2_capable()) {
-        ml_dsa_poly_ntt_mult_avx2(&lhs->coeff[0], &rhs->coeff[0], &out->coeff[0]);
-        return;
-    }
-#endif
-
     for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
-        out->coeff[i] = reduce_montgomery((uint64_t)lhs->coeff[i] * (uint64_t)rhs->coeff[i]);
+        out->coeff[i] = reduce_montgomery((uint64_t)lhs->coeff[i]
+            * (uint64_t)rhs->coeff[i]);
 }

-/*
- * In place number theoretic transform of a given polynomial.
- *
- * See FIPS 204, Algorithm 41, NTT()
- * This function uses montgomery multiplication.
- *
- * @param p a polynomial that is used as the input, that is replaced with
- *        the NTT of the polynomial
- */
-void ossl_ml_dsa_poly_ntt(POLY *p)
+static void poly_ntt_scalar(POLY *p)
 {
     int i, j, k;
     int step;
     int offset = ML_DSA_NUM_POLY_COEFFICIENTS;

-#if !defined(OPENSSL_NO_ASM) && (defined(__x86_64) || defined(__x86_64__) || defined(_M_AMD64) || defined(_M_X64))
-    if (ml_dsa_ntt_avx2_capable()) {
-        ml_dsa_poly_ntt_avx2(&p->coeff[0], zetas_montgomery);
-        return;
-    }
-#endif
-
-    /* Fallback implementation */
     /* Step: 1, 2, 4, 8, ..., 128 */
     for (step = 1; step < ML_DSA_NUM_POLY_COEFFICIENTS; step <<= 1) {
         k = 0;
@@ -174,14 +171,7 @@ void ossl_ml_dsa_poly_ntt(POLY *p)
     }
 }

-/*
- * @brief In place inverse number theoretic transform of a given polynomial.
- * See FIPS 204, Algorithm 42,  NTT^-1()
- *
- * @param p a polynomial that is used as the input, that is overwritten with
- *          the inverse of the NTT.
- */
-void ossl_ml_dsa_poly_ntt_inverse(POLY *p)
+static void poly_ntt_inverse_scalar(POLY *p)
 {
     /*
      * Step: 128, 64, 32, 16, ..., 1
@@ -194,14 +184,6 @@ void ossl_ml_dsa_poly_ntt_inverse(POLY *p)
      */
     static const uint32_t inverse_degree_montgomery = 41978;

-#if !defined(OPENSSL_NO_ASM) && (defined(__x86_64) || defined(__x86_64__) || defined(_M_AMD64) || defined(_M_X64))
-    if (ml_dsa_ntt_avx2_capable()) {
-        ml_dsa_poly_ntt_inverse_avx2(&p->coeff[0]);
-        return;
-    }
-#endif
-
-    /* Fallback implementation */
     for (offset = 1; offset < ML_DSA_NUM_POLY_COEFFICIENTS; offset <<= 1) {
         step >>= 1;
         k = 0;
@@ -220,5 +202,85 @@ void ossl_ml_dsa_poly_ntt_inverse(POLY *p)
         }
     }
     for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
-        p->coeff[i] = reduce_montgomery((uint64_t)p->coeff[i] * (uint64_t)inverse_degree_montgomery);
+        p->coeff[i] = reduce_montgomery((uint64_t)p->coeff[i]
+            * (uint64_t)inverse_degree_montgomery);
+}
+
+/*
+ * AVX2 wrapper functions
+ */
+#ifdef ML_DSA_NTT_ASM
+static void poly_ntt_mult_avx2_wrapper(const POLY *lhs, const POLY *rhs,
+    POLY *out)
+{
+    ml_dsa_poly_ntt_mult_avx2(&lhs->coeff[0], &rhs->coeff[0], &out->coeff[0]);
+}
+
+static void poly_ntt_avx2_wrapper(POLY *p)
+{
+    ml_dsa_poly_ntt_avx2(&p->coeff[0], zetas_montgomery);
+}
+
+static void poly_ntt_inverse_avx2_wrapper(POLY *p)
+{
+    ml_dsa_poly_ntt_inverse_avx2(&p->coeff[0]);
+}
+#endif
+
+/*
+ * Initialize NTT function pointers to AVX2 implementations if available.
+ * Scalar implementations are used by default.
+ */
+static void ml_dsa_ntt_init(void)
+{
+#ifdef ML_DSA_NTT_ASM
+    if (ml_dsa_ntt_avx2_capable()) {
+        poly_ntt_impl = poly_ntt_avx2_wrapper;
+        poly_ntt_inverse_impl = poly_ntt_inverse_avx2_wrapper;
+        poly_ntt_mult_impl = poly_ntt_mult_avx2_wrapper;
+    }
+#endif
+}
+
+/*
+ * @brief Multiply two polynomials in the number theoretically transformed state.
+ * See FIPS 204, Algorithm 45, MultiplyNTT()
+ * This function has been modified to use montgomery multiplication
+ *
+ * @param lhs A polynomial multiplicand
+ * @param rhs A polynomial multiplier
+ * @param out The returned result of the polynomial multiply
+ */
+void ossl_ml_dsa_poly_ntt_mult(const POLY *lhs, const POLY *rhs, POLY *out)
+{
+    (void)CRYPTO_THREAD_run_once(&ml_dsa_ntt_once, ml_dsa_ntt_init);
+    poly_ntt_mult_impl(lhs, rhs, out);
+}
+
+/*
+ * In place number theoretic transform of a given polynomial.
+ *
+ * See FIPS 204, Algorithm 41, NTT()
+ * This function uses montgomery multiplication.
+ *
+ * @param p a polynomial that is used as the input, that is replaced with
+ *        the NTT of the polynomial
+ */
+void ossl_ml_dsa_poly_ntt(POLY *p)
+{
+    (void)CRYPTO_THREAD_run_once(&ml_dsa_ntt_once, ml_dsa_ntt_init);
+    poly_ntt_impl(p);
+}
+
+/*
+ * @brief In place inverse number theoretic transform of a given polynomial.
+ * See FIPS 204, Algorithm 42,  NTT^-1()
+ *
+ * @param p a polynomial that is used as the input, that is overwritten with
+ *          the inverse of the NTT.
+ */
+void ossl_ml_dsa_poly_ntt_inverse(POLY *p)
+{
+    (void)CRYPTO_THREAD_run_once(&ml_dsa_ntt_once, ml_dsa_ntt_init);
+    poly_ntt_inverse_impl(p);
 }