Files
robots_proxy/tests/proxy_integration.rs
Zachary D. Rowitsch f01548ad8a
Some checks failed
CI / build (push) Has been cancelled
Add MITM integration tests and configurable robots failover
* add native-root HTTPS client for robots policy so HTTPS robots.txt fetches work under MITM
  * introduce `robots.missing_policy`, document all robots flags, and handle missing robots responses
  * add HTTPS MITM integration tests (self-signed CA, CONNECT tunnel verification)
  * document MITM CA generation/install steps and wire server listener to support upgrades
2025-11-28 15:49:00 -05:00

571 lines
21 KiB
Rust

use std::convert::Infallible;
use std::fs;
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper::{Request, Response, StatusCode, Uri};
use hyper_rustls::HttpsConnectorBuilder;
use hyper_util::client::legacy::{Client, connect::HttpConnector as LegacyHttpConnector};
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as AutoBuilder;
use rcgen::{BasicConstraints, Certificate, CertificateParams, IsCa, SanType};
use robots_proxy::mitm::MitmAuthority;
use robots_proxy::proxy::service::{ProxyService, build_service};
use robots_proxy::robots::{RobotsDecider, RobotsPolicy, RobotsVerdict};
use robots_proxy::settings::AppConfig;
use rustls::{self, ClientConfig, RootCertStore, ServerConfig, client::ServerName};
use rustls_022::pki_types::CertificateDer as CertificateDer022;
use rustls_022::{ClientConfig as ClientConfig022, RootCertStore as RootCertStore022};
use tempfile::TempDir;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio_rustls::{TlsAcceptor, TlsConnector};
use tower::ServiceExt;
const TEST_SKIP_MSG: &str = "integration test skipped: unable to bind local socket";
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn proxy_forwards_allowed_request() -> Result<(), Box<dyn std::error::Error>> {
let server = match TestHttpServer::spawn("User-agent: *\nAllow: /").await {
Ok(server) => server,
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => {
eprintln!("{TEST_SKIP_MSG}: {err}");
return Ok(());
}
Err(err) => return Err(err.into()),
};
let service = build_proxy_service();
let uri = format!("http://{}/allowed", server.addr);
let response = service.clone().oneshot(build_request(&uri)).await?;
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await?.to_bytes();
assert_eq!(body, Bytes::from_static(b"ok"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn proxy_blocks_disallowed_request() -> Result<(), Box<dyn std::error::Error>> {
let server = match TestHttpServer::spawn("User-agent: *\nDisallow: /private").await {
Ok(server) => server,
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => {
eprintln!("{TEST_SKIP_MSG}: {err}");
return Ok(());
}
Err(err) => return Err(err.into()),
};
let service = build_proxy_service();
let uri = format!("http://{}/private/data", server.addr);
let response = service.clone().oneshot(build_request(&uri)).await?;
assert_eq!(response.status(), StatusCode::FORBIDDEN);
Ok(())
}
fn build_proxy_service() -> ProxyService {
let mut connector = LegacyHttpConnector::new();
connector.enforce_http(false);
let https = HttpsConnectorBuilder::new()
.with_native_roots()
.expect("native roots")
.https_or_http()
.enable_http1()
.wrap_connector(connector);
let client = Client::builder(TokioExecutor::new()).build(https);
let mut config = AppConfig::default();
config.listen_addr = "127.0.0.1:0".into();
let config = Arc::new(config);
let policy: Arc<dyn RobotsDecider> =
Arc::new(RobotsPolicy::new(config.clone()).expect("robots policy"));
build_service(config, client, Some(policy), None)
}
fn build_request(uri: &str) -> Request<Incoming> {
Request::builder().method("GET").uri(uri).body(empty_incoming()).expect("request")
}
fn empty_incoming() -> Incoming {
unsafe { std::mem::zeroed() }
}
struct TestHttpServer {
addr: SocketAddr,
shutdown_tx: Option<oneshot::Sender<()>>,
handle: Option<JoinHandle<()>>,
}
impl TestHttpServer {
async fn spawn(robots_body: &'static str) -> io::Result<Self> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let (tx, mut rx) = oneshot::channel();
let robots_body = robots_body.to_string();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = &mut rx => break,
accept = listener.accept() => {
let (stream, _) = match accept {
Ok(pair) => pair,
Err(err) => {
eprintln!("test server accept error: {err}");
continue;
}
};
let robots_body = robots_body.clone();
tokio::spawn(async move {
let service = service_fn(move |req| {
let robots_body = robots_body.clone();
async move {
let payload = if req.uri().path() == "/robots.txt" {
robots_body.clone()
} else {
"ok".to_string()
};
let response = Response::builder()
.status(StatusCode::OK)
.body(Full::new(Bytes::from(payload)))
.unwrap();
Ok::<_, Infallible>(response)
}
});
let io = TokioIo::new(stream);
if let Err(err) = AutoBuilder::new(TokioExecutor::new())
.serve_connection(io, service)
.await
{
eprintln!("test server connection error: {err}");
}
});
}
}
}
});
Ok(Self { addr, shutdown_tx: Some(tx), handle: Some(handle) })
}
}
impl Drop for TestHttpServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(handle) = self.handle.take() {
handle.abort();
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn https_mitm_blocks_disallowed_request() -> Result<(), Box<dyn std::error::Error>> {
let ca = TestCa::new();
let policy = static_policy();
let tls_server = match TestTlsServer::spawn(&ca, "User-agent: *\nDisallow: /blocked").await {
Ok(server) => server,
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => {
eprintln!("{TEST_SKIP_MSG}: {err}");
return Ok(());
}
Err(err) => return Err(err.into()),
};
let service = build_mitm_proxy_service(policy.clone(), &ca);
let proxy = match TestProxy::spawn(service).await {
Ok(proxy) => proxy,
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => {
eprintln!("{TEST_SKIP_MSG}: {err}");
return Ok(());
}
Err(err) => return Err(err.into()),
};
let (status, _) = perform_https_request(proxy.addr, tls_server.addr, &ca, "/blocked").await?;
assert_eq!(status, StatusCode::FORBIDDEN);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn https_mitm_allows_permitted_request() -> Result<(), Box<dyn std::error::Error>> {
let ca = TestCa::new();
let policy = static_policy();
let tls_server = match TestTlsServer::spawn(&ca, "User-agent: *\nDisallow: /blocked").await {
Ok(server) => server,
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => {
eprintln!("{TEST_SKIP_MSG}: {err}");
return Ok(());
}
Err(err) => return Err(err.into()),
};
let service = build_mitm_proxy_service(policy.clone(), &ca);
let proxy = match TestProxy::spawn(service).await {
Ok(proxy) => proxy,
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => {
eprintln!("{TEST_SKIP_MSG}: {err}");
return Ok(());
}
Err(err) => return Err(err.into()),
};
let (status, body) =
perform_https_request(proxy.addr, tls_server.addr, &ca, "/allowed").await?;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, b"ok".as_ref());
Ok(())
}
struct TestProxy {
addr: SocketAddr,
shutdown_tx: Option<oneshot::Sender<()>>,
handle: Option<JoinHandle<()>>,
}
impl TestProxy {
async fn spawn(service: ProxyService) -> io::Result<Self> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let (tx, mut rx) = oneshot::channel();
let service = service;
let handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = &mut rx => break,
accept = listener.accept() => {
let (stream, peer) = match accept {
Ok(pair) => pair,
Err(err) => {
eprintln!("proxy accept error: {err}");
continue;
}
};
let svc = service.clone();
tokio::spawn(async move {
let io = TokioIo::new(stream);
let builder = AutoBuilder::new(TokioExecutor::new());
let connection = builder.serve_connection_with_upgrades(
io,
service_fn(move |req| svc.clone().oneshot(req)),
);
if let Err(err) = connection.await {
eprintln!("proxy connection error from {peer}: {err}");
}
});
}
}
}
});
Ok(Self { addr, shutdown_tx: Some(tx), handle: Some(handle) })
}
}
impl Drop for TestProxy {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(handle) = self.handle.take() {
handle.abort();
}
}
}
struct TestTlsServer {
addr: SocketAddr,
shutdown_tx: Option<oneshot::Sender<()>>,
handle: Option<JoinHandle<()>>,
}
impl TestTlsServer {
async fn spawn(ca: &TestCa, robots_body: &'static str) -> io::Result<Self> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let (tx, mut rx) = oneshot::channel();
let (cert_chain, key) = ca.issue_server_cert(addr.ip());
let server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert_chain, key)
.expect("valid server cert");
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let robots_body = robots_body.to_string();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = &mut rx => break,
accept = listener.accept() => {
let (stream, _) = match accept {
Ok(pair) => pair,
Err(err) => {
eprintln!("tls server accept error: {err}");
continue;
}
};
let acceptor = acceptor.clone();
let robots_body = robots_body.clone();
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
let io = TokioIo::new(tls_stream);
let service = service_fn(move |req| {
let body = robots_body.clone();
async move {
let payload = if req.uri().path() == "/robots.txt" {
body.clone()
} else {
"ok".to_string()
};
let response = Response::builder()
.status(StatusCode::OK)
.body(Full::new(Bytes::from(payload)))
.unwrap();
Ok::<_, Infallible>(response)
}
});
if let Err(err) = AutoBuilder::new(TokioExecutor::new())
.serve_connection(io, service)
.await
{
eprintln!("tls server connection error: {err}");
}
}
Err(err) => eprintln!("tls handshake error: {err}"),
}
});
}
}
}
});
Ok(Self { addr, shutdown_tx: Some(tx), handle: Some(handle) })
}
}
impl Drop for TestTlsServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(handle) = self.handle.take() {
handle.abort();
}
}
}
struct TestCa {
_dir: TempDir,
cert: Certificate,
cert_der: Vec<u8>,
cert_path: PathBuf,
key_path: PathBuf,
}
impl TestCa {
fn new() -> Self {
let dir = TempDir::new().expect("tempdir");
let mut params = CertificateParams::new(vec!["Robots Proxy Test CA".to_string()]);
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![
rcgen::KeyUsagePurpose::KeyCertSign,
rcgen::KeyUsagePurpose::CrlSign,
rcgen::KeyUsagePurpose::DigitalSignature,
];
params.alg = &rcgen::PKCS_ECDSA_P256_SHA256;
let cert = Certificate::from_params(params).expect("ca params");
let cert_der = cert.serialize_der().expect("ca der");
let cert_pem = cert.serialize_pem().expect("ca pem");
let key_pem = cert.serialize_private_key_pem();
let cert_path = dir.path().join("ca.pem");
let key_path = dir.path().join("ca.key");
fs::write(&cert_path, cert_pem).expect("write cert");
fs::write(&key_path, key_pem).expect("write key");
Self { _dir: dir, cert, cert_der, cert_path, key_path }
}
fn issue_server_cert(&self, host: IpAddr) -> (Vec<rustls::Certificate>, rustls::PrivateKey) {
let mut params = CertificateParams::new(vec![host.to_string()]);
params.alg = &rcgen::PKCS_ECDSA_P256_SHA256;
params.subject_alt_names = vec![SanType::IpAddress(host)];
params.key_usages =
vec![rcgen::KeyUsagePurpose::DigitalSignature, rcgen::KeyUsagePurpose::KeyEncipherment];
let cert = Certificate::from_params(params).expect("leaf params");
let der = cert.serialize_der_with_signer(&self.cert).expect("leaf der");
let key = cert.serialize_private_key_der();
let mut chain = Vec::new();
chain.push(rustls::Certificate(der));
chain.push(rustls::Certificate(self.cert_der.clone()));
(chain, rustls::PrivateKey(key))
}
fn cert_path(&self) -> &Path {
&self.cert_path
}
fn key_path(&self) -> &Path {
&self.key_path
}
fn root_der(&self) -> &[u8] {
&self.cert_der
}
}
#[derive(Clone, Default)]
struct StaticRobotsPolicy;
#[async_trait]
impl RobotsDecider for StaticRobotsPolicy {
async fn decide(&self, uri: &Uri) -> RobotsVerdict {
if uri.path().starts_with("/blocked") {
RobotsVerdict::Disallowed { reason: "blocked", status: StatusCode::FORBIDDEN }
} else {
RobotsVerdict::Allowed
}
}
}
fn static_policy() -> Arc<dyn RobotsDecider> {
Arc::new(StaticRobotsPolicy::default())
}
fn build_mitm_proxy_service(policy: Arc<dyn RobotsDecider>, ca: &TestCa) -> ProxyService {
let mut connector = LegacyHttpConnector::new();
connector.enforce_http(false);
let mut roots = RootCertStore022::empty();
roots.add_parsable_certificates([CertificateDer022::from(ca.root_der().to_vec())]);
let client_config =
ClientConfig022::builder().with_root_certificates(Arc::new(roots)).with_no_client_auth();
let https = HttpsConnectorBuilder::new()
.with_tls_config(client_config)
.https_or_http()
.enable_http1()
.wrap_connector(connector);
let client = Client::builder(TokioExecutor::new()).build(https);
let mut config = AppConfig::default();
config.listen_addr = "127.0.0.1:0".into();
config.mitm.enabled = true;
config.mitm.ca_cert_path = Some(ca.cert_path().to_path_buf());
config.mitm.ca_key_path = Some(ca.key_path().to_path_buf());
config.mitm.cert_cache_ttl_secs = 60;
let config = Arc::new(config);
let mitm =
MitmAuthority::from_config(&config.mitm).expect("mitm config").expect("mitm enabled");
build_service(config, client, Some(policy), Some(mitm))
}
async fn perform_https_request(
proxy_addr: SocketAddr,
target: SocketAddr,
ca: &TestCa,
path: &str,
) -> io::Result<(StatusCode, Vec<u8>)> {
let mut stream = tokio::net::TcpStream::connect(proxy_addr).await?;
let authority = format!("{}:{}", target.ip(), target.port());
let connect = format!("CONNECT {} HTTP/1.1\r\nHost: {}\r\n\r\n", authority, authority);
stream.write_all(connect.as_bytes()).await?;
stream.flush().await?;
let mut header = Vec::new();
read_until_headers(&mut stream, &mut header).await?;
let status = parse_status(&header)?;
if status != StatusCode::OK {
return Ok((status, Vec::new()));
}
let mut roots = RootCertStore::empty();
roots.add_parsable_certificates(&[ca.root_der().to_vec()]);
let client_config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(client_config));
let server_name = match target.ip() {
IpAddr::V4(addr) => ServerName::IpAddress(addr.into()),
IpAddr::V6(addr) => ServerName::IpAddress(addr.into()),
};
let mut tls = connector.connect(server_name, stream).await?;
let request =
format!("GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", path, authority);
tls.write_all(request.as_bytes()).await?;
tls.flush().await?;
let mut response = Vec::new();
tls.read_to_end(&mut response).await?;
let status = parse_status(&response)?;
let body = split_body(&response).to_vec();
Ok((status, body))
}
async fn read_until_headers<S>(stream: &mut S, buf: &mut Vec<u8>) -> io::Result<()>
where
S: AsyncRead + Unpin,
{
let mut byte = [0u8; 1];
while stream.read_exact(&mut byte).await.is_ok() {
buf.push(byte[0]);
if buf.ends_with(b"\r\n\r\n") {
return Ok(());
}
}
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "incomplete headers"))
}
fn parse_status(buf: &[u8]) -> io::Result<StatusCode> {
let header_end = buf
.windows(4)
.position(|window| window == b"\r\n\r\n")
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing headers"))?;
let header = &buf[..header_end];
let header_str = std::str::from_utf8(header)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "headers not utf8"))?;
let status_line = header_str
.lines()
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing status line"))?;
let status_str = status_line.trim();
let parts: Vec<_> = status_str.split_whitespace().collect();
if parts.len() < 2 {
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid status line"));
}
let code: u16 = parts[1]
.parse()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid status code"))?;
StatusCode::from_u16(code)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "unknown status code"))
}
fn split_body(buf: &[u8]) -> &[u8] {
if let Some(idx) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
&buf[idx + 4..]
} else {
&[]
}
}