Commit efff9b8278 for openssl.org

commit efff9b82788c561d7033ed043f22b9a11d04b29f
Author: Matt Caswell <matt@openssl.foundation>
Date:   Tue May 19 08:35:47 2026 +0100

    Reapply "Preserve connection custom extensions in SSL_set_SSL_CTX()"

    This reverts commit 7836b7d5b6a6b27a441c4e4c8564be6b270580c4.

    Fixes #31193

    Reviewed-by: Bob Beck <beck@openssl.org>
    Reviewed-by: Tomas Mraz <tomas@openssl.foundation>
    MergeDate: Mon Jun  8 07:51:54 2026
    (Merged from https://github.com/openssl/openssl/pull/31238)

diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index 66289ff075..c89f3e4017 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -5706,6 +5706,8 @@ SSL_CTX *SSL_set_SSL_CTX(SSL *ssl, SSL_CTX *ctx)
     new_cert = ssl_cert_dup(ctx->cert);
     if (new_cert == NULL)
         goto err;
+    if (!custom_exts_copy_conn(&new_cert->custext, &sc->cert->custext))
+        goto err;
     if (!custom_exts_copy_flags(&new_cert->custext, &sc->cert->custext))
         goto err;

diff --git a/ssl/ssl_local.h b/ssl/ssl_local.h
index 4a5dd98e7b..62b0017b9c 100644
--- a/ssl/ssl_local.h
+++ b/ssl/ssl_local.h
@@ -2104,6 +2104,11 @@ typedef struct {
  * corresponding ServerHello extension.
  */
 #define SSL_EXT_FLAG_SENT 0x2
+/*
+ * Indicates an extension that was set on SSL object and needs to be
+ * preserved when switching SSL contexts.
+ */
+#define SSL_EXT_FLAG_CONN 0x4

 typedef struct {
     custom_ext_method *meths;
@@ -2917,6 +2922,8 @@ __owur int custom_ext_add(SSL_CONNECTION *s, int context, WPACKET *pkt, X509 *x,

 __owur int custom_exts_copy(custom_ext_methods *dst,
     const custom_ext_methods *src);
+__owur int custom_exts_copy_conn(custom_ext_methods *dst,
+    const custom_ext_methods *src);
 __owur int custom_exts_copy_flags(custom_ext_methods *dst,
     const custom_ext_methods *src);
 void custom_exts_free(custom_ext_methods *exts);
diff --git a/ssl/statem/extensions_cust.c b/ssl/statem/extensions_cust.c
index c4be3d5318..881e42be07 100644
--- a/ssl/statem/extensions_cust.c
+++ b/ssl/statem/extensions_cust.c
@@ -106,7 +106,7 @@ void custom_ext_init(custom_ext_methods *exts)
     custom_ext_method *meth = exts->meths;

     for (i = 0; i < exts->meths_count; i++, meth++)
-        meth->ext_flags = 0;
+        meth->ext_flags &= ~(SSL_EXT_FLAG_SENT | SSL_EXT_FLAG_RECEIVED);
 }

 /* Pass received custom extension data to the application for parsing. */
@@ -390,6 +390,56 @@ int custom_exts_copy(custom_ext_methods *dst, const custom_ext_methods *src)
     return 1;
 }

+/* Copy custom extensions that were set on connection */
+int custom_exts_copy_conn(custom_ext_methods *dst,
+    const custom_ext_methods *src)
+{
+    size_t i;
+    int err = 0;
+
+    if (src->meths_count > 0) {
+        size_t meths_count = 0;
+
+        for (i = 0; i < src->meths_count; i++)
+            if ((src->meths[i].ext_flags & SSL_EXT_FLAG_CONN) != 0)
+                meths_count++;
+
+        if (meths_count > 0) {
+            custom_ext_method *methdst = OPENSSL_realloc(dst->meths,
+                (dst->meths_count + meths_count) * sizeof(custom_ext_method));
+
+            if (methdst == NULL)
+                return 0;
+
+            for (i = 0; i < dst->meths_count; i++)
+                custom_ext_copy_old_cb(&methdst[i], &dst->meths[i], &err);
+
+            dst->meths = methdst;
+            methdst += dst->meths_count;
+
+            for (i = 0; i < src->meths_count; i++) {
+                custom_ext_method *methsrc = &src->meths[i];
+
+                if ((methsrc->ext_flags & SSL_EXT_FLAG_CONN) == 0)
+                    continue;
+
+                memcpy(methdst, methsrc, sizeof(custom_ext_method));
+                custom_ext_copy_old_cb(methdst, methsrc, &err);
+                methdst++;
+            }
+
+            dst->meths_count += meths_count;
+        }
+    }
+
+    if (err) {
+        custom_exts_free(dst);
+        return 0;
+    }
+
+    return 1;
+}
+
 void custom_exts_free(custom_ext_methods *exts)
 {
     size_t i;
@@ -478,6 +528,7 @@ int ossl_tls_add_custom_ext_intern(SSL_CTX *ctx, custom_ext_methods *exts,
     meth->add_cb = add_cb;
     meth->free_cb = free_cb;
     meth->ext_type = ext_type;
+    meth->ext_flags = (ctx == NULL) ? SSL_EXT_FLAG_CONN : 0;
     meth->add_arg = add_arg;
     meth->parse_arg = parse_arg;
     exts->meths_count++;
diff --git a/test/sslapitest.c b/test/sslapitest.c
index e8c5da352f..efa2d3cd43 100644
--- a/test/sslapitest.c
+++ b/test/sslapitest.c
@@ -14079,10 +14079,11 @@ static int alert_cb(SSL *s, unsigned char alert_code, void *arg)
  * Test 1: Force a failure
  * Test 3: Use a CCM based ciphersuite
  * Test 4: fail yield_secret_cb to see double free
+ * Test 5: Normal run with SNI
  */
 static int test_quic_tls(int idx)
 {
-    SSL_CTX *sctx = NULL, *cctx = NULL;
+    SSL_CTX *sctx = NULL, *sctx2 = NULL, *cctx = NULL;
     SSL *serverssl = NULL, *clientssl = NULL;
     int testresult = 0;
     OSSL_DISPATCH qtdis[] = {
@@ -14110,6 +14111,7 @@ static int test_quic_tls(int idx)
     if (idx == 4)
         qtdis[3].function = (void (*)(void))yield_secret_cb_fail;

+    snicb = 0;
     memset(secret_history, 0, sizeof(secret_history));
     secret_history_idx = 0;
     memset(&sdata, 0, sizeof(sdata));
@@ -14124,6 +14126,18 @@ static int test_quic_tls(int idx)
             &sctx, &cctx, cert, privkey)))
         goto end;

+    if (idx == 5) {
+        if (!TEST_true(create_ssl_ctx_pair(libctx, TLS_server_method(), NULL,
+                TLS1_3_VERSION, 0,
+                &sctx2, NULL, cert, privkey)))
+            goto end;
+
+        /* Set up SNI */
+        if (!TEST_true(SSL_CTX_set_tlsext_servername_callback(sctx, sni_cb))
+            || !TEST_true(SSL_CTX_set_tlsext_servername_arg(sctx, sctx2)))
+            goto end;
+    }
+
     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl, NULL,
             NULL)))
         goto end;
@@ -14164,6 +14178,12 @@ static int test_quic_tls(int idx)
         goto end;
     }

+    /* We should have had the SNI callback called exactly once */
+    if (idx == 5) {
+        if (!TEST_int_eq(snicb, 1))
+            goto end;
+    }
+
     /* Check no problems during the handshake */
     if (!TEST_false(sdata.alert)
         || !TEST_false(cdata.alert)
@@ -14205,6 +14225,7 @@ static int test_quic_tls(int idx)
 end:
     SSL_free(serverssl);
     SSL_free(clientssl);
+    SSL_CTX_free(sctx2);
     SSL_CTX_free(sctx);
     SSL_CTX_free(cctx);

@@ -15258,7 +15279,7 @@ int setup_tests(void)
 #endif
     ADD_ALL_TESTS(test_alpn, 4);
 #if !defined(OSSL_NO_USABLE_TLS1_3)
-    ADD_ALL_TESTS(test_quic_tls, 5);
+    ADD_ALL_TESTS(test_quic_tls, 6);
     ADD_TEST(test_quic_tls_early_data);
 #endif
     ADD_ALL_TESTS(test_no_renegotiation, 2);