Some checks failed
CI / build (push) Has been cancelled
* add native-root HTTPS client for robots policy so HTTPS robots.txt fetches work under MITM * introduce `robots.missing_policy`, document all robots flags, and handle missing robots responses * add HTTPS MITM integration tests (self-signed CA, CONNECT tunnel verification) * document MITM CA generation/install steps and wire server listener to support upgrades
571 lines
21 KiB
Rust
571 lines
21 KiB
Rust
use std::convert::Infallible;
|
|
use std::fs;
|
|
use std::io;
|
|
use std::net::{IpAddr, SocketAddr};
|
|
use std::path::{Path, PathBuf};
|
|
use std::sync::Arc;
|
|
|
|
use async_trait::async_trait;
|
|
use bytes::Bytes;
|
|
use http_body_util::{BodyExt, Full};
|
|
use hyper::body::Incoming;
|
|
use hyper::service::service_fn;
|
|
use hyper::{Request, Response, StatusCode, Uri};
|
|
use hyper_rustls::HttpsConnectorBuilder;
|
|
use hyper_util::client::legacy::{Client, connect::HttpConnector as LegacyHttpConnector};
|
|
use hyper_util::rt::{TokioExecutor, TokioIo};
|
|
use hyper_util::server::conn::auto::Builder as AutoBuilder;
|
|
use rcgen::{BasicConstraints, Certificate, CertificateParams, IsCa, SanType};
|
|
use robots_proxy::mitm::MitmAuthority;
|
|
use robots_proxy::proxy::service::{ProxyService, build_service};
|
|
use robots_proxy::robots::{RobotsDecider, RobotsPolicy, RobotsVerdict};
|
|
use robots_proxy::settings::AppConfig;
|
|
use rustls::{self, ClientConfig, RootCertStore, ServerConfig, client::ServerName};
|
|
use rustls_022::pki_types::CertificateDer as CertificateDer022;
|
|
use rustls_022::{ClientConfig as ClientConfig022, RootCertStore as RootCertStore022};
|
|
use tempfile::TempDir;
|
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::TcpListener;
|
|
use tokio::sync::oneshot;
|
|
use tokio::task::JoinHandle;
|
|
use tokio_rustls::{TlsAcceptor, TlsConnector};
|
|
use tower::ServiceExt;
|
|
|
|
const TEST_SKIP_MSG: &str = "integration test skipped: unable to bind local socket";
|
|
|
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
|
async fn proxy_forwards_allowed_request() -> Result<(), Box<dyn std::error::Error>> {
|
|
let server = match TestHttpServer::spawn("User-agent: *\nAllow: /").await {
|
|
Ok(server) => server,
|
|
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => {
|
|
eprintln!("{TEST_SKIP_MSG}: {err}");
|
|
return Ok(());
|
|
}
|
|
Err(err) => return Err(err.into()),
|
|
};
|
|
|
|
let service = build_proxy_service();
|
|
let uri = format!("http://{}/allowed", server.addr);
|
|
let response = service.clone().oneshot(build_request(&uri)).await?;
|
|
|
|
assert_eq!(response.status(), StatusCode::OK);
|
|
let body = response.into_body().collect().await?.to_bytes();
|
|
assert_eq!(body, Bytes::from_static(b"ok"));
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
|
async fn proxy_blocks_disallowed_request() -> Result<(), Box<dyn std::error::Error>> {
|
|
let server = match TestHttpServer::spawn("User-agent: *\nDisallow: /private").await {
|
|
Ok(server) => server,
|
|
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => {
|
|
eprintln!("{TEST_SKIP_MSG}: {err}");
|
|
return Ok(());
|
|
}
|
|
Err(err) => return Err(err.into()),
|
|
};
|
|
|
|
let service = build_proxy_service();
|
|
let uri = format!("http://{}/private/data", server.addr);
|
|
let response = service.clone().oneshot(build_request(&uri)).await?;
|
|
|
|
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
Ok(())
|
|
}
|
|
|
|
fn build_proxy_service() -> ProxyService {
|
|
let mut connector = LegacyHttpConnector::new();
|
|
connector.enforce_http(false);
|
|
let https = HttpsConnectorBuilder::new()
|
|
.with_native_roots()
|
|
.expect("native roots")
|
|
.https_or_http()
|
|
.enable_http1()
|
|
.wrap_connector(connector);
|
|
let client = Client::builder(TokioExecutor::new()).build(https);
|
|
|
|
let mut config = AppConfig::default();
|
|
config.listen_addr = "127.0.0.1:0".into();
|
|
let config = Arc::new(config);
|
|
|
|
let policy: Arc<dyn RobotsDecider> =
|
|
Arc::new(RobotsPolicy::new(config.clone()).expect("robots policy"));
|
|
build_service(config, client, Some(policy), None)
|
|
}
|
|
|
|
fn build_request(uri: &str) -> Request<Incoming> {
|
|
Request::builder().method("GET").uri(uri).body(empty_incoming()).expect("request")
|
|
}
|
|
|
|
fn empty_incoming() -> Incoming {
|
|
unsafe { std::mem::zeroed() }
|
|
}
|
|
|
|
struct TestHttpServer {
|
|
addr: SocketAddr,
|
|
shutdown_tx: Option<oneshot::Sender<()>>,
|
|
handle: Option<JoinHandle<()>>,
|
|
}
|
|
|
|
impl TestHttpServer {
|
|
async fn spawn(robots_body: &'static str) -> io::Result<Self> {
|
|
let listener = TcpListener::bind("127.0.0.1:0").await?;
|
|
let addr = listener.local_addr()?;
|
|
let (tx, mut rx) = oneshot::channel();
|
|
let robots_body = robots_body.to_string();
|
|
|
|
let handle = tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
_ = &mut rx => break,
|
|
accept = listener.accept() => {
|
|
let (stream, _) = match accept {
|
|
Ok(pair) => pair,
|
|
Err(err) => {
|
|
eprintln!("test server accept error: {err}");
|
|
continue;
|
|
}
|
|
};
|
|
let robots_body = robots_body.clone();
|
|
tokio::spawn(async move {
|
|
let service = service_fn(move |req| {
|
|
let robots_body = robots_body.clone();
|
|
async move {
|
|
let payload = if req.uri().path() == "/robots.txt" {
|
|
robots_body.clone()
|
|
} else {
|
|
"ok".to_string()
|
|
};
|
|
let response = Response::builder()
|
|
.status(StatusCode::OK)
|
|
.body(Full::new(Bytes::from(payload)))
|
|
.unwrap();
|
|
Ok::<_, Infallible>(response)
|
|
}
|
|
});
|
|
|
|
let io = TokioIo::new(stream);
|
|
if let Err(err) = AutoBuilder::new(TokioExecutor::new())
|
|
.serve_connection(io, service)
|
|
.await
|
|
{
|
|
eprintln!("test server connection error: {err}");
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok(Self { addr, shutdown_tx: Some(tx), handle: Some(handle) })
|
|
}
|
|
}
|
|
|
|
impl Drop for TestHttpServer {
|
|
fn drop(&mut self) {
|
|
if let Some(tx) = self.shutdown_tx.take() {
|
|
let _ = tx.send(());
|
|
}
|
|
if let Some(handle) = self.handle.take() {
|
|
handle.abort();
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
|
async fn https_mitm_blocks_disallowed_request() -> Result<(), Box<dyn std::error::Error>> {
|
|
let ca = TestCa::new();
|
|
let policy = static_policy();
|
|
let tls_server = match TestTlsServer::spawn(&ca, "User-agent: *\nDisallow: /blocked").await {
|
|
Ok(server) => server,
|
|
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => {
|
|
eprintln!("{TEST_SKIP_MSG}: {err}");
|
|
return Ok(());
|
|
}
|
|
Err(err) => return Err(err.into()),
|
|
};
|
|
|
|
let service = build_mitm_proxy_service(policy.clone(), &ca);
|
|
let proxy = match TestProxy::spawn(service).await {
|
|
Ok(proxy) => proxy,
|
|
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => {
|
|
eprintln!("{TEST_SKIP_MSG}: {err}");
|
|
return Ok(());
|
|
}
|
|
Err(err) => return Err(err.into()),
|
|
};
|
|
|
|
let (status, _) = perform_https_request(proxy.addr, tls_server.addr, &ca, "/blocked").await?;
|
|
assert_eq!(status, StatusCode::FORBIDDEN);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
|
async fn https_mitm_allows_permitted_request() -> Result<(), Box<dyn std::error::Error>> {
|
|
let ca = TestCa::new();
|
|
let policy = static_policy();
|
|
let tls_server = match TestTlsServer::spawn(&ca, "User-agent: *\nDisallow: /blocked").await {
|
|
Ok(server) => server,
|
|
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => {
|
|
eprintln!("{TEST_SKIP_MSG}: {err}");
|
|
return Ok(());
|
|
}
|
|
Err(err) => return Err(err.into()),
|
|
};
|
|
|
|
let service = build_mitm_proxy_service(policy.clone(), &ca);
|
|
let proxy = match TestProxy::spawn(service).await {
|
|
Ok(proxy) => proxy,
|
|
Err(err) if err.kind() == io::ErrorKind::PermissionDenied => {
|
|
eprintln!("{TEST_SKIP_MSG}: {err}");
|
|
return Ok(());
|
|
}
|
|
Err(err) => return Err(err.into()),
|
|
};
|
|
|
|
let (status, body) =
|
|
perform_https_request(proxy.addr, tls_server.addr, &ca, "/allowed").await?;
|
|
assert_eq!(status, StatusCode::OK);
|
|
assert_eq!(body, b"ok".as_ref());
|
|
|
|
Ok(())
|
|
}
|
|
|
|
struct TestProxy {
|
|
addr: SocketAddr,
|
|
shutdown_tx: Option<oneshot::Sender<()>>,
|
|
handle: Option<JoinHandle<()>>,
|
|
}
|
|
|
|
impl TestProxy {
|
|
async fn spawn(service: ProxyService) -> io::Result<Self> {
|
|
let listener = TcpListener::bind("127.0.0.1:0").await?;
|
|
let addr = listener.local_addr()?;
|
|
let (tx, mut rx) = oneshot::channel();
|
|
let service = service;
|
|
|
|
let handle = tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
_ = &mut rx => break,
|
|
accept = listener.accept() => {
|
|
let (stream, peer) = match accept {
|
|
Ok(pair) => pair,
|
|
Err(err) => {
|
|
eprintln!("proxy accept error: {err}");
|
|
continue;
|
|
}
|
|
};
|
|
let svc = service.clone();
|
|
tokio::spawn(async move {
|
|
let io = TokioIo::new(stream);
|
|
let builder = AutoBuilder::new(TokioExecutor::new());
|
|
let connection = builder.serve_connection_with_upgrades(
|
|
io,
|
|
service_fn(move |req| svc.clone().oneshot(req)),
|
|
);
|
|
if let Err(err) = connection.await {
|
|
eprintln!("proxy connection error from {peer}: {err}");
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok(Self { addr, shutdown_tx: Some(tx), handle: Some(handle) })
|
|
}
|
|
}
|
|
|
|
impl Drop for TestProxy {
|
|
fn drop(&mut self) {
|
|
if let Some(tx) = self.shutdown_tx.take() {
|
|
let _ = tx.send(());
|
|
}
|
|
if let Some(handle) = self.handle.take() {
|
|
handle.abort();
|
|
}
|
|
}
|
|
}
|
|
|
|
struct TestTlsServer {
|
|
addr: SocketAddr,
|
|
shutdown_tx: Option<oneshot::Sender<()>>,
|
|
handle: Option<JoinHandle<()>>,
|
|
}
|
|
|
|
impl TestTlsServer {
|
|
async fn spawn(ca: &TestCa, robots_body: &'static str) -> io::Result<Self> {
|
|
let listener = TcpListener::bind("127.0.0.1:0").await?;
|
|
let addr = listener.local_addr()?;
|
|
let (tx, mut rx) = oneshot::channel();
|
|
let (cert_chain, key) = ca.issue_server_cert(addr.ip());
|
|
let server_config = ServerConfig::builder()
|
|
.with_safe_defaults()
|
|
.with_no_client_auth()
|
|
.with_single_cert(cert_chain, key)
|
|
.expect("valid server cert");
|
|
let acceptor = TlsAcceptor::from(Arc::new(server_config));
|
|
let robots_body = robots_body.to_string();
|
|
|
|
let handle = tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
_ = &mut rx => break,
|
|
accept = listener.accept() => {
|
|
let (stream, _) = match accept {
|
|
Ok(pair) => pair,
|
|
Err(err) => {
|
|
eprintln!("tls server accept error: {err}");
|
|
continue;
|
|
}
|
|
};
|
|
let acceptor = acceptor.clone();
|
|
let robots_body = robots_body.clone();
|
|
tokio::spawn(async move {
|
|
match acceptor.accept(stream).await {
|
|
Ok(tls_stream) => {
|
|
let io = TokioIo::new(tls_stream);
|
|
let service = service_fn(move |req| {
|
|
let body = robots_body.clone();
|
|
async move {
|
|
let payload = if req.uri().path() == "/robots.txt" {
|
|
body.clone()
|
|
} else {
|
|
"ok".to_string()
|
|
};
|
|
let response = Response::builder()
|
|
.status(StatusCode::OK)
|
|
.body(Full::new(Bytes::from(payload)))
|
|
.unwrap();
|
|
Ok::<_, Infallible>(response)
|
|
}
|
|
});
|
|
|
|
if let Err(err) = AutoBuilder::new(TokioExecutor::new())
|
|
.serve_connection(io, service)
|
|
.await
|
|
{
|
|
eprintln!("tls server connection error: {err}");
|
|
}
|
|
}
|
|
Err(err) => eprintln!("tls handshake error: {err}"),
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok(Self { addr, shutdown_tx: Some(tx), handle: Some(handle) })
|
|
}
|
|
}
|
|
|
|
impl Drop for TestTlsServer {
|
|
fn drop(&mut self) {
|
|
if let Some(tx) = self.shutdown_tx.take() {
|
|
let _ = tx.send(());
|
|
}
|
|
if let Some(handle) = self.handle.take() {
|
|
handle.abort();
|
|
}
|
|
}
|
|
}
|
|
|
|
struct TestCa {
|
|
_dir: TempDir,
|
|
cert: Certificate,
|
|
cert_der: Vec<u8>,
|
|
cert_path: PathBuf,
|
|
key_path: PathBuf,
|
|
}
|
|
|
|
impl TestCa {
|
|
fn new() -> Self {
|
|
let dir = TempDir::new().expect("tempdir");
|
|
let mut params = CertificateParams::new(vec!["Robots Proxy Test CA".to_string()]);
|
|
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
|
|
params.key_usages = vec![
|
|
rcgen::KeyUsagePurpose::KeyCertSign,
|
|
rcgen::KeyUsagePurpose::CrlSign,
|
|
rcgen::KeyUsagePurpose::DigitalSignature,
|
|
];
|
|
params.alg = &rcgen::PKCS_ECDSA_P256_SHA256;
|
|
let cert = Certificate::from_params(params).expect("ca params");
|
|
let cert_der = cert.serialize_der().expect("ca der");
|
|
let cert_pem = cert.serialize_pem().expect("ca pem");
|
|
let key_pem = cert.serialize_private_key_pem();
|
|
let cert_path = dir.path().join("ca.pem");
|
|
let key_path = dir.path().join("ca.key");
|
|
fs::write(&cert_path, cert_pem).expect("write cert");
|
|
fs::write(&key_path, key_pem).expect("write key");
|
|
|
|
Self { _dir: dir, cert, cert_der, cert_path, key_path }
|
|
}
|
|
|
|
fn issue_server_cert(&self, host: IpAddr) -> (Vec<rustls::Certificate>, rustls::PrivateKey) {
|
|
let mut params = CertificateParams::new(vec![host.to_string()]);
|
|
params.alg = &rcgen::PKCS_ECDSA_P256_SHA256;
|
|
params.subject_alt_names = vec![SanType::IpAddress(host)];
|
|
params.key_usages =
|
|
vec![rcgen::KeyUsagePurpose::DigitalSignature, rcgen::KeyUsagePurpose::KeyEncipherment];
|
|
let cert = Certificate::from_params(params).expect("leaf params");
|
|
let der = cert.serialize_der_with_signer(&self.cert).expect("leaf der");
|
|
let key = cert.serialize_private_key_der();
|
|
let mut chain = Vec::new();
|
|
chain.push(rustls::Certificate(der));
|
|
chain.push(rustls::Certificate(self.cert_der.clone()));
|
|
(chain, rustls::PrivateKey(key))
|
|
}
|
|
|
|
fn cert_path(&self) -> &Path {
|
|
&self.cert_path
|
|
}
|
|
|
|
fn key_path(&self) -> &Path {
|
|
&self.key_path
|
|
}
|
|
|
|
fn root_der(&self) -> &[u8] {
|
|
&self.cert_der
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Default)]
|
|
struct StaticRobotsPolicy;
|
|
|
|
#[async_trait]
|
|
impl RobotsDecider for StaticRobotsPolicy {
|
|
async fn decide(&self, uri: &Uri) -> RobotsVerdict {
|
|
if uri.path().starts_with("/blocked") {
|
|
RobotsVerdict::Disallowed { reason: "blocked", status: StatusCode::FORBIDDEN }
|
|
} else {
|
|
RobotsVerdict::Allowed
|
|
}
|
|
}
|
|
}
|
|
|
|
fn static_policy() -> Arc<dyn RobotsDecider> {
|
|
Arc::new(StaticRobotsPolicy::default())
|
|
}
|
|
|
|
fn build_mitm_proxy_service(policy: Arc<dyn RobotsDecider>, ca: &TestCa) -> ProxyService {
|
|
let mut connector = LegacyHttpConnector::new();
|
|
connector.enforce_http(false);
|
|
|
|
let mut roots = RootCertStore022::empty();
|
|
roots.add_parsable_certificates([CertificateDer022::from(ca.root_der().to_vec())]);
|
|
let client_config =
|
|
ClientConfig022::builder().with_root_certificates(Arc::new(roots)).with_no_client_auth();
|
|
|
|
let https = HttpsConnectorBuilder::new()
|
|
.with_tls_config(client_config)
|
|
.https_or_http()
|
|
.enable_http1()
|
|
.wrap_connector(connector);
|
|
let client = Client::builder(TokioExecutor::new()).build(https);
|
|
|
|
let mut config = AppConfig::default();
|
|
config.listen_addr = "127.0.0.1:0".into();
|
|
config.mitm.enabled = true;
|
|
config.mitm.ca_cert_path = Some(ca.cert_path().to_path_buf());
|
|
config.mitm.ca_key_path = Some(ca.key_path().to_path_buf());
|
|
config.mitm.cert_cache_ttl_secs = 60;
|
|
let config = Arc::new(config);
|
|
|
|
let mitm =
|
|
MitmAuthority::from_config(&config.mitm).expect("mitm config").expect("mitm enabled");
|
|
build_service(config, client, Some(policy), Some(mitm))
|
|
}
|
|
|
|
async fn perform_https_request(
|
|
proxy_addr: SocketAddr,
|
|
target: SocketAddr,
|
|
ca: &TestCa,
|
|
path: &str,
|
|
) -> io::Result<(StatusCode, Vec<u8>)> {
|
|
let mut stream = tokio::net::TcpStream::connect(proxy_addr).await?;
|
|
let authority = format!("{}:{}", target.ip(), target.port());
|
|
let connect = format!("CONNECT {} HTTP/1.1\r\nHost: {}\r\n\r\n", authority, authority);
|
|
stream.write_all(connect.as_bytes()).await?;
|
|
stream.flush().await?;
|
|
|
|
let mut header = Vec::new();
|
|
read_until_headers(&mut stream, &mut header).await?;
|
|
let status = parse_status(&header)?;
|
|
if status != StatusCode::OK {
|
|
return Ok((status, Vec::new()));
|
|
}
|
|
|
|
let mut roots = RootCertStore::empty();
|
|
roots.add_parsable_certificates(&[ca.root_der().to_vec()]);
|
|
let client_config = ClientConfig::builder()
|
|
.with_safe_defaults()
|
|
.with_root_certificates(roots)
|
|
.with_no_client_auth();
|
|
let connector = TlsConnector::from(Arc::new(client_config));
|
|
let server_name = match target.ip() {
|
|
IpAddr::V4(addr) => ServerName::IpAddress(addr.into()),
|
|
IpAddr::V6(addr) => ServerName::IpAddress(addr.into()),
|
|
};
|
|
let mut tls = connector.connect(server_name, stream).await?;
|
|
|
|
let request =
|
|
format!("GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", path, authority);
|
|
tls.write_all(request.as_bytes()).await?;
|
|
tls.flush().await?;
|
|
|
|
let mut response = Vec::new();
|
|
tls.read_to_end(&mut response).await?;
|
|
let status = parse_status(&response)?;
|
|
let body = split_body(&response).to_vec();
|
|
Ok((status, body))
|
|
}
|
|
|
|
async fn read_until_headers<S>(stream: &mut S, buf: &mut Vec<u8>) -> io::Result<()>
|
|
where
|
|
S: AsyncRead + Unpin,
|
|
{
|
|
let mut byte = [0u8; 1];
|
|
while stream.read_exact(&mut byte).await.is_ok() {
|
|
buf.push(byte[0]);
|
|
if buf.ends_with(b"\r\n\r\n") {
|
|
return Ok(());
|
|
}
|
|
}
|
|
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "incomplete headers"))
|
|
}
|
|
|
|
fn parse_status(buf: &[u8]) -> io::Result<StatusCode> {
|
|
let header_end = buf
|
|
.windows(4)
|
|
.position(|window| window == b"\r\n\r\n")
|
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing headers"))?;
|
|
let header = &buf[..header_end];
|
|
let header_str = std::str::from_utf8(header)
|
|
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "headers not utf8"))?;
|
|
let status_line = header_str
|
|
.lines()
|
|
.next()
|
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing status line"))?;
|
|
let status_str = status_line.trim();
|
|
let parts: Vec<_> = status_str.split_whitespace().collect();
|
|
if parts.len() < 2 {
|
|
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid status line"));
|
|
}
|
|
let code: u16 = parts[1]
|
|
.parse()
|
|
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid status code"))?;
|
|
StatusCode::from_u16(code)
|
|
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "unknown status code"))
|
|
}
|
|
|
|
fn split_body(buf: &[u8]) -> &[u8] {
|
|
if let Some(idx) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
|
|
&buf[idx + 4..]
|
|
} else {
|
|
&[]
|
|
}
|
|
}
|