Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions ci-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Instant;

use crate::benchmark::{
get_reported_instr_count, validate_benchmarks, Benchmark, BenchmarkKind, BenchmarkParams,
ResumptionKind,
};
use crate::callgrind::{CallgrindRunner, CountInstructions};
use crate::util::async_io::{self, AsyncRead, AsyncWrite};
use crate::util::transport::{
read_handshake_message, read_plaintext_to_end_bounded, send_handshake_message,
write_all_plaintext_bounded,
};
use crate::util::KeyType;
use anyhow::Context;
use async_trait::async_trait;
use clap::{Parser, Subcommand, ValueEnum};
Expand All @@ -26,17 +37,6 @@ use watfaq_rustls::{
CipherSuite, ClientConfig, ClientConnection, HandshakeKind, ProtocolVersion, RootCertStore,
ServerConfig, ServerConnection,
};
use crate::benchmark::{
get_reported_instr_count, validate_benchmarks, Benchmark, BenchmarkKind, BenchmarkParams,
ResumptionKind,
};
use crate::callgrind::{CallgrindRunner, CountInstructions};
use crate::util::async_io::{self, AsyncRead, AsyncWrite};
use crate::util::transport::{
read_handshake_message, read_plaintext_to_end_bounded, send_handshake_message,
write_all_plaintext_bounded,
};
use crate::util::KeyType;

mod benchmark;
mod callgrind;
Expand Down
3 changes: 1 addition & 2 deletions ci-bench/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,7 @@ pub mod async_io {

fn poll(mut self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
if !self.writer.inner.open.get() {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
return Poll::Ready(Err(io::Error::other(
"channel was closed",
)));
}
Expand Down
2 changes: 1 addition & 1 deletion examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ hickory-resolver = { workspace = true }
log = { workspace = true }
mio = { workspace = true }
rcgen = { workspace = true }
watfaq-rustls = { path = "../rustls", features = [ "logging" ]}
watfaq-rustls = { path = "../rustls", features = [ "logging" ] }
serde = { workspace = true }
tokio = { workspace = true }
webpki-roots = { workspace = true }
36 changes: 21 additions & 15 deletions examples/src/bin/reality-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ fn main() {
// Parse command line arguments
let args: Vec<String> = env::args().collect();
if args.len() != 5 {
eprintln!("Usage: {} <server_addr> <sni_servername> <public_key_base64> <short_id_hex>", args[0]);
eprintln!(
"Usage: {} <server_addr> <sni_servername> <public_key_base64> <short_id_hex>",
args[0]
);
eprintln!();
eprintln!("Parameters:");
eprintln!(" <server_addr> Real server address (e.g., tw04.ctg.wtf:443)");
Expand Down Expand Up @@ -101,11 +104,10 @@ fn main() {
let short_id_len = short_id.len();

// Create Reality configuration
let reality_config = RealityConfig::new(server_pubkey, short_id)
.unwrap_or_else(|e| {
eprintln!("Error creating Reality config: {}", e);
std::process::exit(1);
});
let reality_config = RealityConfig::new(server_pubkey, short_id).unwrap_or_else(|e| {
eprintln!("Error creating Reality config: {}", e);
std::process::exit(1);
});

println!("Reality configuration created successfully");
println!(" Server public key: {}", bytes_to_hex(&server_pubkey));
Expand All @@ -129,7 +131,10 @@ fn main() {
// Allow using SSLKEYLOGFILE for debugging
config.key_log = Arc::new(watfaq_rustls::KeyLogFile::new());

println!("\nConnecting to {} (SNI: {})...", &server_addr, &sni_servername);
println!(
"\nConnecting to {} (SNI: {})...",
&server_addr, &sni_servername
);

// Use SNI servername for TLS connection (for disguise/camouflage)
let server_name: pki_types::ServerName<'static> = sni_servername
Expand Down Expand Up @@ -191,34 +196,35 @@ fn main() {
println!("\nServer response:");
println!("----------------------------------------");
let mut plaintext = Vec::new();
tls.read_to_end(&mut plaintext).unwrap_or_else(|e| {
eprintln!("Error reading response: {}", e);
std::process::exit(1);
});
tls.read_to_end(&mut plaintext)
.unwrap_or_else(|e| {
eprintln!("Error reading response: {}", e);
std::process::exit(1);
});
stdout().write_all(&plaintext).unwrap();
println!("----------------------------------------");
println!("\nConnection closed successfully");
}

/// Helper function to convert hex string to bytes
fn hex_to_bytes(hex: &str) -> Result<Vec<u8>, &'static str> {
if hex.len() % 2 != 0 {
if !hex.len().is_multiple_of(2) {
return Err("Hex string must have even length");
}

let mut bytes = Vec::new();
for i in (0..hex.len()).step_by(2) {
let byte_str = &hex[i..i + 2];
let byte = u8::from_str_radix(byte_str, 16)
.map_err(|_| "Invalid hex character")?;
let byte = u8::from_str_radix(byte_str, 16).map_err(|_| "Invalid hex character")?;
bytes.push(byte);
}
Ok(bytes)
}

/// Helper function to convert bytes to hex string
fn bytes_to_hex(bytes: &[u8]) -> String {
bytes.iter()
bytes
.iter()
.map(|b| format!("{:02x}", b))
.collect::<Vec<_>>()
.join("")
Expand Down
2 changes: 1 addition & 1 deletion examples/src/bin/server_acceptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ impl TestPki {
&self,
serials: Vec<rcgen::SerialNumber>,
next_update_seconds: u64,
) -> CertificateRevocationListDer {
) -> CertificateRevocationListDer<'_> {
// In a real use-case you would want to set this to the current date/time.
let now = rcgen::date_time_ymd(2023, 1, 1);

Expand Down
1 change: 0 additions & 1 deletion rustls/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
/// for Rust Nightly.
///
/// See the comment in lib.rs to understand why we need this.

#[cfg_attr(feature = "read_buf", rustversion::not(nightly))]
fn main() {
println!("cargo:rustc-check-cfg=cfg(bench)");
Expand Down
93 changes: 48 additions & 45 deletions rustls/src/client/hs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ use core::ops::Deref;

use pki_types::ServerName;

use super::reality;
#[cfg(feature = "tls12")]
use super::tls12;
use super::Tls12Resumption;
use super::reality;
#[cfg(feature = "logging")]
use crate::bs_debug;
use crate::check::inappropriate_handshake_message;
Expand Down Expand Up @@ -132,7 +132,7 @@ where
// Set kx_state to X25519 group for Reality
let x25519_group = config
.find_kx_group(NamedGroup::X25519, ProtocolVersion::TLSv1_3)
.expect("X25519 group required for Reality");
.ok_or(Error::General("X25519 group required for Reality".into()))?;
cx.common.kx_state = KxState::Start(x25519_group);
None // Will be set later from reality_state
} else if config.supports_version(ProtocolVersion::TLSv1_3) {
Expand Down Expand Up @@ -171,7 +171,7 @@ where
}
Some(inner.session_id)
}
_ => None,
_ => None::<SessionId>,
}
} else {
debug!("Not resuming any session");
Expand Down Expand Up @@ -235,6 +235,7 @@ struct ExpectServerHello {
offered_key_share: Option<Box<dyn ActiveKeyExchange>>,
suite: Option<SupportedCipherSuite>,
ech_state: Option<EchState>,
reality_state: Option<reality::RealitySessionState>,
}

struct ExpectServerHelloOrHelloRetryRequest {
Expand Down Expand Up @@ -552,6 +553,47 @@ where
payload: HandshakePayload::ClientHello(chp_payload),
};

// Compute Reality session_id BEFORE PSK binder to avoid invalidating the binder
// Reality uses ClientHello with session_id=0 as AAD, so this order is safe
if let Some(ref reality) = reality_state {
// Step 1: Set session_id to zero temporarily
let mut buffer = Vec::new();
match &mut chp.payload {
HandshakePayload::ClientHello(c) => {
c.session_id = SessionId {
len: 32,
data: [0; 32],
};
}
_ => unreachable!(),
}

// Step 2: Encode ClientHello with zero session_id (for AAD)
chp.encode(&mut buffer);

// Step 3: Get HKDF-SHA256 provider
let hkdf = reality::get_hkdf_sha256_from_config(&config.provider.cipher_suites)?;

// Step 4: Compute Reality session_id
let session_id_data = reality.compute_session_id(
&input.random,
&buffer,
hkdf,
config.time_provider.as_ref(),
)?;

// Step 5: Update session_id with computed Reality value
match &mut chp.payload {
HandshakePayload::ClientHello(c) => {
c.session_id = SessionId {
len: 32,
data: session_id_data,
};
}
_ => unreachable!(),
}
}

let early_key_schedule = match (ech_state.as_mut(), tls13_session) {
// If we're performing ECH and resuming, then the PSK binder will have been dealt with
// separately, and we need to take the early_data_key_schedule computed for the inner hello.
Expand All @@ -561,7 +603,7 @@ where
.map(|schedule| (tls13_session.suite(), schedule)),

// When we're not doing ECH and resuming, then the PSK binder need to be filled in as
// normal.
// normal. Reality session_id has been set above, so PSK binder will see the correct value.
(_, Some(tls13_session)) => Some((
tls13_session.suite(),
tls13::fill_in_psk_binder(&tls13_session, &transcript_buffer, &mut chp),
Expand Down Expand Up @@ -596,46 +638,6 @@ where
}
}

// Compute Reality session_id if Reality is enabled
if let Some(ref reality) = reality_state {
// Step 1: Set session_id to zero temporarily
let mut buffer = Vec::new();
match &mut chp.payload {
HandshakePayload::ClientHello(c) => {
c.session_id = SessionId {
len: 32,
data: [0; 32],
};
}
_ => unreachable!(),
}

// Step 2: Encode ClientHello with zero session_id
chp.encode(&mut buffer);

// Step 3: Get HKDF-SHA256 provider
let hkdf = reality::get_hkdf_sha256_from_config(&config.provider.cipher_suites)?;

// Step 4: Compute Reality session_id
let session_id_data = reality.compute_session_id(
&input.random,
&buffer,
hkdf,
config.time_provider.as_ref(),
)?;

// Step 5: Update session_id
match &mut chp.payload {
HandshakePayload::ClientHello(c) => {
c.session_id = SessionId {
len: 32,
data: session_id_data,
};
}
_ => unreachable!(),
}
}

let ch = Message {
version: match retryreq {
// <https://datatracker.ietf.org/doc/html/rfc8446#section-5.1>:
Expand Down Expand Up @@ -698,6 +700,7 @@ where
offered_key_share: key_share,
suite,
ech_state,
reality_state,
};

Ok(if support_tls13 && retryreq.is_none() {
Expand Down Expand Up @@ -1266,7 +1269,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
self.next.input,
cx,
self.next.ech_state,
None, // Reality state not used in retry
self.next.reality_state,
)
}
}
Expand Down
Loading