@@ -16,11 +16,20 @@ pub const MAX_ADDRESS_COUNT: usize = 4;
1616pub const MAX_SERVER_COUNT : usize = 4 ;
1717
1818const DNS_PORT : u16 = 53 ;
19+ const MDNS_DNS_PORT : u16 = 5353 ;
1920const MAX_NAME_LEN : usize = 255 ;
2021const RETRANSMIT_DELAY : Duration = Duration :: from_millis ( 1_000 ) ;
2122const MAX_RETRANSMIT_DELAY : Duration = Duration :: from_millis ( 10_000 ) ;
2223const RETRANSMIT_TIMEOUT : Duration = Duration :: from_millis ( 10_000 ) ; // Should generally be 2-10 secs
2324
25+ #[ cfg( feature = "proto-ipv6" ) ]
26+ const MDNS_IPV6_ADDR : IpAddress = IpAddress :: Ipv6 ( crate :: wire:: Ipv6Address ( [
27+ 0xff , 0x02 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0xfb ,
28+ ] ) ) ;
29+
30+ #[ cfg( feature = "proto-ipv4" ) ]
31+ const MDNS_IPV4_ADDR : IpAddress = IpAddress :: Ipv4 ( crate :: wire:: Ipv4Address ( [ 224 , 0 , 0 , 251 ] ) ) ;
32+
2433/// Error returned by [`Socket::start_query`]
2534#[ derive( Debug , PartialEq , Eq , Clone , Copy ) ]
2635#[ cfg_attr( feature = "defmt" , derive( defmt:: Format ) ) ]
@@ -81,6 +90,14 @@ struct PendingQuery {
8190 delay : Duration ,
8291
8392 server_idx : usize ,
93+ mdns : MulticastDns ,
94+ }
95+
96+ #[ derive( Debug ) ]
97+ pub enum MulticastDns {
98+ Disabled ,
99+ #[ cfg( feature = "socket-mdns" ) ]
100+ Enabled ,
84101}
85102
86103#[ derive( Debug ) ]
@@ -185,6 +202,7 @@ impl<'a> Socket<'a> {
185202 & mut self ,
186203 cx : & mut Context ,
187204 name : & str ,
205+ query_type : Type ,
188206 ) -> Result < QueryHandle , StartQueryError > {
189207 let mut name = name. as_bytes ( ) ;
190208
@@ -200,6 +218,13 @@ impl<'a> Socket<'a> {
200218
201219 let mut raw_name: Vec < u8 , MAX_NAME_LEN > = Vec :: new ( ) ;
202220
221+ let mut mdns = MulticastDns :: Disabled ;
222+ #[ cfg( feature = "socket-mdns" ) ]
223+ if name. split ( |& c| c == b'.' ) . last ( ) . unwrap ( ) == b"local" {
224+ net_trace ! ( "Starting a mDNS query" ) ;
225+ mdns = MulticastDns :: Enabled ;
226+ }
227+
203228 for s in name. split ( |& c| c == b'.' ) {
204229 if s. len ( ) > 63 {
205230 net_trace ! ( "invalid name: too long label" ) ;
@@ -224,7 +249,7 @@ impl<'a> Socket<'a> {
224249 . push ( 0x00 )
225250 . map_err ( |_| StartQueryError :: NameTooLong ) ?;
226251
227- self . start_query_raw ( cx, & raw_name)
252+ self . start_query_raw ( cx, & raw_name, query_type , mdns )
228253 }
229254
230255 /// Start a query with a raw (wire-format) DNS name.
@@ -235,19 +260,22 @@ impl<'a> Socket<'a> {
235260 & mut self ,
236261 cx : & mut Context ,
237262 raw_name : & [ u8 ] ,
263+ query_type : Type ,
264+ mdns : MulticastDns ,
238265 ) -> Result < QueryHandle , StartQueryError > {
239266 let handle = self . find_free_query ( ) . ok_or ( StartQueryError :: NoFreeSlot ) ?;
240267
241268 self . queries [ handle. 0 ] = Some ( DnsQuery {
242269 state : State :: Pending ( PendingQuery {
243270 name : Vec :: from_slice ( raw_name) . map_err ( |_| StartQueryError :: NameTooLong ) ?,
244- type_ : Type :: A ,
271+ type_ : query_type ,
245272 txid : cx. rand ( ) . rand_u16 ( ) ,
246273 port : cx. rand ( ) . rand_source_port ( ) ,
247274 delay : RETRANSMIT_DELAY ,
248275 timeout_at : None ,
249276 retransmit_at : Instant :: ZERO ,
250277 server_idx : 0 ,
278+ mdns,
251279 } ) ,
252280 #[ cfg( feature = "async" ) ]
253281 waker : WakerRegistration :: new ( ) ,
@@ -313,11 +341,12 @@ impl<'a> Socket<'a> {
313341 }
314342
315343 pub ( crate ) fn accepts ( & self , ip_repr : & IpRepr , udp_repr : & UdpRepr ) -> bool {
316- udp_repr. src_port == DNS_PORT
344+ ( udp_repr. src_port == DNS_PORT
317345 && self
318346 . servers
319347 . iter ( )
320- . any ( |server| * server == ip_repr. src_addr ( ) )
348+ . any ( |server| * server == ip_repr. src_addr ( ) ) )
349+ || ( udp_repr. src_port == MDNS_DNS_PORT )
321350 }
322351
323352 pub ( crate ) fn process (
@@ -482,6 +511,20 @@ impl<'a> Socket<'a> {
482511
483512 for q in self . queries . iter_mut ( ) . flatten ( ) {
484513 if let State :: Pending ( pq) = & mut q. state {
514+ // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns
515+ // so we internally overwrite the servers for any of those queries
516+ // in this function.
517+ let servers = match pq. mdns {
518+ #[ cfg( feature = "socket-mdns" ) ]
519+ MulticastDns :: Enabled => & [
520+ #[ cfg( feature = "proto-ipv6" ) ]
521+ MDNS_IPV6_ADDR ,
522+ #[ cfg( feature = "proto-ipv4" ) ]
523+ MDNS_IPV4_ADDR ,
524+ ] ,
525+ MulticastDns :: Disabled => self . servers . as_slice ( ) ,
526+ } ;
527+
485528 let timeout = if let Some ( timeout) = pq. timeout_at {
486529 timeout
487530 } else {
@@ -500,16 +543,15 @@ impl<'a> Socket<'a> {
500543 // Try next server. We check below whether we've tried all servers.
501544 pq. server_idx += 1 ;
502545 }
503-
504546 // Check if we've run out of servers to try.
505- if pq. server_idx >= self . servers . len ( ) {
547+ if pq. server_idx >= servers. len ( ) {
506548 net_trace ! ( "already tried all servers." ) ;
507549 q. set_state ( State :: Failure ) ;
508550 continue ;
509551 }
510552
511553 // Check so the IP address is valid
512- if self . servers [ pq. server_idx ] . is_unspecified ( ) {
554+ if servers[ pq. server_idx ] . is_unspecified ( ) {
513555 net_trace ! ( "invalid unspecified DNS server addr." ) ;
514556 q. set_state ( State :: Failure ) ;
515557 continue ;
@@ -526,20 +568,26 @@ impl<'a> Socket<'a> {
526568 opcode : Opcode :: Query ,
527569 question : Question {
528570 name : & pq. name ,
529- type_ : Type :: A ,
571+ type_ : pq . type_ ,
530572 } ,
531573 } ;
532574
533575 let mut payload = [ 0u8 ; 512 ] ;
534576 let payload = & mut payload[ ..repr. buffer_len ( ) ] ;
535577 repr. emit ( & mut Packet :: new_unchecked ( payload) ) ;
536578
579+ let dst_port = match pq. mdns {
580+ #[ cfg( feature = "socket-mdns" ) ]
581+ MulticastDns :: Enabled => MDNS_DNS_PORT ,
582+ MulticastDns :: Disabled => DNS_PORT ,
583+ } ;
584+
537585 let udp_repr = UdpRepr {
538586 src_port : pq. port ,
539- dst_port : 53 ,
587+ dst_port,
540588 } ;
541589
542- let dst_addr = self . servers [ pq. server_idx ] ;
590+ let dst_addr = servers[ pq. server_idx ] ;
543591 let src_addr = cx. get_source_address ( dst_addr) . unwrap ( ) ; // TODO remove unwrap
544592 let ip_repr = IpRepr :: new (
545593 src_addr,
@@ -550,7 +598,7 @@ impl<'a> Socket<'a> {
550598 ) ;
551599
552600 net_trace ! (
553- "sending {} octets to {:?}: {}" ,
601+ "sending {} octets to {} from port {}" ,
554602 payload. len( ) ,
555603 ip_repr. dst_addr( ) ,
556604 udp_repr. src_port
0 commit comments