hickory_proto/h2/
h2_client_stream.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use alloc::boxed::Box;
9use alloc::string::String;
10use alloc::sync::Arc;
11use core::fmt::{self, Display};
12use core::future::Future;
13use core::ops::DerefMut;
14use core::pin::Pin;
15use core::str::FromStr;
16use core::task::{Context, Poll};
17use std::io;
18use std::net::SocketAddr;
19
20use bytes::{Buf, Bytes, BytesMut};
21use futures_util::future::{FutureExt, TryFutureExt};
22use futures_util::ready;
23use futures_util::stream::Stream;
24use h2::client::{Connection, SendRequest};
25use http::header::{self, CONTENT_LENGTH};
26use rustls::ClientConfig;
27use rustls::pki_types::ServerName;
28use tokio::time::{error, timeout};
29use tokio_rustls::{TlsConnector, client::TlsStream as TokioTlsClientStream};
30use tracing::{debug, warn};
31
32use crate::error::ProtoError;
33use crate::http::Version;
34use crate::runtime::RuntimeProvider;
35use crate::runtime::iocompat::AsyncIoStdAsTokio;
36use crate::tcp::DnsTcpStream;
37use crate::xfer::{CONNECT_TIMEOUT, DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
38
39const ALPN_H2: &[u8] = b"h2";
40
41/// A DNS client connection for DNS-over-HTTPS
42#[derive(Clone)]
43#[must_use = "futures do nothing unless polled"]
44pub struct HttpsClientStream {
45    // Corresponds to the dns-name of the HTTPS server
46    name_server_name: Arc<str>,
47    query_path: Arc<str>,
48    name_server: SocketAddr,
49    h2: SendRequest<Bytes>,
50    is_shutdown: bool,
51}
52
53impl Display for HttpsClientStream {
54    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
55        write!(
56            formatter,
57            "HTTPS({},{})",
58            self.name_server, self.name_server_name
59        )
60    }
61}
62
63impl HttpsClientStream {
64    async fn inner_send(
65        h2: SendRequest<Bytes>,
66        message: Bytes,
67        name_server_name: Arc<str>,
68        query_path: Arc<str>,
69    ) -> Result<DnsResponse, ProtoError> {
70        let mut h2 = match h2.ready().await {
71            Ok(h2) => h2,
72            Err(err) => {
73                // TODO: make specific error
74                return Err(ProtoError::from(format!("h2 send_request error: {err}")));
75            }
76        };
77
78        // build up the http request
79        let request = crate::http::request::new(
80            Version::Http2,
81            &name_server_name,
82            &query_path,
83            message.remaining(),
84        );
85
86        let request =
87            request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
88
89        debug!("request: {:#?}", request);
90
91        // Send the request
92        let (response_future, mut send_stream) = h2
93            .send_request(request, false)
94            .map_err(|err| ProtoError::from(format!("h2 send_request error: {err}")))?;
95
96        send_stream
97            .send_data(message, true)
98            .map_err(|e| ProtoError::from(format!("h2 send_data error: {e}")))?;
99
100        let mut response_stream = response_future
101            .await
102            .map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
103
104        debug!("got response: {:#?}", response_stream);
105
106        // get the length of packet
107        let content_length = response_stream
108            .headers()
109            .get(CONTENT_LENGTH)
110            .map(|v| v.to_str())
111            .transpose()
112            .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
113            .map(usize::from_str)
114            .transpose()
115            .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
116
117        // TODO: what is a good max here?
118        // clamp(512, 4096) says make sure it is at least 512 bytes, and min 4096 says it is at most 4k
119        // just a little protection from malicious actors.
120        let mut response_bytes =
121            BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4_096));
122
123        while let Some(partial_bytes) = response_stream.body_mut().data().await {
124            let partial_bytes =
125                partial_bytes.map_err(|e| ProtoError::from(format!("bad http request: {e}")))?;
126
127            debug!("got bytes: {}", partial_bytes.len());
128            response_bytes.extend(partial_bytes);
129
130            // assert the length
131            if let Some(content_length) = content_length {
132                if response_bytes.len() >= content_length {
133                    break;
134                }
135            }
136        }
137
138        // assert the length
139        if let Some(content_length) = content_length {
140            if response_bytes.len() != content_length {
141                // TODO: make explicit error type
142                return Err(ProtoError::from(format!(
143                    "expected byte length: {}, got: {}",
144                    content_length,
145                    response_bytes.len()
146                )));
147            }
148        }
149
150        // Was it a successful request?
151        if !response_stream.status().is_success() {
152            let error_string = String::from_utf8_lossy(response_bytes.as_ref());
153
154            // TODO: make explicit error type
155            return Err(ProtoError::from(format!(
156                "http unsuccessful code: {}, message: {}",
157                response_stream.status(),
158                error_string
159            )));
160        } else {
161            // verify content type
162            {
163                // in the case that the ContentType is not specified, we assume it's the standard DNS format
164                let content_type = response_stream
165                    .headers()
166                    .get(header::CONTENT_TYPE)
167                    .map(|h| {
168                        h.to_str().map_err(|err| {
169                            // TODO: make explicit error type
170                            ProtoError::from(format!("ContentType header not a string: {err}"))
171                        })
172                    })
173                    .unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
174
175                if content_type != crate::http::MIME_APPLICATION_DNS {
176                    return Err(ProtoError::from(format!(
177                        "ContentType unsupported (must be '{}'): '{}'",
178                        crate::http::MIME_APPLICATION_DNS,
179                        content_type
180                    )));
181                }
182            }
183        };
184
185        // and finally convert the bytes into a DNS message
186        DnsResponse::from_buffer(response_bytes.to_vec())
187    }
188}
189
190impl DnsRequestSender for HttpsClientStream {
191    /// This indicates that the HTTP message was successfully sent, and we now have the response.RecvStream
192    ///
193    /// If the request fails, this will return the error, and it should be assumed that the Stream portion of
194    ///   this will have no date.
195    ///
196    /// ```text
197    /// RFC 8484              DNS Queries over HTTPS (DoH)          October 2018
198    ///
199    ///
200    /// 4.2.  The HTTP Response
201    ///
202    ///    The only response type defined in this document is "application/dns-
203    ///    message", but it is possible that other response formats will be
204    ///    defined in the future.  A DoH server MUST be able to process
205    ///    "application/dns-message" request messages.
206    ///
207    ///    Different response media types will provide more or less information
208    ///    from a DNS response.  For example, one response type might include
209    ///    information from the DNS header bytes while another might omit it.
210    ///    The amount and type of information that a media type gives are solely
211    ///    up to the format, which is not defined in this protocol.
212    ///
213    ///    Each DNS request-response pair is mapped to one HTTP exchange.  The
214    ///    responses may be processed and transported in any order using HTTP's
215    ///    multi-streaming functionality (see Section 5 of [RFC7540]).
216    ///
217    ///    Section 5.1 discusses the relationship between DNS and HTTP response
218    ///    caching.
219    ///
220    /// 4.2.1.  Handling DNS and HTTP Errors
221    ///
222    ///    DNS response codes indicate either success or failure for the DNS
223    ///    query.  A successful HTTP response with a 2xx status code (see
224    ///    Section 6.3 of [RFC7231]) is used for any valid DNS response,
225    ///    regardless of the DNS response code.  For example, a successful 2xx
226    ///    HTTP status code is used even with a DNS message whose DNS response
227    ///    code indicates failure, such as SERVFAIL or NXDOMAIN.
228    ///
229    ///    HTTP responses with non-successful HTTP status codes do not contain
230    ///    replies to the original DNS question in the HTTP request.  DoH
231    ///    clients need to use the same semantic processing of non-successful
232    ///    HTTP status codes as other HTTP clients.  This might mean that the
233    ///    DoH client retries the query with the same DoH server, such as if
234    ///    there are authorization failures (HTTP status code 401; see
235    ///    Section 3.1 of [RFC7235]).  It could also mean that the DoH client
236    ///    retries with a different DoH server, such as for unsupported media
237    ///    types (HTTP status code 415; see Section 6.5.13 of [RFC7231]), or
238    ///    where the server cannot generate a representation suitable for the
239    ///    client (HTTP status code 406; see Section 6.5.6 of [RFC7231]), and so
240    ///    on.
241    /// ```
242    fn send_message(&mut self, mut request: DnsRequest) -> DnsResponseStream {
243        if self.is_shutdown {
244            panic!("can not send messages after stream is shutdown")
245        }
246
247        // per the RFC, a zero id allows for the HTTP packet to be cached better
248        request.set_id(0);
249
250        let bytes = match request.to_vec() {
251            Ok(bytes) => bytes,
252            Err(err) => return err.into(),
253        };
254
255        Box::pin(Self::inner_send(
256            self.h2.clone(),
257            Bytes::from(bytes),
258            Arc::clone(&self.name_server_name),
259            Arc::clone(&self.query_path),
260        ))
261        .into()
262    }
263
264    fn shutdown(&mut self) {
265        self.is_shutdown = true;
266    }
267
268    fn is_shutdown(&self) -> bool {
269        self.is_shutdown
270    }
271}
272
273impl Stream for HttpsClientStream {
274    type Item = Result<(), ProtoError>;
275
276    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
277        if self.is_shutdown {
278            return Poll::Ready(None);
279        }
280
281        // just checking if the connection is ok
282        match self.h2.poll_ready(cx) {
283            Poll::Ready(Ok(())) => Poll::Ready(Some(Ok(()))),
284            Poll::Pending => Poll::Pending,
285            Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
286                "h2 stream errored: {e}",
287            ))))),
288        }
289    }
290}
291
292/// A HTTPS connection builder for DNS-over-HTTPS
293#[derive(Clone)]
294pub struct HttpsClientStreamBuilder<P> {
295    provider: P,
296    client_config: Arc<ClientConfig>,
297    bind_addr: Option<SocketAddr>,
298}
299
300impl<P: RuntimeProvider> HttpsClientStreamBuilder<P> {
301    /// Constructs a new TlsStreamBuilder with the associated ClientConfig
302    pub fn with_client_config(client_config: Arc<ClientConfig>, provider: P) -> Self {
303        Self {
304            provider,
305            client_config,
306            bind_addr: None,
307        }
308    }
309
310    /// Sets the address to connect from.
311    pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
312        self.bind_addr = Some(bind_addr);
313    }
314
315    /// Creates a new HttpsStream to the specified name_server
316    ///
317    /// # Arguments
318    ///
319    /// * `name_server` - IP and Port for the remote DNS resolver
320    /// * `dns_name` - The DNS name associated with a certificate
321    /// * `http_endpoint` - The HTTP endpoint where the remote DNS resolver provides service, typically `/dns-query`
322    pub fn build(
323        mut self,
324        name_server: SocketAddr,
325        dns_name: String,
326        http_endpoint: String,
327    ) -> HttpsClientConnect<P::Tcp> {
328        // ensure the ALPN protocol is set correctly
329        if self.client_config.alpn_protocols.is_empty() {
330            let mut client_config = (*self.client_config).clone();
331            client_config.alpn_protocols = vec![ALPN_H2.to_vec()];
332
333            self.client_config = Arc::new(client_config);
334        }
335
336        let tls = TlsConfig {
337            client_config: self.client_config,
338            dns_name: Arc::from(dns_name),
339            http_endpoint: Arc::from(http_endpoint),
340        };
341
342        let connect = self.provider.connect_tcp(name_server, self.bind_addr, None);
343        HttpsClientConnect(HttpsClientConnectState::TcpConnecting {
344            connect,
345            name_server,
346            tls: Some(tls),
347        })
348    }
349}
350
351/// A future that resolves to an HttpsClientStream
352pub struct HttpsClientConnect<S>(HttpsClientConnectState<S>)
353where
354    S: DnsTcpStream;
355
356impl<S: DnsTcpStream> HttpsClientConnect<S> {
357    /// Creates a new HttpsStream with existing connection
358    pub fn new<F>(
359        future: F,
360        mut client_config: Arc<ClientConfig>,
361        name_server: SocketAddr,
362        dns_name: String,
363        http_endpoint: String,
364    ) -> Self
365    where
366        S: DnsTcpStream,
367        F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
368    {
369        // ensure the ALPN protocol is set correctly
370        if client_config.alpn_protocols.is_empty() {
371            let mut client_cfg = (*client_config).clone();
372            client_cfg.alpn_protocols = vec![ALPN_H2.to_vec()];
373
374            client_config = Arc::new(client_cfg);
375        }
376
377        let tls = TlsConfig {
378            client_config,
379            dns_name: Arc::from(dns_name),
380            http_endpoint: Arc::from(http_endpoint),
381        };
382
383        Self(HttpsClientConnectState::TcpConnecting {
384            connect: Box::pin(future),
385            name_server,
386            tls: Some(tls),
387        })
388    }
389}
390
391impl<S> Future for HttpsClientConnect<S>
392where
393    S: DnsTcpStream,
394{
395    type Output = Result<HttpsClientStream, ProtoError>;
396
397    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
398        self.0.poll_unpin(cx)
399    }
400}
401
402struct TlsConfig {
403    client_config: Arc<ClientConfig>,
404    dns_name: Arc<str>,
405    http_endpoint: Arc<str>,
406}
407
408#[allow(clippy::large_enum_variant)]
409#[allow(clippy::type_complexity)]
410enum HttpsClientConnectState<S>
411where
412    S: DnsTcpStream,
413{
414    TcpConnecting {
415        connect: Pin<Box<dyn Future<Output = io::Result<S>> + Send>>,
416        name_server: SocketAddr,
417        tls: Option<TlsConfig>,
418    },
419    TlsConnecting {
420        // TODO: also abstract away Tokio TLS in RuntimeProvider.
421        tls: Pin<
422            Box<
423                dyn Future<
424                        Output = Result<
425                            Result<TokioTlsClientStream<AsyncIoStdAsTokio<S>>, io::Error>,
426                            error::Elapsed,
427                        >,
428                    > + Send,
429            >,
430        >,
431        name_server_name: Arc<str>,
432        name_server: SocketAddr,
433        query_path: Arc<str>,
434    },
435    H2Handshake {
436        handshake: Pin<
437            Box<
438                dyn Future<
439                        Output = Result<
440                            (
441                                SendRequest<Bytes>,
442                                Connection<TokioTlsClientStream<AsyncIoStdAsTokio<S>>, Bytes>,
443                            ),
444                            h2::Error,
445                        >,
446                    > + Send,
447            >,
448        >,
449        name_server_name: Arc<str>,
450        name_server: SocketAddr,
451        query_path: Arc<str>,
452    },
453    Connected(Option<HttpsClientStream>),
454    Errored(Option<ProtoError>),
455}
456
457impl<S> Future for HttpsClientConnectState<S>
458where
459    S: DnsTcpStream,
460{
461    type Output = Result<HttpsClientStream, ProtoError>;
462
463    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
464        loop {
465            let next = match &mut *self.as_mut() {
466                Self::TcpConnecting {
467                    connect,
468                    name_server,
469                    tls,
470                } => {
471                    let tcp = ready!(connect.poll_unpin(cx))?;
472
473                    debug!("tcp connection established to: {}", name_server);
474                    let tls = tls
475                        .take()
476                        .expect("programming error, tls should not be None here");
477                    let name_server_name = Arc::clone(&tls.dns_name);
478                    let query_path = Arc::clone(&tls.http_endpoint);
479
480                    match ServerName::try_from(&*tls.dns_name) {
481                        Ok(dns_name) => Self::TlsConnecting {
482                            name_server_name,
483                            name_server: *name_server,
484                            tls: Box::pin(timeout(
485                                CONNECT_TIMEOUT,
486                                TlsConnector::from(tls.client_config)
487                                    .connect(dns_name.to_owned(), AsyncIoStdAsTokio(tcp)),
488                            )),
489                            query_path,
490                        },
491                        Err(_) => Self::Errored(Some(ProtoError::from(format!(
492                            "bad dns_name: {}",
493                            &tls.dns_name
494                        )))),
495                    }
496                }
497                Self::TlsConnecting {
498                    name_server_name,
499                    name_server,
500                    query_path,
501                    tls,
502                } => {
503                    let Ok(res) = ready!(tls.poll_unpin(cx)) else {
504                        return Poll::Ready(Err(format!(
505                            "TLS handshake timed out after {CONNECT_TIMEOUT:?}"
506                        )
507                        .into()));
508                    };
509                    let tls = res?;
510                    debug!("tls connection established to: {}", name_server);
511                    let mut handshake = h2::client::Builder::new();
512                    handshake.enable_push(false);
513
514                    let handshake = handshake.handshake(tls);
515                    Self::H2Handshake {
516                        name_server_name: Arc::clone(name_server_name),
517                        name_server: *name_server,
518                        query_path: Arc::clone(query_path),
519                        handshake: Box::pin(handshake),
520                    }
521                }
522                Self::H2Handshake {
523                    name_server_name,
524                    name_server,
525                    query_path,
526                    handshake,
527                } => {
528                    let (send_request, connection) = ready!(
529                        handshake
530                            .poll_unpin(cx)
531                            .map_err(|e| ProtoError::from(format!("h2 handshake error: {e}")))
532                    )?;
533
534                    // TODO: hand this back for others to run rather than spawning here?
535                    debug!("h2 connection established to: {}", name_server);
536                    tokio::spawn(
537                        connection
538                            .map_err(|e| warn!("h2 connection failed: {e}"))
539                            .map(|_: Result<(), ()>| ()),
540                    );
541
542                    Self::Connected(Some(HttpsClientStream {
543                        name_server_name: Arc::clone(name_server_name),
544                        name_server: *name_server,
545                        query_path: Arc::clone(query_path),
546                        h2: send_request,
547                        is_shutdown: false,
548                    }))
549                }
550                Self::Connected(conn) => {
551                    return Poll::Ready(Ok(conn.take().expect("cannot poll after complete")));
552                }
553                Self::Errored(err) => {
554                    return Poll::Ready(Err(err.take().expect("cannot poll after complete")));
555                }
556            };
557
558            *self.as_mut().deref_mut() = next;
559        }
560    }
561}
562
563/// A future that resolves to
564pub struct HttpsClientResponse(
565    Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
566);
567
568impl Future for HttpsClientResponse {
569    type Output = Result<DnsResponse, ProtoError>;
570
571    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
572        self.0.as_mut().poll(cx).map_err(ProtoError::from)
573    }
574}
575
576#[cfg(any(feature = "webpki-roots", feature = "rustls-platform-verifier"))]
577#[cfg(test)]
578mod tests {
579    use alloc::string::ToString;
580    use std::net::SocketAddr;
581
582    use rustls::KeyLogFile;
583    use test_support::subscribe;
584
585    use crate::op::{Edns, Message, Query};
586    use crate::rr::{Name, RecordType};
587    use crate::runtime::TokioRuntimeProvider;
588    use crate::rustls::client_config;
589    use crate::xfer::{DnsRequestOptions, FirstAnswer};
590
591    use super::*;
592
593    #[tokio::test]
594    async fn test_https_google() {
595        subscribe();
596
597        let google = SocketAddr::from(([8, 8, 8, 8], 443));
598        let mut request = Message::new();
599        let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
600        request.add_query(query);
601        request.set_recursion_desired(true);
602        let mut edns = Edns::new();
603        edns.set_version(0);
604        edns.set_max_payload(1232);
605        *request.extensions_mut() = Some(edns);
606
607        let request = DnsRequest::new(request, DnsRequestOptions::default());
608
609        let mut client_config = client_config_h2();
610        client_config.key_log = Arc::new(KeyLogFile::new());
611
612        let provider = TokioRuntimeProvider::new();
613        let https_builder =
614            HttpsClientStreamBuilder::with_client_config(Arc::new(client_config), provider);
615        let connect =
616            https_builder.build(google, "dns.google".to_string(), "/dns-query".to_string());
617
618        let mut https = connect.await.expect("https connect failed");
619
620        let response = https
621            .send_message(request)
622            .first_answer()
623            .await
624            .expect("send_message failed");
625
626        assert!(
627            response
628                .answers()
629                .iter()
630                .any(|record| record.data().as_a().is_some())
631        );
632
633        //
634        // assert that the connection works for a second query
635        let mut request = Message::new();
636        let query = Query::query(
637            Name::from_str("www.example.com.").unwrap(),
638            RecordType::AAAA,
639        );
640        request.add_query(query);
641        request.set_recursion_desired(true);
642        let mut edns = Edns::new();
643        edns.set_version(0);
644        edns.set_max_payload(1232);
645        *request.extensions_mut() = Some(edns);
646
647        let request = DnsRequest::new(request, DnsRequestOptions::default());
648
649        let response = https
650            .send_message(request.clone())
651            .first_answer()
652            .await
653            .expect("send_message failed");
654
655        assert!(
656            response
657                .answers()
658                .iter()
659                .any(|record| record.data().as_aaaa().is_some())
660        );
661    }
662
663    #[tokio::test]
664    async fn test_https_google_with_pure_ip_address_server() {
665        subscribe();
666
667        let google = SocketAddr::from(([8, 8, 8, 8], 443));
668        let mut request = Message::new();
669        let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
670        request.add_query(query);
671        request.set_recursion_desired(true);
672        let mut edns = Edns::new();
673        edns.set_version(0);
674        edns.set_max_payload(1232);
675        *request.extensions_mut() = Some(edns);
676
677        let request = DnsRequest::new(request, DnsRequestOptions::default());
678
679        let mut client_config = client_config_h2();
680        client_config.key_log = Arc::new(KeyLogFile::new());
681
682        let provider = TokioRuntimeProvider::new();
683        let https_builder =
684            HttpsClientStreamBuilder::with_client_config(Arc::new(client_config), provider);
685        let connect =
686            https_builder.build(google, google.ip().to_string(), "/dns-query".to_string());
687
688        let mut https = connect.await.expect("https connect failed");
689
690        let response = https
691            .send_message(request)
692            .first_answer()
693            .await
694            .expect("send_message failed");
695
696        assert!(
697            response
698                .answers()
699                .iter()
700                .any(|record| record.data().as_a().is_some())
701        );
702
703        //
704        // assert that the connection works for a second query
705        let mut request = Message::new();
706        let query = Query::query(
707            Name::from_str("www.example.com.").unwrap(),
708            RecordType::AAAA,
709        );
710        request.add_query(query);
711        request.set_recursion_desired(true);
712        let mut edns = Edns::new();
713        edns.set_version(0);
714        edns.set_max_payload(1232);
715        *request.extensions_mut() = Some(edns);
716
717        let request = DnsRequest::new(request, DnsRequestOptions::default());
718
719        let response = https
720            .send_message(request.clone())
721            .first_answer()
722            .await
723            .expect("send_message failed");
724
725        assert!(
726            response
727                .answers()
728                .iter()
729                .any(|record| record.data().as_aaaa().is_some())
730        );
731    }
732
733    #[tokio::test]
734    #[ignore = "cloudflare has been unreliable as a public test service"]
735    async fn test_https_cloudflare() {
736        subscribe();
737
738        let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
739        let mut request = Message::new();
740        let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
741        request.add_query(query);
742        request.set_recursion_desired(true);
743        let mut edns = Edns::new();
744        edns.set_version(0);
745        edns.set_max_payload(1232);
746        *request.extensions_mut() = Some(edns);
747
748        let request = DnsRequest::new(request, DnsRequestOptions::default());
749
750        let client_config = client_config_h2();
751        let provider = TokioRuntimeProvider::new();
752        let https_builder =
753            HttpsClientStreamBuilder::with_client_config(Arc::new(client_config), provider);
754        let connect = https_builder.build(
755            cloudflare,
756            "cloudflare-dns.com".to_string(),
757            "/dns-query".to_string(),
758        );
759
760        let mut https = connect.await.expect("https connect failed");
761
762        let response = https
763            .send_message(request)
764            .first_answer()
765            .await
766            .expect("send_message failed");
767
768        assert!(
769            response
770                .answers()
771                .iter()
772                .any(|record| record.data().as_a().is_some())
773        );
774
775        //
776        // assert that the connection works for a second query
777        let mut request = Message::new();
778        let query = Query::query(
779            Name::from_str("www.example.com.").unwrap(),
780            RecordType::AAAA,
781        );
782        request.add_query(query);
783        request.set_recursion_desired(true);
784        let mut edns = Edns::new();
785        edns.set_version(0);
786        edns.set_max_payload(1232);
787        *request.extensions_mut() = Some(edns);
788
789        let request = DnsRequest::new(request, DnsRequestOptions::default());
790
791        let response = https
792            .send_message(request)
793            .first_answer()
794            .await
795            .expect("send_message failed");
796
797        assert!(
798            response
799                .answers()
800                .iter()
801                .any(|record| record.data().as_aaaa().is_some())
802        );
803    }
804
805    fn client_config_h2() -> ClientConfig {
806        let mut config = client_config();
807        config.alpn_protocols = vec![ALPN_H2.to_vec()];
808        config
809    }
810}