Source code
Revision control
Copy as Markdown
Other Tools
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
#include "tls_agent.h"
#include "databuffer.h"
#include "keyhi.h"
#include "pk11func.h"
#include "ssl.h"
#include "sslerr.h"
#include "sslexp.h"
#include "sslproto.h"
#include "tls_filter.h"
#include "tls_parser.h"
extern "C" {
// This is not something that should make you happy.
#include "libssl_internals.h"
}
#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
#include "gtest_utils.h"
#include "nss_scoped_ptrs.h"
extern std::string g_working_dir_path;
namespace nss_test {
const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
const std::string TlsAgent::kClient = "client"; // both sign and encrypt
const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger
const std::string TlsAgent::kRsa8192 = "rsa8192"; // biggest allowed
const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt
const std::string TlsAgent::kServerRsaSign = "rsa_sign";
const std::string TlsAgent::kServerRsaPss = "rsa_pss";
const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt";
const std::string TlsAgent::kServerEcdsa256 = "ecdsa256";
const std::string TlsAgent::kServerEcdsa384 = "ecdsa384";
const std::string TlsAgent::kServerEcdsa521 = "ecdsa521";
const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa";
const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa";
const std::string TlsAgent::kServerDsa = "dsa";
const std::string TlsAgent::kDelegatorEcdsa256 = "delegator_ecdsa256";
const std::string TlsAgent::kDelegatorRsae2048 = "delegator_rsae2048";
const std::string TlsAgent::kDelegatorRsaPss2048 = "delegator_rsa_pss2048";
static const uint8_t kCannedTls13ServerHello[] = {
0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3,
0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b,
0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76,
0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24,
0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03,
0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9,
0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08,
0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04};
TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var)
: name_(nm),
variant_(var),
role_(rl),
server_key_bits_(0),
adapter_(new DummyPrSocket(role_str(), var)),
ssl_fd_(nullptr),
state_(STATE_INIT),
timer_handle_(nullptr),
falsestart_enabled_(false),
expected_version_(0),
expected_cipher_suite_(0),
expect_client_auth_(false),
expect_ech_(false),
expect_psk_(ssl_psk_none),
can_falsestart_hook_called_(false),
sni_hook_called_(false),
auth_certificate_hook_called_(false),
expected_received_alert_(kTlsAlertCloseNotify),
expected_received_alert_level_(kTlsAlertWarning),
expected_sent_alert_(kTlsAlertCloseNotify),
expected_sent_alert_level_(kTlsAlertWarning),
handshake_callback_called_(false),
resumption_callback_called_(false),
error_code_(0),
send_ctr_(0),
recv_ctr_(0),
expect_readwrite_error_(false),
handshake_callback_(),
auth_certificate_callback_(),
sni_callback_(),
skip_version_checks_(false),
resumption_token_(),
policy_() {
memset(&info_, 0, sizeof(info_));
memset(&csinfo_, 0, sizeof(csinfo_));
SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_);
EXPECT_EQ(SECSuccess, rv);
}
TlsAgent::~TlsAgent() {
if (timer_handle_) {
timer_handle_->Cancel();
}
if (adapter_) {
Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
}
// Add failures manually, if any, so we don't throw in a destructor.
if (expected_received_alert_ != kTlsAlertCloseNotify ||
expected_received_alert_level_ != kTlsAlertWarning) {
ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str();
}
if (expected_sent_alert_ != kTlsAlertCloseNotify ||
expected_sent_alert_level_ != kTlsAlertWarning) {
ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str();
}
}
void TlsAgent::SetState(State s) {
if (state_ == s) return;
LOG("Changing state from " << state_ << " to " << s);
state_ = s;
}
/*static*/ bool TlsAgent::LoadCertificate(const std::string& name,
ScopedCERTCertificate* cert,
ScopedSECKEYPrivateKey* priv) {
cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr));
EXPECT_NE(nullptr, cert);
if (!cert) return false;
EXPECT_NE(nullptr, cert->get());
if (!cert->get()) return false;
priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr));
EXPECT_NE(nullptr, priv);
if (!priv) return false;
EXPECT_NE(nullptr, priv->get());
if (!priv->get()) return false;
return true;
}
// Loads a key pair from the certificate identified by |id|.
/*static*/ bool TlsAgent::LoadKeyPairFromCert(const std::string& name,
ScopedSECKEYPublicKey* pub,
ScopedSECKEYPrivateKey* priv) {
ScopedCERTCertificate cert;
if (!TlsAgent::LoadCertificate(name, &cert, priv)) {
return false;
}
pub->reset(SECKEY_ExtractPublicKey(&cert->subjectPublicKeyInfo));
if (!pub->get()) {
return false;
}
return true;
}
void TlsAgent::DelegateCredential(const std::string& name,
const ScopedSECKEYPublicKey& dc_pub,
SSLSignatureScheme dc_cert_verify_alg,
PRUint32 dc_valid_for, PRTime now,
SECItem* dc) {
ScopedCERTCertificate cert;
ScopedSECKEYPrivateKey cert_priv;
EXPECT_TRUE(TlsAgent::LoadCertificate(name, &cert, &cert_priv))
<< "Could not load delegate certificate: " << name
<< "; test db corrupt?";
EXPECT_EQ(SECSuccess,
SSL_DelegateCredential(cert.get(), cert_priv.get(), dc_pub.get(),
dc_cert_verify_alg, dc_valid_for, now, dc));
}
void TlsAgent::EnableDelegatedCredentials() {
ASSERT_TRUE(EnsureTlsSetup());
SetOption(SSL_ENABLE_DELEGATED_CREDENTIALS, PR_TRUE);
}
void TlsAgent::AddDelegatedCredential(const std::string& dc_name,
SSLSignatureScheme dc_cert_verify_alg,
PRUint32 dc_valid_for, PRTime now) {
ASSERT_TRUE(EnsureTlsSetup());
ScopedSECKEYPublicKey pub;
ScopedSECKEYPrivateKey priv;
EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(dc_name, &pub, &priv));
StackSECItem dc;
TlsAgent::DelegateCredential(name_, pub, dc_cert_verify_alg, dc_valid_for,
now, &dc);
SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr,
nullptr, &dc, priv.get()};
EXPECT_TRUE(ConfigServerCert(name_, true, &extra_data));
}
bool TlsAgent::ConfigServerCert(const std::string& id, bool updateKeyBits,
const SSLExtraServerCertData* serverCertData) {
ScopedCERTCertificate cert;
ScopedSECKEYPrivateKey priv;
if (!TlsAgent::LoadCertificate(id, &cert, &priv)) {
return false;
}
if (updateKeyBits) {
ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get()));
EXPECT_NE(nullptr, pub.get());
if (!pub.get()) return false;
server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get());
}
SECStatus rv =
SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null);
EXPECT_EQ(SECFailure, rv);
rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData,
serverCertData ? sizeof(*serverCertData) : 0);
return rv == SECSuccess;
}
bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
// Don't set up twice
if (ssl_fd_) return true;
NssManagePolicy policyManage(policy_, option_);
ScopedPRFileDesc dummy_fd(adapter_->CreateFD());
EXPECT_NE(nullptr, dummy_fd);
if (!dummy_fd) {
return false;
}
if (adapter_->variant() == ssl_variant_stream) {
ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get()));
} else {
ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get()));
}
EXPECT_NE(nullptr, ssl_fd_);
if (!ssl_fd_) {
return false;
}
dummy_fd.release(); // Now subsumed by ssl_fd_.
SECStatus rv;
if (!skip_version_checks_) {
rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
}
ScopedCERTCertList anchors(CERT_NewCertList());
rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get());
if (rv != SECSuccess) return false;
if (role_ == SERVER) {
EXPECT_TRUE(ConfigServerCert(name_, true));
rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
} else {
rv = SSL_SetURL(ssl_fd(), "server");
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
}
rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
// All these tests depend on having this disabled to start with.
SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_FALSE);
return true;
}
bool TlsAgent::MaybeSetResumptionToken() {
if (!resumption_token_.empty()) {
LOG("setting external resumption token");
SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(),
resumption_token_.size());
// rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR
// if the resumption token was bad (expired/malformed/etc.).
if (expect_psk_ == ssl_psk_resume) {
// Only in case we expect resumption this has to be successful. We might
// not expect resumption due to some reason but the token is totally fine.
EXPECT_EQ(SECSuccess, rv);
}
if (rv != SECSuccess) {
EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError());
resumption_token_.clear();
EXPECT_FALSE(expect_psk_ == ssl_psk_resume);
if (expect_psk_ == ssl_psk_resume) return false;
}
}
return true;
}
void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) {
EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd(), ctx.get()));
}
// Defaults to a Sync callback returning success
void TlsAgent::SetupClientAuth(ClientAuthCallbackType callbackType,
bool callbackSuccess) {
EXPECT_TRUE(EnsureTlsSetup());
ASSERT_EQ(CLIENT, role_);
client_auth_callback_type_ = callbackType;
client_auth_callback_success_ = callbackSuccess;
if (callbackType == ClientAuthCallbackType::kNone && !callbackSuccess) {
// Don't set a callback for this case.
return;
}
EXPECT_EQ(SECSuccess,
SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook,
reinterpret_cast<void*>(this)));
}
void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) {
ScopedCERTDistNames expected(CERT_GetSSLCACerts(nullptr));
ASSERT_EQ(expected->nnames, caNames->nnames);
for (size_t i = 0; i < static_cast<size_t>(expected->nnames); ++i) {
EXPECT_EQ(SECEqual,
SECITEM_CompareItem(&(expected->names[i]), &(caNames->names[i])));
}
}
// Complete processing of Client Certificate Selection
// A No-op if the agent is using synchronous client cert selection.
// Otherwise, calls SSL_ClientCertCallbackComplete.
// kAsyncDelay triggers a call to SSL_ForceHandshake prior to completion to
// ensure that the socket is correctly blocked.
void TlsAgent::ClientAuthCallbackComplete() {
ASSERT_EQ(CLIENT, role_);
if (client_auth_callback_type_ != ClientAuthCallbackType::kAsyncDelay &&
client_auth_callback_type_ != ClientAuthCallbackType::kAsyncImmediate) {
return;
}
client_auth_callback_fired_++;
EXPECT_TRUE(client_auth_callback_awaiting_);
std::cerr << "client: calling SSL_ClientCertCallbackComplete with status "
<< (client_auth_callback_success_ ? "success" : "failed")
<< std::endl;
client_auth_callback_awaiting_ = false;
if (client_auth_callback_type_ == ClientAuthCallbackType::kAsyncDelay) {
std::cerr
<< "Running Handshake prior to running SSL_ClientCertCallbackComplete"
<< std::endl;
SECStatus rv = SSL_ForceHandshake(ssl_fd());
EXPECT_EQ(rv, SECFailure);
EXPECT_EQ(PORT_GetError(), PR_WOULD_BLOCK_ERROR);
}
ScopedCERTCertificate cert;
ScopedSECKEYPrivateKey priv;
if (client_auth_callback_success_) {
ASSERT_TRUE(TlsAgent::LoadCertificate(name(), &cert, &priv));
EXPECT_EQ(SECSuccess,
SSL_ClientCertCallbackComplete(ssl_fd(), SECSuccess,
priv.release(), cert.release()));
} else {
EXPECT_EQ(SECSuccess, SSL_ClientCertCallbackComplete(ssl_fd(), SECFailure,
nullptr, nullptr));
}
}
SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
CERTDistNames* caNames,
CERTCertificate** clientCert,
SECKEYPrivateKey** clientKey) {
TlsAgent* agent = reinterpret_cast<TlsAgent*>(self);
EXPECT_EQ(CLIENT, agent->role_);
agent->client_auth_callback_fired_++;
switch (agent->client_auth_callback_type_) {
case ClientAuthCallbackType::kAsyncDelay:
case ClientAuthCallbackType::kAsyncImmediate:
std::cerr << "Waiting for complete call" << std::endl;
agent->client_auth_callback_awaiting_ = true;
return SECWouldBlock;
case ClientAuthCallbackType::kSync:
case ClientAuthCallbackType::kNone:
// Handle the sync case. None && Success is treated as Sync and Success.
if (!agent->client_auth_callback_success_) {
return SECFailure;
}
ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";
// CheckCertReqAgainstDefaultCAs(caNames);
ScopedCERTCertificate cert;
ScopedSECKEYPrivateKey priv;
if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) {
return SECFailure;
}
*clientCert = cert.release();
*clientKey = priv.release();
return SECSuccess;
}
/* This is unreachable, but some old compilers can't tell that. */
PORT_Assert(0);
PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
return SECFailure;
}
// Increments by 1 for each callback
bool TlsAgent::CheckClientAuthCallbacksCompleted(uint8_t expected) {
EXPECT_EQ(CLIENT, role_);
return expected == client_auth_callback_fired_;
}
bool TlsAgent::GetPeerChainLength(size_t* count) {
SECItemArray* chain = nullptr;
SECStatus rv = SSL_PeerCertificateChainDER(ssl_fd(), &chain);
if (rv != SECSuccess) return false;
*count = chain->len;
SECITEM_FreeArray(chain, true);
return true;
}
void TlsAgent::CheckPeerChainFunctionConsistency() {
SECItemArray* derChain = nullptr;
SECStatus rv = SSL_PeerCertificateChainDER(ssl_fd(), &derChain);
PRErrorCode err1 = PR_GetError();
CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd());
PRErrorCode err2 = PR_GetError();
if (rv != SECSuccess) {
ASSERT_EQ(nullptr, chain);
ASSERT_EQ(nullptr, derChain);
ASSERT_EQ(err1, SSL_ERROR_NO_CERTIFICATE);
ASSERT_EQ(err2, SSL_ERROR_NO_CERTIFICATE);
return;
}
ASSERT_NE(nullptr, chain);
ASSERT_NE(nullptr, derChain);
unsigned int count = 0;
for (PRCList* cursor = PR_NEXT_LINK(&chain->list);
count < derChain->len && cursor != &chain->list;
cursor = PR_NEXT_LINK(cursor)) {
CERTCertListNode* node = (CERTCertListNode*)cursor;
EXPECT_TRUE(
SECITEM_ItemsAreEqual(&node->cert->derCert, &derChain->items[count]));
++count;
}
ASSERT_EQ(count, derChain->len);
SECITEM_FreeArray(derChain, true);
CERT_DestroyCertList(chain);
}
void TlsAgent::CheckCipherSuite(uint16_t suite) {
EXPECT_EQ(csinfo_.cipherSuite, suite);
}
void TlsAgent::RequestClientAuth(bool requireAuth) {
ASSERT_EQ(SERVER, role_);
SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE);
SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE);
EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(
ssl_fd(), &TlsAgent::ClientAuthenticated, this));
expect_client_auth_ = true;
}
void TlsAgent::StartConnect(PRFileDesc* model) {
EXPECT_TRUE(EnsureTlsSetup(model));
SECStatus rv;
rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
SetState(STATE_CONNECTING);
}
void TlsAgent::DisableAllCiphers() {
for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
SECStatus rv =
SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
}
}
// Not actually all groups, just the ones that we are actually willing
// to use.
const std::vector<SSLNamedGroup> kAllDHEGroups = {
ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
ssl_grp_ec_secp521r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072,
ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192,
ssl_grp_kem_xyber768d00, ssl_grp_kem_mlkem768x25519,
};
const std::vector<SSLNamedGroup> kECDHEGroups = {
ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
ssl_grp_ec_secp521r1, ssl_grp_kem_xyber768d00, ssl_grp_kem_mlkem768x25519,
};
const std::vector<SSLNamedGroup> kFFDHEGroups = {
ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096,
ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};
// Defined because the big DHE groups are ridiculously slow.
const std::vector<SSLNamedGroup> kFasterDHEGroups = {
ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_kem_xyber768d00,
ssl_grp_kem_mlkem768x25519,
};
const std::vector<SSLNamedGroup> kEcdhHybridGroups = {
ssl_grp_kem_xyber768d00,
ssl_grp_kem_mlkem768x25519,
};
void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) {
EXPECT_TRUE(EnsureTlsSetup());
for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
SSLCipherSuiteInfo csinfo;
SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
sizeof(csinfo));
ASSERT_EQ(SECSuccess, rv);
EXPECT_EQ(sizeof(csinfo), csinfo.length);
if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) {
rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
}
}
void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) {
switch (kea) {
case ssl_kea_dh:
ConfigNamedGroups(kFFDHEGroups);
break;
case ssl_kea_ecdh:
ConfigNamedGroups(kECDHEGroups);
break;
case ssl_kea_ecdh_hybrid:
ConfigNamedGroups(kEcdhHybridGroups);
break;
default:
break;
}
}
void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) {
if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa ||
authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) {
ConfigNamedGroups(kECDHEGroups);
}
}
void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) {
EXPECT_TRUE(EnsureTlsSetup());
for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
SSLCipherSuiteInfo csinfo;
SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
sizeof(csinfo));
ASSERT_EQ(SECSuccess, rv);
if ((csinfo.authType == authType) ||
(csinfo.keaType == ssl_kea_tls13_any)) {
rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
}
}
void TlsAgent::EnableSingleCipher(uint16_t cipher) {
DisableAllCiphers();
SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) {
EXPECT_TRUE(EnsureTlsSetup());
SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size());
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::Set0RttEnabled(bool en) {
SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
}
void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
vrange_.min = minver;
vrange_.max = maxver;
if (ssl_fd()) {
SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
EXPECT_EQ(SECSuccess, rv);
}
}
SECStatus ResumptionTokenCallback(PRFileDesc* fd,
const PRUint8* resumptionToken,
unsigned int len, void* ctx) {
EXPECT_NE(nullptr, resumptionToken);
if (!resumptionToken) {
return SECFailure;
}
std::vector<uint8_t> new_token(resumptionToken, resumptionToken + len);
reinterpret_cast<TlsAgent*>(ctx)->SetResumptionToken(new_token);
reinterpret_cast<TlsAgent*>(ctx)->SetResumptionCallbackCalled();
return SECSuccess;
}
void TlsAgent::SetResumptionTokenCallback() {
EXPECT_TRUE(EnsureTlsSetup());
SECStatus rv =
SSL_SetResumptionTokenCallback(ssl_fd(), ResumptionTokenCallback, this);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) {
*minver = vrange_.min;
*maxver = vrange_.max;
}
void TlsAgent::SetExpectedVersion(uint16_t ver) { expected_version_ = ver; }
void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; }
void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; }
void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; }
void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes,
size_t count) {
EXPECT_TRUE(EnsureTlsSetup());
EXPECT_LE(count, SSL_SignatureMaxCount());
EXPECT_EQ(SECSuccess,
SSL_SignatureSchemePrefSet(ssl_fd(), schemes,
static_cast<unsigned int>(count)));
EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0))
<< "setting no schemes should fail and do nothing";
std::vector<SSLSignatureScheme> configuredSchemes(count);
unsigned int configuredCount;
EXPECT_EQ(SECFailure,
SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1))
<< "get schemes, schemes is nullptr";
EXPECT_EQ(SECFailure,
SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0],
&configuredCount, 0))
<< "get schemes, too little space";
EXPECT_EQ(SECFailure,
SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr,
configuredSchemes.size()))
<< "get schemes, countOut is nullptr";
EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet(
ssl_fd(), &configuredSchemes[0], &configuredCount,
configuredSchemes.size()));
// SignatureSchemePrefSet drops unsupported algorithms silently, so the
// number that are configured might be fewer.
EXPECT_LE(configuredCount, count);
unsigned int i = 0;
for (unsigned int j = 0; j < count && i < configuredCount; ++j) {
if (i < configuredCount && schemes[j] == configuredSchemes[i]) {
++i;
}
}
EXPECT_EQ(i, configuredCount) << "schemes in use were all set";
}
void TlsAgent::CheckKEA(SSLKEAType kea, SSLNamedGroup kea_group,
size_t kea_size) const {
EXPECT_EQ(STATE_CONNECTED, state_);
EXPECT_EQ(kea, info_.keaType);
if (kea_size == 0) {
switch (kea_group) {
case ssl_grp_ec_curve25519:
case ssl_grp_kem_xyber768d00:
case ssl_grp_kem_mlkem768x25519:
kea_size = 255;
break;
case ssl_grp_ec_secp256r1:
kea_size = 256;
break;
case ssl_grp_ec_secp384r1:
kea_size = 384;
break;
case ssl_grp_ffdhe_2048:
kea_size = 2048;
break;
case ssl_grp_ffdhe_3072:
kea_size = 3072;
break;
case ssl_grp_ffdhe_custom:
break;
default:
if (kea == ssl_kea_rsa) {
kea_size = server_key_bits_;
} else {
EXPECT_TRUE(false) << "need to update group sizes";
}
}
}
if (kea_group != ssl_grp_ffdhe_custom) {
EXPECT_EQ(kea_size, info_.keaKeyBits);
EXPECT_EQ(kea_group, info_.keaGroup);
}
}
void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const {
if (kea_group != ssl_grp_ffdhe_custom) {
EXPECT_EQ(kea_group, info_.originalKeaGroup);
}
}
void TlsAgent::CheckAuthType(SSLAuthType auth,
SSLSignatureScheme sig_scheme) const {
EXPECT_EQ(STATE_CONNECTED, state_);
EXPECT_EQ(auth, info_.authType);
if (auth != ssl_auth_psk) {
EXPECT_EQ(server_key_bits_, info_.authKeyBits);
}
if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
switch (auth) {
case ssl_auth_rsa_sign:
sig_scheme = ssl_sig_rsa_pkcs1_sha1md5;
break;
case ssl_auth_ecdsa:
sig_scheme = ssl_sig_ecdsa_sha1;
break;
default:
break;
}
}
EXPECT_EQ(sig_scheme, info_.signatureScheme);
if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) {
return;
}
// Check authAlgorithm, which is the old value for authType. This is a second
// switch statement because default label is different.
switch (auth) {
case ssl_auth_rsa_sign:
case ssl_auth_rsa_pss:
EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
<< "authAlgorithm for RSA is always decrypt";
break;
case ssl_auth_ecdh_rsa:
EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
<< "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)";
break;
case ssl_auth_ecdh_ecdsa:
EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm)
<< "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)";
break;
default:
EXPECT_EQ(auth, csinfo_.authAlgorithm)
<< "authAlgorithm is (usually) the same as authType";
break;
}
}
void TlsAgent::EnableFalseStart() {
EXPECT_TRUE(EnsureTlsSetup());
falsestart_enabled_ = true;
EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback(
ssl_fd(), CanFalseStartCallback, this));
SetOption(SSL_ENABLE_FALSE_START, PR_TRUE);
}
void TlsAgent::ExpectEch(bool expected) { expect_ech_ = expected; }
void TlsAgent::ExpectPsk(SSLPskType psk) { expect_psk_ = psk; }
void TlsAgent::ExpectResumption() { expect_psk_ = ssl_psk_resume; }
void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
EXPECT_TRUE(EnsureTlsSetup());
EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len));
}
void TlsAgent::AddPsk(const ScopedPK11SymKey& psk, std::string label,
SSLHashType hash, uint16_t zeroRttSuite) {
EXPECT_TRUE(EnsureTlsSetup());
EXPECT_EQ(SECSuccess, SSL_AddExternalPsk0Rtt(
ssl_fd(), psk.get(),
reinterpret_cast<const uint8_t*>(label.data()),
label.length(), hash, zeroRttSuite, 1000));
}
void TlsAgent::RemovePsk(std::string label) {
EXPECT_EQ(SECSuccess,
SSL_RemoveExternalPsk(
ssl_fd(), reinterpret_cast<const uint8_t*>(label.data()),
label.length()));
}
void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
const std::string& expected) const {
SSLNextProtoState alpn_state;
char chosen[10];
unsigned int chosen_len;
SECStatus rv = SSL_GetNextProto(ssl_fd(), &alpn_state,
reinterpret_cast<unsigned char*>(chosen),
&chosen_len, sizeof(chosen));
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ(expected_state, alpn_state);
if (alpn_state == SSL_NEXT_PROTO_NO_SUPPORT) {
EXPECT_EQ("", expected);
} else {
EXPECT_NE("", expected);
EXPECT_EQ(expected, std::string(chosen, chosen_len));
}
}
void TlsAgent::CheckEpochs(uint16_t expected_read,
uint16_t expected_write) const {
uint16_t read_epoch = 0;
uint16_t write_epoch = 0;
EXPECT_EQ(SECSuccess,
SSL_GetCurrentEpoch(ssl_fd(), &read_epoch, &write_epoch));
EXPECT_EQ(expected_read, read_epoch) << role_str() << " read epoch";
EXPECT_EQ(expected_write, write_epoch) << role_str() << " write epoch";
}
void TlsAgent::EnableSrtp() {
EXPECT_TRUE(EnsureTlsSetup());
const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80,
SRTP_AES128_CM_HMAC_SHA1_32};
EXPECT_EQ(SECSuccess,
SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers)));
}
void TlsAgent::CheckSrtp() const {
uint16_t actual;
EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual));
EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
}
void TlsAgent::CheckErrorCode(int32_t expected) const {
EXPECT_EQ(STATE_ERROR, state_);
EXPECT_EQ(expected, error_code_)
<< "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
<< PORT_ErrorToName(expected) << std::endl;
}
static uint8_t GetExpectedAlertLevel(uint8_t alert) {
if (alert == kTlsAlertCloseNotify) {
return kTlsAlertWarning;
}
return kTlsAlertFatal;
}
void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) {
expected_received_alert_ = alert;
if (level == 0) {
expected_received_alert_level_ = GetExpectedAlertLevel(alert);
} else {
expected_received_alert_level_ = level;
}
}
void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) {
expected_sent_alert_ = alert;
if (level == 0) {
expected_sent_alert_level_ = GetExpectedAlertLevel(alert);
} else {
expected_sent_alert_level_ = level;
}
}
void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) {
LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal")
<< " alert " << (sent ? "sent" : "received") << ": "
<< static_cast<int>(alert->description));
auto& expected = sent ? expected_sent_alert_ : expected_received_alert_;
auto& expected_level =
sent ? expected_sent_alert_level_ : expected_received_alert_level_;
/* Silently pass close_notify in case the test has already ended. */
if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning &&
alert->description == expected && alert->level == expected_level) {
return;
}
EXPECT_EQ(expected, alert->description);
EXPECT_EQ(expected_level, alert->level);
expected = kTlsAlertCloseNotify;
expected_level = kTlsAlertWarning;
}
void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
ASSERT_EQ(0, error_code_);
WAIT_(error_code_ != 0, delay);
EXPECT_EQ(expected, error_code_)
<< "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
<< PORT_ErrorToName(expected) << std::endl;
}
void TlsAgent::CheckPreliminaryInfo() {
SSLPreliminaryChannelInfo preinfo;
EXPECT_EQ(SECSuccess,
SSL_GetPreliminaryChannelInfo(ssl_fd(), &preinfo, sizeof(preinfo)));
EXPECT_EQ(sizeof(preinfo), preinfo.length);
EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version);
// A version of 0 is invalid and indicates no expectation. This value is
// initialized to 0 so that tests that don't explicitly set an expected
// version can negotiate a version.
if (!expected_version_) {
expected_version_ = preinfo.protocolVersion;
}
EXPECT_EQ(expected_version_, preinfo.protocolVersion);
// As with the version; 0 is the null cipher suite (and also invalid).
if (!expected_cipher_suite_) {
expected_cipher_suite_ = preinfo.cipherSuite;
}
EXPECT_EQ(expected_cipher_suite_, preinfo.cipherSuite);
}
// Check that all the expected callbacks have been called.
void TlsAgent::CheckCallbacks() const {
// If false start happens, the handshake is reported as being complete at the
// point that false start happens.
if (expect_psk_ == ssl_psk_resume || !falsestart_enabled_) {
EXPECT_TRUE(handshake_callback_called_);
}
// These callbacks shouldn't fire if we are resuming, except on TLS 1.3.
if (role_ == SERVER) {
PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn);
EXPECT_EQ(((expect_psk_ != ssl_psk_resume && have_sni) ||
expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3),
sni_hook_called_);
} else {
EXPECT_EQ(expect_psk_ == ssl_psk_none, auth_certificate_hook_called_);
// Note that this isn't unconditionally called, even with false start on.
// But the callback is only skipped if a cipher that is ridiculously weak
// (80 bits) is chosen. Don't test that: plan to remove bad ciphers.
EXPECT_EQ(falsestart_enabled_ && expect_psk_ != ssl_psk_resume,
can_falsestart_hook_called_);
}
}
void TlsAgent::ResetPreliminaryInfo() {
expected_version_ = 0;
expected_cipher_suite_ = 0;
}
void TlsAgent::UpdatePreliminaryChannelInfo() {
SECStatus rv =
SSL_GetPreliminaryChannelInfo(ssl_fd(), &pre_info_, sizeof(pre_info_));
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ(sizeof(pre_info_), pre_info_.length);
}
void TlsAgent::ValidateCipherSpecs() {
PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd());
// We use one ciphersuite in each direction.
PRInt32 expected = 2;
if (variant_ == ssl_variant_datagram) {
// For DTLS 1.3, the client retains the cipher spec for early data and the
// handshake so that it can retransmit EndOfEarlyData and its final flight.
// It also retains the handshake read cipher spec so that it can read ACKs
// from the server. The server retains the handshake read cipher spec so it
// can read the client's retransmitted Finished.
if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
if (role_ == CLIENT) {
expected = info_.earlyDataAccepted ? 5 : 4;
} else {
expected = 3;
}
} else {
// For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec
// until the holddown timer runs down.
if (expect_psk_ == ssl_psk_resume) {
if (role_ == CLIENT) {
expected = 3;
}
} else {
if (role_ == SERVER) {
expected = 3;
}
}
}
}
// This function will be run before the handshake completes if false start is
// enabled. In that case, the client will still be reading cleartext, but
// will have a spec prepared for reading ciphertext. With DTLS, the client
// will also have a spec retained for retransmission of handshake messages.
if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) {
EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_);
expected = (variant_ == ssl_variant_datagram) ? 4 : 3;
}
EXPECT_EQ(expected, cipherSpecs);
if (expected != cipherSpecs) {
SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd());
}
}
void TlsAgent::Connected() {
if (state_ == STATE_CONNECTED) {
return;
}
LOG("Handshake success");
CheckPreliminaryInfo();
CheckCallbacks();
SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_));
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ(sizeof(info_), info_.length);
EXPECT_EQ(expect_psk_ == ssl_psk_resume, info_.resumed == PR_TRUE);
EXPECT_EQ(expect_psk_, info_.pskType);
EXPECT_EQ(expect_ech_, info_.echAccepted);
// Preliminary values are exposed through callbacks during the handshake.
// If either expected values were set or the callbacks were called, check
// that the final values are correct.
UpdatePreliminaryChannelInfo();
EXPECT_EQ(expected_version_, info_.protocolVersion);
EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite);
rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_));
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ(sizeof(csinfo_), csinfo_.length);
ValidateCipherSpecs();
SetState(STATE_CONNECTED);
}
void TlsAgent::CheckClientAuthCompleted(uint8_t handshakes) {
EXPECT_FALSE(client_auth_callback_awaiting_);
switch (client_auth_callback_type_) {
case ClientAuthCallbackType::kNone:
if (!client_auth_callback_success_) {
EXPECT_TRUE(CheckClientAuthCallbacksCompleted(0));
break;
}
case ClientAuthCallbackType::kSync:
EXPECT_TRUE(CheckClientAuthCallbacksCompleted(handshakes));
break;
case ClientAuthCallbackType::kAsyncDelay:
case ClientAuthCallbackType::kAsyncImmediate:
EXPECT_TRUE(CheckClientAuthCallbacksCompleted(2 * handshakes));
break;
}
}
void TlsAgent::EnableExtendedMasterSecret() {
SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
}
void TlsAgent::CheckExtendedMasterSecret(bool expected) {
if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) {
expected = PR_TRUE;
}
ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE)
<< "unexpected extended master secret state for " << name_;
}
void TlsAgent::CheckEarlyDataAccepted(bool expected) {
if (version() < SSL_LIBRARY_VERSION_TLS_1_3) {
expected = false;
}
ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE)
<< "unexpected early data state for " << name_;
}
void TlsAgent::CheckSecretsDestroyed() {
ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd()));
}
void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) {
ASSERT_TRUE(EnsureTlsSetup());
SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), ver);
ASSERT_EQ(SECSuccess, rv);
}
void TlsAgent::Handshake() {
LOGV("Handshake");
SECStatus rv = SSL_ForceHandshake(ssl_fd());
if (client_auth_callback_awaiting_) {
ClientAuthCallbackComplete();
rv = SSL_ForceHandshake(ssl_fd());
}
if (rv == SECSuccess) {
Connected();
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
&TlsAgent::ReadableCallback);
return;
}
int32_t err = PR_GetError();
if (err == PR_WOULD_BLOCK_ERROR) {
LOGV("Would have blocked");
if (variant_ == ssl_variant_datagram) {
if (timer_handle_) {
timer_handle_->Cancel();
timer_handle_ = nullptr;
}
PRIntervalTime timeout;
rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout);
if (rv == SECSuccess) {
Poller::Instance()->SetTimer(
timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_);
}
}
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
&TlsAgent::ReadableCallback);
return;
}
if (err != 0) {
LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": "
<< PORT_ErrorToString(err));
}
error_code_ = err;
SetState(STATE_ERROR);
}
void TlsAgent::PrepareForRenegotiate() {
EXPECT_EQ(STATE_CONNECTED, state_);
SetState(STATE_CONNECTING);
}
void TlsAgent::StartRenegotiate() {
PrepareForRenegotiate();
SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::SendDirect(const DataBuffer& buf) {
LOG("Send Direct " << buf);
auto peer = adapter_->peer().lock();
if (peer) {
peer->PacketReceived(buf);
} else {
LOG("Send Direct peer absent");
}
}
void TlsAgent::SendRecordDirect(const TlsRecord& record) {
DataBuffer buf;
auto rv = record.header.Write(&buf, 0, record.buffer);
EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv);
SendDirect(buf);
}
static bool ErrorIsFatal(PRErrorCode code) {
return code != PR_WOULD_BLOCK_ERROR && code != SSL_ERROR_RX_SHORT_DTLS_READ;
}
void TlsAgent::SendData(size_t bytes, size_t blocksize) {
uint8_t block[16385]; // One larger than the maximum record size.
ASSERT_LE(blocksize, sizeof(block));
while (bytes) {
size_t tosend = std::min(blocksize, bytes);
for (size_t i = 0; i < tosend; ++i) {
block[i] = 0xff & send_ctr_;
++send_ctr_;
}
SendBuffer(DataBuffer(block, tosend));
bytes -= tosend;
}
}
void TlsAgent::SendBuffer(const DataBuffer& buf) {
LOGV("Writing " << buf.len() << " bytes");
int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len());
if (expect_readwrite_error_) {
EXPECT_GT(0, rv);
EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_);
error_code_ = PR_GetError();
expect_readwrite_error_ = false;
} else {
ASSERT_EQ(buf.len(), static_cast<size_t>(rv));
}
}
bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
uint64_t seq, uint8_t ct,
const DataBuffer& buf) {
// Ensure that we are doing TLS 1.3.
EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3);
if (variant_ != ssl_variant_datagram) {
ADD_FAILURE();
return false;
}
LOGV("Encrypting " << buf.len() << " bytes");
uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
kCtDtlsCiphertextLengthPresent;
TlsRecordHeader header(variant_, expected_version_, dtls13_ct, seq);
TlsRecordHeader out_header(header);
DataBuffer padded = buf;
padded.Write(padded.len(), ct, 1);
DataBuffer ciphertext;
if (!spec->Protect(header, padded, &ciphertext, &out_header)) {
return false;
}
DataBuffer record;
auto rv = out_header.Write(&record, 0, ciphertext);
EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv);
SendDirect(record);
return true;
}
void TlsAgent::ReadBytes(size_t amount) {
uint8_t block[16384];
size_t remaining = amount;
while (remaining > 0) {
int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block)));
LOGV("ReadBytes " << rv);
if (rv > 0) {
size_t count = static_cast<size_t>(rv);
for (size_t i = 0; i < count; ++i) {
ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
recv_ctr_++;
}
remaining -= rv;
} else {
PRErrorCode err = 0;
if (rv < 0) {
err = PR_GetError();
if (err != 0) {
LOG("Read error " << PORT_ErrorToName(err) << ": "
<< PORT_ErrorToString(err));
}
if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) {
if (ErrorIsFatal(err)) {
SetState(STATE_ERROR);
}
error_code_ = err;
expect_readwrite_error_ = false;
}
}
if (err != 0 && ErrorIsFatal(err)) {
// If we hit a fatal error, we're done.
remaining = 0;
}
break;
}
}
// If closed, then don't bother waiting around.
if (remaining) {
LOGV("Re-arming");
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
&TlsAgent::ReadableCallback);
}
}
void TlsAgent::ResetSentBytes(size_t bytes) { send_ctr_ = bytes; }
void TlsAgent::SetOption(int32_t option, int value) {
ASSERT_TRUE(EnsureTlsSetup());
EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value));
}
void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
SetOption(SSL_ENABLE_SESSION_TICKETS,
mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
}
void TlsAgent::EnableECDHEServerKeyReuse() {
ASSERT_EQ(TlsAgent::SERVER, role_);
SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE);
}
static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"};
::testing::internal::ParamGenerator<std::string>
TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr);
void TlsAgentTestBase::SetUp() {
SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
}
void TlsAgentTestBase::TearDown() {
agent_ = nullptr;
SSL_ClearSessionCache();
SSL_ShutdownServerSessionIDCache();
}
void TlsAgentTestBase::Reset(const std::string& server_name) {
agent_.reset(
new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name,
role_, variant_));
if (version_) {
agent_->SetVersionRange(version_, version_);
}
agent_->adapter()->SetPeer(sink_adapter_);
agent_->StartConnect();
}
void TlsAgentTestBase::EnsureInit() {
if (!agent_) {
Reset();
}
const std::vector<SSLNamedGroup> groups = {
ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
ssl_grp_ffdhe_2048};
agent_->ConfigNamedGroups(groups);
}
void TlsAgentTestBase::ExpectAlert(uint8_t alert) {
EnsureInit();
agent_->ExpectSendAlert(alert);
}
void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
TlsAgent::State expected_state,
int32_t error_code) {
std::cerr << "Process message: " << buffer << std::endl;
EnsureInit();
agent_->adapter()->PacketReceived(buffer);
agent_->Handshake();
ASSERT_EQ(expected_state, agent_->state());
if (expected_state == TlsAgent::STATE_ERROR) {
ASSERT_EQ(error_code, agent_->error_code());
}
}
void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type,
uint16_t version, const uint8_t* buf,
size_t len, DataBuffer* out,
uint64_t sequence_number) {
// Fixup the content type for DTLSCiphertext
if (variant == ssl_variant_datagram &&
version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
type == ssl_ct_application_data) {
type = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
kCtDtlsCiphertextLengthPresent;
}
size_t index = 0;
if (variant == ssl_variant_stream) {
index = out->Write(index, type, 1);
index = out->Write(index, version, 2);
} else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
(type & kCtDtlsCiphertextMask) == kCtDtlsCiphertext) {
uint32_t epoch = (sequence_number >> 48) & 0x3;
index = out->Write(index, type | epoch, 1);
uint32_t seqno = sequence_number & ((1ULL << 16) - 1);
index = out->Write(index, seqno, 2);
} else {
index = out->Write(index, type, 1);
index = out->Write(index, TlsVersionToDtlsVersion(version), 2);
index = out->Write(index, sequence_number >> 32, 4);
index = out->Write(index, sequence_number & PR_UINT32_MAX, 4);
}
index = out->Write(index, len, 2);
out->Write(index, buf, len);
}
void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version,
const uint8_t* buf, size_t len,
DataBuffer* out, uint64_t seq_num) const {
MakeRecord(variant_, type, version, buf, len, out, seq_num);
}
void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type,
const uint8_t* data, size_t hs_len,
DataBuffer* out,
uint64_t seq_num) const {
return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0,
0);
}
void TlsAgentTestBase::MakeHandshakeMessageFragment(
uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out,
uint64_t seq_num, uint32_t fragment_offset,
uint32_t fragment_length) const {
size_t index = 0;
if (!fragment_length) fragment_length = hs_len;
index = out->Write(index, hs_type, 1); // Handshake record type.
index = out->Write(index, hs_len, 3); // Handshake length
if (variant_ == ssl_variant_datagram) {
index = out->Write(index, seq_num, 2);
index = out->Write(index, fragment_offset, 3);
index = out->Write(index, fragment_length, 3);
}
if (data) {
index = out->Write(index, data, fragment_length);
} else {
for (size_t i = 0; i < fragment_length; ++i) {
index = out->Write(index, 1, 1);
}
}
}
void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type,
size_t hs_len,
DataBuffer* out) {
size_t index = 0;
index = out->Write(index, ssl_ct_handshake, 1); // Content Type
index = out->Write(index, 3, 1); // Version high
index = out->Write(index, 1, 1); // Version low
index = out->Write(index, 4 + hs_len, 2); // Length
index = out->Write(index, hs_type, 1); // Handshake record type.
index = out->Write(index, hs_len, 3); // Handshake length
for (size_t i = 0; i < hs_len; ++i) {
index = out->Write(index, 1, 1);
}
}
DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() {
DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello));
if (variant_ == ssl_variant_datagram) {
sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2);
// The version should be at the end.
uint32_t v;
EXPECT_TRUE(sh.Read(sh.len() - 2, 2, &v));
EXPECT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_3), v);
sh.Write(sh.len() - 2, SSL_LIBRARY_VERSION_DTLS_1_3_WIRE, 2);
}
return sh;
}
} // namespace nss_test