Commit 90f0137453 for openssl.org

commit 90f0137453aaec5f09d26fda91c6025ae25e4130
Author: Simo Sorce <simo@redhat.com>
Date:   Wed Apr 9 09:35:20 2025 -0400

    Split the ML-DSA internal sigver functions

    Deconstruct the functions into 2 parts:
    - mu computation (if needed)
    - actual signing/verification

    Adds helper to compute mu that is split in 3 parts
    (init/update/finalize) where the update part can be used to feed the message
    to be signed or verified in chunks of any size.

    Signed-off-by: Simo Sorce <simo@redhat.com>

    Reviewed-by: Viktor Dukhovni <viktor@openssl.org>
    Reviewed-by: Tomas Mraz <tomas@openssl.org>
    (Merged from https://github.com/openssl/openssl/pull/27342)

diff --git a/crypto/ml_dsa/ml_dsa_sign.c b/crypto/ml_dsa/ml_dsa_sign.c
index 346f635094..bbeb95e2a3 100644
--- a/crypto/ml_dsa/ml_dsa_sign.c
+++ b/crypto/ml_dsa/ml_dsa_sign.c
@@ -11,6 +11,9 @@
 #include <openssl/core_names.h>
 #include <openssl/params.h>
 #include <openssl/rand.h>
+#include <openssl/err.h>
+#include <openssl/proverr.h>
+#include "internal/common.h"
 #include "ml_dsa_local.h"
 #include "ml_dsa_key.h"
 #include "ml_dsa_matrix.h"
@@ -43,12 +46,115 @@ static void signature_init(ML_DSA_SIG *sig,
 }

 /*
- * FIPS 204, Algorithm 7, ML-DSA.Sign_internal()
- * @returns 1 on success and 0 on failure.
+ * @brief: Auxiliary functions to compute ML-DSA's MU.
+ * This combines the steps of creating M' and concatenating it
+ * to the Public Key Hash to obtain MU.
+ * See FIPS 204 Algorithm 2 Step 10 (and algorithm 3 Step 5) as
+ * well as Algorithm 7 Step 6 (and algorithm 8 Step 7)
+ *
+ * ML_DSA pure signatures are encoded as M' = 00 || ctx_len || ctx || msg
+ * Where ctx is the empty string by default and ctx_len <= 255.
+ * The message is appended to the encoded context.
+ * Finally a public key hash is prepended, and the whole is hashed
+ * to derive the mu value.
+ *
+ * @param key: A public or private ML-DSA key;
+ * @param encode: if not set, assumes that M' is provided raw and the
+ * following parameters are ignored.
+ * @param ctx An optional context to add to the message encoding.
+ * @param ctx_len The size of |ctx|. It must be in the range 0..255
+ * @returns an EVP_MD_CTX if the operation is successful, NULL otherwise.
+ */
+
+static EVP_MD_CTX *ml_dsa_mu_init(const ML_DSA_KEY *key, int encode,
+                                  const uint8_t *ctx, size_t ctx_len)
+{
+    EVP_MD_CTX *md_ctx;
+    uint8_t itb[2];
+
+    if (key == NULL)
+        return NULL;
+
+    md_ctx = EVP_MD_CTX_new();
+    if (md_ctx == NULL)
+        return NULL;
+
+    /* H(.. */
+    if (!EVP_DigestInit_ex2(md_ctx, key->shake256_md, NULL))
+        goto err;
+    /* ..pk (= key->tr) */
+    if (!EVP_DigestUpdate(md_ctx, key->tr, sizeof(key->tr)))
+        goto err;
+    /* M' = .. */
+    if (encode) {
+        if (ctx_len > ML_DSA_MAX_CONTEXT_STRING_LEN)
+            goto err;
+        /* IntegerToBytes(0, 1) .. */
+        itb[0] = 0;
+        /* || IntegerToBytes(|ctx|, 1) || .. */
+        itb[1] = (uint8_t)ctx_len;
+        if (!EVP_DigestUpdate(md_ctx, itb, 2))
+            goto err;
+        /* ctx || .. */
+        if (!EVP_DigestUpdate(md_ctx, ctx, ctx_len))
+            goto err;
+        /* .. msg) will follow in update and final functions */
+    }
+
+    return md_ctx;
+
+err:
+    EVP_MD_CTX_free(md_ctx);
+    return NULL;
+}
+
+/*
+ * @brief: updates the internal ML-DSA hash with an additional message chunk.
+ *
+ * @param md_ctx: The hashing context
+ * @param msg: The next message chunk
+ * @param msg_len: The length of the msg buffer to process
+ * @returns 1 on success, 0 on error
+ */
+static int ml_dsa_mu_update(EVP_MD_CTX *md_ctx,
+                            const uint8_t *msg, size_t msg_len)
+{
+    return EVP_DigestUpdate(md_ctx, msg, msg_len);
+}
+
+/*
+ * @brief: finalizes the internal ML-DSA hash
+ *
+ * @param md_ctx: The hashing context
+ * @param mu: The output buffer for Mu
+ * @param mu_len: The size of the output buffer
+ * @returns 1 on success, 0 on error
+ */
+static int ml_dsa_mu_finalize(EVP_MD_CTX *md_ctx, uint8_t *mu, size_t mu_len)
+{
+    if (!ossl_assert(mu_len == ML_DSA_MU_BYTES)) {
+        ERR_raise(ERR_LIB_PROV, PROV_R_BAD_LENGTH);
+        return 0;
+    }
+    return EVP_DigestSqueeze(md_ctx, mu, mu_len);
+}
+
+/*
+ * @brief FIPS 204, Algorithm 7, ML-DSA.Sign_internal()
+ *
+ * This algorithm is decomposed in 2 steps, a set of functions to compute mu
+ * and then the actual signing function.
+ *
+ * @param priv: The private ML-DSA key
+ * @param mu: The pre-computed mu hash
+ * @param mu_len: The length of the mu buffer
+ * @param rnd: The random buffer
+ * @param rnd_len: The length of the random buffer
+ * @param out_sig: The output signature buffer
+ * @returns 1 on success, 0 on error
  */
-static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu,
-                                const uint8_t *encoded_msg,
-                                size_t encoded_msg_len,
+static int ml_dsa_sign_internal(const ML_DSA_KEY *priv,
+                                const uint8_t *mu, size_t mu_len,
                                 const uint8_t *rnd, size_t rnd_len,
                                 uint8_t *out_sig)
 {
@@ -63,25 +169,28 @@ static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu,
     size_t num_polys_k = 5 * k;
     size_t num_polys_l = 3 * l;
     size_t num_polys_k_by_l = k * l;
-    POLY *polys = NULL, *p, *c_ntt;
+    POLY *p, *c_ntt;
     VECTOR s1_ntt, s2_ntt, t0_ntt, w, w1, cs1, cs2, y;
     MATRIX a_ntt;
     ML_DSA_SIG sig;
-    uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu;
-    const size_t mu_len = sizeof(mu);
     uint8_t rho_prime[ML_DSA_RHO_PRIME_BYTES];
     uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
     size_t c_tilde_len = params->bit_strength >> 2;
     size_t kappa;

+    if (mu_len != ML_DSA_MU_BYTES) {
+        ERR_raise(ERR_LIB_PROV, PROV_R_BAD_LENGTH);
+        return 0;
+    }
+
     /*
      * Allocate a single blob for most of the variable size temporary variables.
      * Mostly used for VECTOR POLYNOMIALS (every POLY is 1K).
      */
     w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
     alloc_len = w1_encoded_len
-        + sizeof(*polys) * (1 + num_polys_k + num_polys_l
-                            + num_polys_k_by_l + num_polys_sig_k);
+        + sizeof(*p) * (1 + num_polys_k + num_polys_l
+                        + num_polys_k_by_l + num_polys_sig_k);
     alloc = OPENSSL_malloc(alloc_len);
     if (alloc == NULL)
         return 0;
@@ -110,17 +219,9 @@ static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu,

     if (!matrix_expand_A(md_ctx, priv->shake128_md, priv->rho, &a_ntt))
         goto err;
-    if (msg_is_mu) {
-        if (encoded_msg_len != mu_len)
-            goto err;
-        mu_ptr = (uint8_t *)encoded_msg;
-    } else {
-        if (!shake_xof_2(md_ctx, priv->shake256_md, priv->tr, sizeof(priv->tr),
-                         encoded_msg, encoded_msg_len, mu_ptr, mu_len))
-            goto err;
-    }
+
     if (!shake_xof_3(md_ctx, priv->shake256_md, priv->K, sizeof(priv->K),
-                     rnd, rnd_len, mu_ptr, mu_len,
+                     rnd, rnd_len, mu, mu_len,
                      rho_prime, sizeof(rho_prime)))
         goto err;

@@ -152,7 +253,7 @@ static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu,
         vector_high_bits(&w, gamma2, &w1);
         ossl_ml_dsa_w1_encode(&w1, gamma2, w1_encoded, w1_encoded_len);

-        if (!shake_xof_2(md_ctx, priv->shake256_md, mu_ptr, mu_len,
+        if (!shake_xof_2(md_ctx, priv->shake256_md, mu, mu_len,
                          w1_encoded, w1_encoded_len, c_tilde, c_tilde_len))
             break;

@@ -202,15 +303,26 @@ err:
 }

 /*
- * See FIPS 204, Algorithm 8, ML-DSA.Verify_internal().
+ * @brief FIPS 204, Algorithm 8, ML-DSA.Verify_internal().
+ *
+ * This algorithm is decomposed in 2 steps, a set of functions to compute mu
+ * and then the actual verification function.
+ *
+ * @param pub: The public ML-DSA key
+ * @param mu: The pre-computed mu hash
+ * @param mu_len: The length of the mu buffer
+ * @param sig_enc: The encoded signature to be verified
+ * @param sig_enc_len: the encoded csignature length
+ * @returns 1 on success, 0 on error
  */
-static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu,
-                                  const uint8_t *msg_enc, size_t msg_enc_len,
-                                  const uint8_t *sig_enc, size_t sig_enc_len)
+static int ml_dsa_verify_internal(const ML_DSA_KEY *pub,
+                                  const uint8_t *mu, size_t mu_len,
+                                  const uint8_t *sig_enc,
+                                  size_t sig_enc_len)
 {
     int ret = 0;
     uint8_t *alloc = NULL, *w1_encoded;
-    POLY *polys = NULL, *p, *c_ntt;
+    POLY *p, *c_ntt;
     MATRIX a_ntt;
     VECTOR az_ntt, ct1_ntt, *z_ntt, *w1, *w_approx;
     ML_DSA_SIG sig;
@@ -223,21 +335,25 @@ static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu,
     size_t num_polys_k = 2 * k;
     size_t num_polys_l = 1 * l;
     size_t num_polys_k_by_l = k * l;
-    uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu;
-    const size_t mu_len = sizeof(mu);
     uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
     uint8_t c_tilde_sig[ML_DSA_MAX_LAMBDA / 4];
     EVP_MD_CTX *md_ctx = NULL;
     size_t c_tilde_len = params->bit_strength >> 2;
     uint32_t z_max;

+    if (mu_len != ML_DSA_MU_BYTES) {
+        ERR_raise(ERR_LIB_PROV, PROV_R_BAD_LENGTH);
+        return 0;
+    }
+
+
     /* Allocate space for all the POLYNOMIALS used by temporary VECTORS */
     w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
     alloc = OPENSSL_malloc(w1_encoded_len
-                           + sizeof(*polys) * (1 + num_polys_k
-                                               + num_polys_l
-                                               + num_polys_k_by_l
-                                               + num_polys_sig));
+                           + sizeof(*p) * (1 + num_polys_k
+                                           + num_polys_l
+                                           + num_polys_k_by_l
+                                           + num_polys_sig));
     if (alloc == NULL)
         return 0;
     md_ctx = EVP_MD_CTX_new();
@@ -258,16 +374,8 @@ static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu,
     if (!ossl_ml_dsa_sig_decode(&sig, sig_enc, sig_enc_len, pub->params)
             || !matrix_expand_A(md_ctx, pub->shake128_md, pub->rho, &a_ntt))
         goto err;
-    if (msg_is_mu) {
-        if (msg_enc_len != mu_len)
-            goto err;
-        mu_ptr = (uint8_t *)msg_enc;
-    } else {
-        if (!shake_xof_2(md_ctx, pub->shake256_md, pub->tr, sizeof(pub->tr),
-                         msg_enc, msg_enc_len, mu_ptr, mu_len))
-            goto err;
-    }
-    /* Compute verifiers challenge c_ntt = NTT(SampleInBall(c_tilde) */
+
+    /* Compute verifiers challenge c_ntt = NTT(SampleInBall(c_tilde)) */
     if (!poly_sample_in_ball_ntt(c_ntt, c_tilde_sig, c_tilde_len,
                                  md_ctx, pub->shake256_md, params->tau))
         goto err;
@@ -292,7 +400,7 @@ static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu,
     vector_use_hint(&sig.hint, w_approx, gamma2, w1);
     ossl_ml_dsa_w1_encode(w1, gamma2, w1_encoded, w1_encoded_len);

-    if (!shake_xof_3(md_ctx, pub->shake256_md, mu_ptr, mu_len,
+    if (!shake_xof_3(md_ctx, pub->shake256_md, mu, mu_len,
                      w1_encoded, w1_encoded_len, NULL, 0, c_tilde, c_tilde_len))
         goto err;

@@ -304,61 +412,6 @@ err:
     return ret;
 }

-/**
- * @brief Encode a message
- * See FIPS 204 Algorithm 2 Step 10 (and algorithm 3 Step 5).
- *
- * ML_DSA pure signatures are encoded as M' = 00 || ctx_len || ctx || msg
- * Where ctx is the empty string by default and ctx_len <= 255.
- *
- * Note this code could be shared with SLH_DSA
- *
- * @param msg A message to encode
- * @param msg_len The size of |msg|
- * @param ctx An optional context to add to the message encoding.
- * @param ctx_len The size of |ctx|. It must be in the range 0..255
- * @param encode Use the Pure signature encoding if this is 1, and dont encode
- *               if this value is 0.
- * @param tmp A small buffer that may be used if the message is small.
- * @param tmp_len The size of |tmp|
- * @param out_len The size of the returned encoded buffer.
- * @returns A buffer containing the encoded message. If the passed in
- * |tmp| buffer is big enough to hold the encoded message then it returns |tmp|
- * otherwise it allocates memory which must be freed by the caller. If |encode|
- * is 0 then it returns |msg|. NULL is returned if there is a failure.
- */
-static uint8_t *msg_encode(const uint8_t *msg, size_t msg_len,
-                           const uint8_t *ctx, size_t ctx_len, int encode,
-                           uint8_t *tmp, size_t tmp_len, size_t *out_len)
-{
-    uint8_t *encoded = NULL;
-    size_t encoded_len;
-
-    if (encode == 0) {
-        /* Raw message */
-        *out_len = msg_len;
-        return (uint8_t *)msg;
-    }
-    if (ctx_len > ML_DSA_MAX_CONTEXT_STRING_LEN)
-        return NULL;
-
-    /* Pure encoding */
-    encoded_len = 1 + 1 + ctx_len + msg_len;
-    *out_len = encoded_len;
-    if (encoded_len <= tmp_len) {
-        encoded = tmp;
-    } else {
-        encoded = OPENSSL_malloc(encoded_len);
-        if (encoded == NULL)
-            return NULL;
-    }
-    encoded[0] = 0;
-    encoded[1] = (uint8_t)ctx_len;
-    memcpy(&encoded[2], ctx, ctx_len);
-    memcpy(&encoded[2 + ctx_len], msg, msg_len);
-    return encoded;
-}
-
 /**
  * See FIPS 204 Section 5.2 Algorithm 2 ML-DSA.Sign()
  *
@@ -370,31 +423,43 @@ int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu,
                      const uint8_t *rand, size_t rand_len, int encode,
                      unsigned char *sig, size_t *sig_len, size_t sig_size)
 {
-    int ret = 1;
-    uint8_t m_tmp[1024], *m = m_tmp, *alloced_m = NULL;
-    size_t m_len = 0;
+    EVP_MD_CTX *md_ctx = NULL;
+    uint8_t mu[ML_DSA_MU_BYTES];
+    const uint8_t *mu_ptr = mu;
+    size_t mu_len = sizeof(mu);
+    int ret = 0;

     if (ossl_ml_dsa_key_get_priv(priv) == NULL)
         return 0;
-    if (sig != NULL) {
-        if (sig_size < priv->params->sig_len)
-            return 0;
-        if (msg_is_mu) {
-            m = (uint8_t *)msg;
-            m_len = msg_len;
-        } else {
-            m = msg_encode(msg, msg_len, context, context_len, encode,
-                           m_tmp, sizeof(m_tmp), &m_len);
-            if (m == NULL)
-                return 0;
-            if (m != msg && m != m_tmp)
-                alloced_m = m;
-        }
-        ret = ml_dsa_sign_internal(priv, msg_is_mu, m, m_len, rand, rand_len, sig);
-        OPENSSL_free(alloced_m);
-    }
+
     if (sig_len != NULL)
         *sig_len = priv->params->sig_len;
+
+    if (sig == NULL)
+        return (sig_len != NULL) ? 1 : 0;
+
+    if (sig_size < priv->params->sig_len)
+        return 0;
+
+    if (msg_is_mu) {
+        mu_ptr = msg;
+        mu_len = msg_len;
+    } else {
+        md_ctx = ml_dsa_mu_init(priv, encode, context, context_len);
+        if (md_ctx == NULL)
+            return 0;
+
+        if (!ml_dsa_mu_update(md_ctx, msg, msg_len))
+            goto err;
+
+        if (!ml_dsa_mu_finalize(md_ctx, mu, mu_len))
+            goto err;
+    }
+
+    ret = ml_dsa_sign_internal(priv, mu_ptr, mu_len, rand, rand_len, sig);
+
+err:
+    EVP_MD_CTX_free(md_ctx);
     return ret;
 }

@@ -407,27 +472,32 @@ int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, int msg_is_mu,
                        const uint8_t *context, size_t context_len, int encode,
                        const uint8_t *sig, size_t sig_len)
 {
-    uint8_t *m, *alloced_m = NULL;
-    size_t m_len;
-    uint8_t m_tmp[1024];
+    EVP_MD_CTX *md_ctx = NULL;
+    uint8_t mu[ML_DSA_MU_BYTES];
+    const uint8_t *mu_ptr = mu;
+    size_t mu_len = sizeof(mu);
     int ret = 0;

     if (ossl_ml_dsa_key_get_pub(pub) == NULL)
         return 0;

     if (msg_is_mu) {
-        m = (uint8_t *)msg;
-        m_len = msg_len;
+        mu_ptr = msg;
+        mu_len = msg_len;
     } else {
-        m = msg_encode(msg, msg_len, context, context_len, encode,
-                       m_tmp, sizeof(m_tmp), &m_len);
-        if (m == NULL)
+        md_ctx = ml_dsa_mu_init(pub, encode, context, context_len);
+        if (md_ctx == NULL)
             return 0;
-        if (m != msg && m != m_tmp)
-            alloced_m = m;
+
+        if (!ml_dsa_mu_update(md_ctx, msg, msg_len))
+            goto err;
+
+        if (!ml_dsa_mu_finalize(md_ctx, mu, mu_len))
+            goto err;
     }

-    ret = ml_dsa_verify_internal(pub, msg_is_mu, m, m_len, sig, sig_len);
-    OPENSSL_free(alloced_m);
+    ret = ml_dsa_verify_internal(pub, mu_ptr, mu_len, sig, sig_len);
+err:
+    EVP_MD_CTX_free(md_ctx);
     return ret;
 }