1use alloc::boxed::Box;
9use alloc::string::String;
10use alloc::sync::Arc;
11use core::fmt::{self, Display};
12use core::future::Future;
13use core::ops::DerefMut;
14use core::pin::Pin;
15use core::str::FromStr;
16use core::task::{Context, Poll};
17use std::io;
18use std::net::SocketAddr;
19
20use bytes::{Buf, Bytes, BytesMut};
21use futures_util::future::{FutureExt, TryFutureExt};
22use futures_util::ready;
23use futures_util::stream::Stream;
24use h2::client::{Connection, SendRequest};
25use http::header::{self, CONTENT_LENGTH};
26use rustls::ClientConfig;
27use rustls::pki_types::ServerName;
28use tokio::time::{error, timeout};
29use tokio_rustls::{TlsConnector, client::TlsStream as TokioTlsClientStream};
30use tracing::{debug, warn};
31
32use crate::error::ProtoError;
33use crate::http::Version;
34use crate::runtime::RuntimeProvider;
35use crate::runtime::iocompat::AsyncIoStdAsTokio;
36use crate::tcp::DnsTcpStream;
37use crate::xfer::{CONNECT_TIMEOUT, DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
38
39const ALPN_H2: &[u8] = b"h2";
40
41#[derive(Clone)]
43#[must_use = "futures do nothing unless polled"]
44pub struct HttpsClientStream {
45 name_server_name: Arc<str>,
47 query_path: Arc<str>,
48 name_server: SocketAddr,
49 h2: SendRequest<Bytes>,
50 is_shutdown: bool,
51}
52
53impl Display for HttpsClientStream {
54 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
55 write!(
56 formatter,
57 "HTTPS({},{})",
58 self.name_server, self.name_server_name
59 )
60 }
61}
62
63impl HttpsClientStream {
64 async fn inner_send(
65 h2: SendRequest<Bytes>,
66 message: Bytes,
67 name_server_name: Arc<str>,
68 query_path: Arc<str>,
69 ) -> Result<DnsResponse, ProtoError> {
70 let mut h2 = match h2.ready().await {
71 Ok(h2) => h2,
72 Err(err) => {
73 return Err(ProtoError::from(format!("h2 send_request error: {err}")));
75 }
76 };
77
78 let request = crate::http::request::new(
80 Version::Http2,
81 &name_server_name,
82 &query_path,
83 message.remaining(),
84 );
85
86 let request =
87 request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
88
89 debug!("request: {:#?}", request);
90
91 let (response_future, mut send_stream) = h2
93 .send_request(request, false)
94 .map_err(|err| ProtoError::from(format!("h2 send_request error: {err}")))?;
95
96 send_stream
97 .send_data(message, true)
98 .map_err(|e| ProtoError::from(format!("h2 send_data error: {e}")))?;
99
100 let mut response_stream = response_future
101 .await
102 .map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
103
104 debug!("got response: {:#?}", response_stream);
105
106 let content_length = response_stream
108 .headers()
109 .get(CONTENT_LENGTH)
110 .map(|v| v.to_str())
111 .transpose()
112 .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
113 .map(usize::from_str)
114 .transpose()
115 .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
116
117 let mut response_bytes =
121 BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4_096));
122
123 while let Some(partial_bytes) = response_stream.body_mut().data().await {
124 let partial_bytes =
125 partial_bytes.map_err(|e| ProtoError::from(format!("bad http request: {e}")))?;
126
127 debug!("got bytes: {}", partial_bytes.len());
128 response_bytes.extend(partial_bytes);
129
130 if let Some(content_length) = content_length {
132 if response_bytes.len() >= content_length {
133 break;
134 }
135 }
136 }
137
138 if let Some(content_length) = content_length {
140 if response_bytes.len() != content_length {
141 return Err(ProtoError::from(format!(
143 "expected byte length: {}, got: {}",
144 content_length,
145 response_bytes.len()
146 )));
147 }
148 }
149
150 if !response_stream.status().is_success() {
152 let error_string = String::from_utf8_lossy(response_bytes.as_ref());
153
154 return Err(ProtoError::from(format!(
156 "http unsuccessful code: {}, message: {}",
157 response_stream.status(),
158 error_string
159 )));
160 } else {
161 {
163 let content_type = response_stream
165 .headers()
166 .get(header::CONTENT_TYPE)
167 .map(|h| {
168 h.to_str().map_err(|err| {
169 ProtoError::from(format!("ContentType header not a string: {err}"))
171 })
172 })
173 .unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
174
175 if content_type != crate::http::MIME_APPLICATION_DNS {
176 return Err(ProtoError::from(format!(
177 "ContentType unsupported (must be '{}'): '{}'",
178 crate::http::MIME_APPLICATION_DNS,
179 content_type
180 )));
181 }
182 }
183 };
184
185 DnsResponse::from_buffer(response_bytes.to_vec())
187 }
188}
189
190impl DnsRequestSender for HttpsClientStream {
191 fn send_message(&mut self, mut request: DnsRequest) -> DnsResponseStream {
243 if self.is_shutdown {
244 panic!("can not send messages after stream is shutdown")
245 }
246
247 request.set_id(0);
249
250 let bytes = match request.to_vec() {
251 Ok(bytes) => bytes,
252 Err(err) => return err.into(),
253 };
254
255 Box::pin(Self::inner_send(
256 self.h2.clone(),
257 Bytes::from(bytes),
258 Arc::clone(&self.name_server_name),
259 Arc::clone(&self.query_path),
260 ))
261 .into()
262 }
263
264 fn shutdown(&mut self) {
265 self.is_shutdown = true;
266 }
267
268 fn is_shutdown(&self) -> bool {
269 self.is_shutdown
270 }
271}
272
273impl Stream for HttpsClientStream {
274 type Item = Result<(), ProtoError>;
275
276 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
277 if self.is_shutdown {
278 return Poll::Ready(None);
279 }
280
281 match self.h2.poll_ready(cx) {
283 Poll::Ready(Ok(())) => Poll::Ready(Some(Ok(()))),
284 Poll::Pending => Poll::Pending,
285 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
286 "h2 stream errored: {e}",
287 ))))),
288 }
289 }
290}
291
292#[derive(Clone)]
294pub struct HttpsClientStreamBuilder<P> {
295 provider: P,
296 client_config: Arc<ClientConfig>,
297 bind_addr: Option<SocketAddr>,
298}
299
300impl<P: RuntimeProvider> HttpsClientStreamBuilder<P> {
301 pub fn with_client_config(client_config: Arc<ClientConfig>, provider: P) -> Self {
303 Self {
304 provider,
305 client_config,
306 bind_addr: None,
307 }
308 }
309
310 pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
312 self.bind_addr = Some(bind_addr);
313 }
314
315 pub fn build(
323 mut self,
324 name_server: SocketAddr,
325 dns_name: String,
326 http_endpoint: String,
327 ) -> HttpsClientConnect<P::Tcp> {
328 if self.client_config.alpn_protocols.is_empty() {
330 let mut client_config = (*self.client_config).clone();
331 client_config.alpn_protocols = vec![ALPN_H2.to_vec()];
332
333 self.client_config = Arc::new(client_config);
334 }
335
336 let tls = TlsConfig {
337 client_config: self.client_config,
338 dns_name: Arc::from(dns_name),
339 http_endpoint: Arc::from(http_endpoint),
340 };
341
342 let connect = self.provider.connect_tcp(name_server, self.bind_addr, None);
343 HttpsClientConnect(HttpsClientConnectState::TcpConnecting {
344 connect,
345 name_server,
346 tls: Some(tls),
347 })
348 }
349}
350
351pub struct HttpsClientConnect<S>(HttpsClientConnectState<S>)
353where
354 S: DnsTcpStream;
355
356impl<S: DnsTcpStream> HttpsClientConnect<S> {
357 pub fn new<F>(
359 future: F,
360 mut client_config: Arc<ClientConfig>,
361 name_server: SocketAddr,
362 dns_name: String,
363 http_endpoint: String,
364 ) -> Self
365 where
366 S: DnsTcpStream,
367 F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
368 {
369 if client_config.alpn_protocols.is_empty() {
371 let mut client_cfg = (*client_config).clone();
372 client_cfg.alpn_protocols = vec![ALPN_H2.to_vec()];
373
374 client_config = Arc::new(client_cfg);
375 }
376
377 let tls = TlsConfig {
378 client_config,
379 dns_name: Arc::from(dns_name),
380 http_endpoint: Arc::from(http_endpoint),
381 };
382
383 Self(HttpsClientConnectState::TcpConnecting {
384 connect: Box::pin(future),
385 name_server,
386 tls: Some(tls),
387 })
388 }
389}
390
391impl<S> Future for HttpsClientConnect<S>
392where
393 S: DnsTcpStream,
394{
395 type Output = Result<HttpsClientStream, ProtoError>;
396
397 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
398 self.0.poll_unpin(cx)
399 }
400}
401
402struct TlsConfig {
403 client_config: Arc<ClientConfig>,
404 dns_name: Arc<str>,
405 http_endpoint: Arc<str>,
406}
407
408#[allow(clippy::large_enum_variant)]
409#[allow(clippy::type_complexity)]
410enum HttpsClientConnectState<S>
411where
412 S: DnsTcpStream,
413{
414 TcpConnecting {
415 connect: Pin<Box<dyn Future<Output = io::Result<S>> + Send>>,
416 name_server: SocketAddr,
417 tls: Option<TlsConfig>,
418 },
419 TlsConnecting {
420 tls: Pin<
422 Box<
423 dyn Future<
424 Output = Result<
425 Result<TokioTlsClientStream<AsyncIoStdAsTokio<S>>, io::Error>,
426 error::Elapsed,
427 >,
428 > + Send,
429 >,
430 >,
431 name_server_name: Arc<str>,
432 name_server: SocketAddr,
433 query_path: Arc<str>,
434 },
435 H2Handshake {
436 handshake: Pin<
437 Box<
438 dyn Future<
439 Output = Result<
440 (
441 SendRequest<Bytes>,
442 Connection<TokioTlsClientStream<AsyncIoStdAsTokio<S>>, Bytes>,
443 ),
444 h2::Error,
445 >,
446 > + Send,
447 >,
448 >,
449 name_server_name: Arc<str>,
450 name_server: SocketAddr,
451 query_path: Arc<str>,
452 },
453 Connected(Option<HttpsClientStream>),
454 Errored(Option<ProtoError>),
455}
456
457impl<S> Future for HttpsClientConnectState<S>
458where
459 S: DnsTcpStream,
460{
461 type Output = Result<HttpsClientStream, ProtoError>;
462
463 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
464 loop {
465 let next = match &mut *self.as_mut() {
466 Self::TcpConnecting {
467 connect,
468 name_server,
469 tls,
470 } => {
471 let tcp = ready!(connect.poll_unpin(cx))?;
472
473 debug!("tcp connection established to: {}", name_server);
474 let tls = tls
475 .take()
476 .expect("programming error, tls should not be None here");
477 let name_server_name = Arc::clone(&tls.dns_name);
478 let query_path = Arc::clone(&tls.http_endpoint);
479
480 match ServerName::try_from(&*tls.dns_name) {
481 Ok(dns_name) => Self::TlsConnecting {
482 name_server_name,
483 name_server: *name_server,
484 tls: Box::pin(timeout(
485 CONNECT_TIMEOUT,
486 TlsConnector::from(tls.client_config)
487 .connect(dns_name.to_owned(), AsyncIoStdAsTokio(tcp)),
488 )),
489 query_path,
490 },
491 Err(_) => Self::Errored(Some(ProtoError::from(format!(
492 "bad dns_name: {}",
493 &tls.dns_name
494 )))),
495 }
496 }
497 Self::TlsConnecting {
498 name_server_name,
499 name_server,
500 query_path,
501 tls,
502 } => {
503 let Ok(res) = ready!(tls.poll_unpin(cx)) else {
504 return Poll::Ready(Err(format!(
505 "TLS handshake timed out after {CONNECT_TIMEOUT:?}"
506 )
507 .into()));
508 };
509 let tls = res?;
510 debug!("tls connection established to: {}", name_server);
511 let mut handshake = h2::client::Builder::new();
512 handshake.enable_push(false);
513
514 let handshake = handshake.handshake(tls);
515 Self::H2Handshake {
516 name_server_name: Arc::clone(name_server_name),
517 name_server: *name_server,
518 query_path: Arc::clone(query_path),
519 handshake: Box::pin(handshake),
520 }
521 }
522 Self::H2Handshake {
523 name_server_name,
524 name_server,
525 query_path,
526 handshake,
527 } => {
528 let (send_request, connection) = ready!(
529 handshake
530 .poll_unpin(cx)
531 .map_err(|e| ProtoError::from(format!("h2 handshake error: {e}")))
532 )?;
533
534 debug!("h2 connection established to: {}", name_server);
536 tokio::spawn(
537 connection
538 .map_err(|e| warn!("h2 connection failed: {e}"))
539 .map(|_: Result<(), ()>| ()),
540 );
541
542 Self::Connected(Some(HttpsClientStream {
543 name_server_name: Arc::clone(name_server_name),
544 name_server: *name_server,
545 query_path: Arc::clone(query_path),
546 h2: send_request,
547 is_shutdown: false,
548 }))
549 }
550 Self::Connected(conn) => {
551 return Poll::Ready(Ok(conn.take().expect("cannot poll after complete")));
552 }
553 Self::Errored(err) => {
554 return Poll::Ready(Err(err.take().expect("cannot poll after complete")));
555 }
556 };
557
558 *self.as_mut().deref_mut() = next;
559 }
560 }
561}
562
563pub struct HttpsClientResponse(
565 Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
566);
567
568impl Future for HttpsClientResponse {
569 type Output = Result<DnsResponse, ProtoError>;
570
571 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
572 self.0.as_mut().poll(cx).map_err(ProtoError::from)
573 }
574}
575
576#[cfg(any(feature = "webpki-roots", feature = "rustls-platform-verifier"))]
577#[cfg(test)]
578mod tests {
579 use alloc::string::ToString;
580 use std::net::SocketAddr;
581
582 use rustls::KeyLogFile;
583 use test_support::subscribe;
584
585 use crate::op::{Edns, Message, Query};
586 use crate::rr::{Name, RecordType};
587 use crate::runtime::TokioRuntimeProvider;
588 use crate::rustls::client_config;
589 use crate::xfer::{DnsRequestOptions, FirstAnswer};
590
591 use super::*;
592
593 #[tokio::test]
594 async fn test_https_google() {
595 subscribe();
596
597 let google = SocketAddr::from(([8, 8, 8, 8], 443));
598 let mut request = Message::new();
599 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
600 request.add_query(query);
601 request.set_recursion_desired(true);
602 let mut edns = Edns::new();
603 edns.set_version(0);
604 edns.set_max_payload(1232);
605 *request.extensions_mut() = Some(edns);
606
607 let request = DnsRequest::new(request, DnsRequestOptions::default());
608
609 let mut client_config = client_config_h2();
610 client_config.key_log = Arc::new(KeyLogFile::new());
611
612 let provider = TokioRuntimeProvider::new();
613 let https_builder =
614 HttpsClientStreamBuilder::with_client_config(Arc::new(client_config), provider);
615 let connect =
616 https_builder.build(google, "dns.google".to_string(), "/dns-query".to_string());
617
618 let mut https = connect.await.expect("https connect failed");
619
620 let response = https
621 .send_message(request)
622 .first_answer()
623 .await
624 .expect("send_message failed");
625
626 assert!(
627 response
628 .answers()
629 .iter()
630 .any(|record| record.data().as_a().is_some())
631 );
632
633 let mut request = Message::new();
636 let query = Query::query(
637 Name::from_str("www.example.com.").unwrap(),
638 RecordType::AAAA,
639 );
640 request.add_query(query);
641 request.set_recursion_desired(true);
642 let mut edns = Edns::new();
643 edns.set_version(0);
644 edns.set_max_payload(1232);
645 *request.extensions_mut() = Some(edns);
646
647 let request = DnsRequest::new(request, DnsRequestOptions::default());
648
649 let response = https
650 .send_message(request.clone())
651 .first_answer()
652 .await
653 .expect("send_message failed");
654
655 assert!(
656 response
657 .answers()
658 .iter()
659 .any(|record| record.data().as_aaaa().is_some())
660 );
661 }
662
663 #[tokio::test]
664 async fn test_https_google_with_pure_ip_address_server() {
665 subscribe();
666
667 let google = SocketAddr::from(([8, 8, 8, 8], 443));
668 let mut request = Message::new();
669 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
670 request.add_query(query);
671 request.set_recursion_desired(true);
672 let mut edns = Edns::new();
673 edns.set_version(0);
674 edns.set_max_payload(1232);
675 *request.extensions_mut() = Some(edns);
676
677 let request = DnsRequest::new(request, DnsRequestOptions::default());
678
679 let mut client_config = client_config_h2();
680 client_config.key_log = Arc::new(KeyLogFile::new());
681
682 let provider = TokioRuntimeProvider::new();
683 let https_builder =
684 HttpsClientStreamBuilder::with_client_config(Arc::new(client_config), provider);
685 let connect =
686 https_builder.build(google, google.ip().to_string(), "/dns-query".to_string());
687
688 let mut https = connect.await.expect("https connect failed");
689
690 let response = https
691 .send_message(request)
692 .first_answer()
693 .await
694 .expect("send_message failed");
695
696 assert!(
697 response
698 .answers()
699 .iter()
700 .any(|record| record.data().as_a().is_some())
701 );
702
703 let mut request = Message::new();
706 let query = Query::query(
707 Name::from_str("www.example.com.").unwrap(),
708 RecordType::AAAA,
709 );
710 request.add_query(query);
711 request.set_recursion_desired(true);
712 let mut edns = Edns::new();
713 edns.set_version(0);
714 edns.set_max_payload(1232);
715 *request.extensions_mut() = Some(edns);
716
717 let request = DnsRequest::new(request, DnsRequestOptions::default());
718
719 let response = https
720 .send_message(request.clone())
721 .first_answer()
722 .await
723 .expect("send_message failed");
724
725 assert!(
726 response
727 .answers()
728 .iter()
729 .any(|record| record.data().as_aaaa().is_some())
730 );
731 }
732
733 #[tokio::test]
734 #[ignore = "cloudflare has been unreliable as a public test service"]
735 async fn test_https_cloudflare() {
736 subscribe();
737
738 let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
739 let mut request = Message::new();
740 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
741 request.add_query(query);
742 request.set_recursion_desired(true);
743 let mut edns = Edns::new();
744 edns.set_version(0);
745 edns.set_max_payload(1232);
746 *request.extensions_mut() = Some(edns);
747
748 let request = DnsRequest::new(request, DnsRequestOptions::default());
749
750 let client_config = client_config_h2();
751 let provider = TokioRuntimeProvider::new();
752 let https_builder =
753 HttpsClientStreamBuilder::with_client_config(Arc::new(client_config), provider);
754 let connect = https_builder.build(
755 cloudflare,
756 "cloudflare-dns.com".to_string(),
757 "/dns-query".to_string(),
758 );
759
760 let mut https = connect.await.expect("https connect failed");
761
762 let response = https
763 .send_message(request)
764 .first_answer()
765 .await
766 .expect("send_message failed");
767
768 assert!(
769 response
770 .answers()
771 .iter()
772 .any(|record| record.data().as_a().is_some())
773 );
774
775 let mut request = Message::new();
778 let query = Query::query(
779 Name::from_str("www.example.com.").unwrap(),
780 RecordType::AAAA,
781 );
782 request.add_query(query);
783 request.set_recursion_desired(true);
784 let mut edns = Edns::new();
785 edns.set_version(0);
786 edns.set_max_payload(1232);
787 *request.extensions_mut() = Some(edns);
788
789 let request = DnsRequest::new(request, DnsRequestOptions::default());
790
791 let response = https
792 .send_message(request)
793 .first_answer()
794 .await
795 .expect("send_message failed");
796
797 assert!(
798 response
799 .answers()
800 .iter()
801 .any(|record| record.data().as_aaaa().is_some())
802 );
803 }
804
805 fn client_config_h2() -> ClientConfig {
806 let mut config = client_config();
807 config.alpn_protocols = vec![ALPN_H2.to_vec()];
808 config
809 }
810}