View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.resolver.dns;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.channel.AddressedEnvelope;
20  import io.netty.channel.Channel;
21  import io.netty.channel.ChannelFuture;
22  import io.netty.channel.ChannelFutureListener;
23  import io.netty.channel.ChannelHandlerContext;
24  import io.netty.channel.ChannelInboundHandlerAdapter;
25  import io.netty.channel.ChannelPromise;
26  import io.netty.handler.codec.dns.AbstractDnsOptPseudoRrRecord;
27  import io.netty.handler.codec.dns.DnsQuery;
28  import io.netty.handler.codec.dns.DnsQuestion;
29  import io.netty.handler.codec.dns.DnsRecord;
30  import io.netty.handler.codec.dns.DnsRecordType;
31  import io.netty.handler.codec.dns.DnsResponse;
32  import io.netty.handler.codec.dns.DnsSection;
33  import io.netty.handler.codec.dns.TcpDnsQueryEncoder;
34  import io.netty.handler.codec.dns.TcpDnsResponseDecoder;
35  import io.netty.util.ReferenceCountUtil;
36  import io.netty.util.concurrent.Future;
37  import io.netty.util.concurrent.FutureListener;
38  import io.netty.util.concurrent.GenericFutureListener;
39  import io.netty.util.concurrent.Promise;
40  import io.netty.util.internal.SystemPropertyUtil;
41  import io.netty.util.internal.ThrowableUtil;
42  import io.netty.util.internal.logging.InternalLogger;
43  import io.netty.util.internal.logging.InternalLoggerFactory;
44  
45  import java.net.InetSocketAddress;
46  import java.net.SocketAddress;
47  import java.util.concurrent.CancellationException;
48  import java.util.concurrent.TimeUnit;
49  
50  import static io.netty.util.internal.ObjectUtil.checkNotNull;
51  
52  abstract class DnsQueryContext {
53  
54      private static final InternalLogger logger = InternalLoggerFactory.getInstance(DnsQueryContext.class);
55      private static final long ID_REUSE_ON_TIMEOUT_DELAY_MILLIS;
56  
57      static {
58          ID_REUSE_ON_TIMEOUT_DELAY_MILLIS =
59                  SystemPropertyUtil.getLong("io.netty.resolver.dns.idReuseOnTimeoutDelayMillis", 10000);
60          logger.debug("-Dio.netty.resolver.dns.idReuseOnTimeoutDelayMillis: {}", ID_REUSE_ON_TIMEOUT_DELAY_MILLIS);
61      }
62  
63      private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder();
64  
65      private final Future<? extends Channel> channelReadyFuture;
66      private final Channel channel;
67      private final InetSocketAddress nameServerAddr;
68      private final DnsQueryContextManager queryContextManager;
69      private final Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise;
70  
71      private final DnsQuestion question;
72      private final DnsRecord[] additionals;
73      private final DnsRecord optResource;
74  
75      private final boolean recursionDesired;
76  
77      private final Bootstrap socketBootstrap;
78  
79      private final boolean retryWithTcpOnTimeout;
80      private final long queryTimeoutMillis;
81  
82      private volatile Future<?> timeoutFuture;
83  
84      private int id = Integer.MIN_VALUE;
85  
86      DnsQueryContext(Channel channel,
87                      Future<? extends Channel> channelReadyFuture,
88                      InetSocketAddress nameServerAddr,
89                      DnsQueryContextManager queryContextManager,
90                      int maxPayLoadSize,
91                      boolean recursionDesired,
92                      long queryTimeoutMillis,
93                      DnsQuestion question,
94                      DnsRecord[] additionals,
95                      Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise,
96                      Bootstrap socketBootstrap,
97                      boolean retryWithTcpOnTimeout) {
98          this.channel = checkNotNull(channel, "channel");
99          this.queryContextManager = checkNotNull(queryContextManager, "queryContextManager");
100         this.channelReadyFuture = checkNotNull(channelReadyFuture, "channelReadyFuture");
101         this.nameServerAddr = checkNotNull(nameServerAddr, "nameServerAddr");
102         this.question = checkNotNull(question, "question");
103         this.additionals = checkNotNull(additionals, "additionals");
104         this.promise = checkNotNull(promise, "promise");
105         this.recursionDesired = recursionDesired;
106         this.queryTimeoutMillis = queryTimeoutMillis;
107         this.socketBootstrap = socketBootstrap;
108         this.retryWithTcpOnTimeout = retryWithTcpOnTimeout;
109 
110         if (maxPayLoadSize > 0 &&
111                 // Only add the extra OPT record if there is not already one. This is required as only one is allowed
112                 // as per RFC:
113                 //  - https://datatracker.ietf.org/doc/html/rfc6891#section-6.1.1
114                 !hasOptRecord(additionals)) {
115             optResource = new AbstractDnsOptPseudoRrRecord(maxPayLoadSize, 0, 0) {
116                 // We may want to remove this in the future and let the user just specify the opt record in the query.
117             };
118         } else {
119             optResource = null;
120         }
121     }
122 
123     private static boolean hasOptRecord(DnsRecord[] additionals) {
124         if (additionals != null && additionals.length > 0) {
125             for (DnsRecord additional: additionals) {
126                 if (additional.type() == DnsRecordType.OPT) {
127                     return true;
128                 }
129             }
130         }
131         return false;
132     }
133 
134     /**
135      * Returns {@code true} if the query was completed already.
136      *
137      * @return {@code true} if done.
138      */
139     final boolean isDone() {
140         return promise.isDone();
141     }
142 
143     /**
144      * Returns the {@link DnsQuestion} that will be written as part of the {@link DnsQuery}.
145      *
146      * @return the question.
147      */
148     final DnsQuestion question() {
149         return question;
150     }
151 
152     /**
153      * Creates and returns a new {@link DnsQuery}.
154      *
155      * @param id                the transaction id to use.
156      * @param nameServerAddr    the nameserver to which the query will be send.
157      * @return                  the new query.
158      */
159     protected abstract DnsQuery newQuery(int id, InetSocketAddress nameServerAddr);
160 
161     /**
162      * Returns the protocol that is used for the query.
163      *
164      * @return  the protocol.
165      */
166     protected abstract String protocol();
167 
168     /**
169      * Write the query and return the {@link ChannelFuture} that is completed once the write completes.
170      *
171      * @param flush                 {@code true} if {@link Channel#flush()} should be called as well.
172      * @return                      the {@link ChannelFuture} that is notified once once the write completes.
173      */
174     final ChannelFuture writeQuery(boolean flush) {
175         assert id == Integer.MIN_VALUE : this.getClass().getSimpleName() +
176                 ".writeQuery(...) can only be executed once.";
177 
178         if ((id = queryContextManager.add(nameServerAddr, this)) == -1) {
179             // We did exhaust the id space, fail the query
180             IllegalStateException e = new IllegalStateException("query ID space exhausted: " + question());
181             finishFailure("failed to send a query via " + protocol(), e, false);
182             return channel.newFailedFuture(e);
183         }
184 
185         // Ensure we remove the id from the QueryContextManager once the query completes.
186         promise.addListener(new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
187             @Override
188             public void operationComplete(Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
189                 // Cancel the timeout task.
190                 Future<?> timeoutFuture = DnsQueryContext.this.timeoutFuture;
191                 if (timeoutFuture != null) {
192                     DnsQueryContext.this.timeoutFuture = null;
193                     timeoutFuture.cancel(false);
194                 }
195 
196                 Throwable cause = future.cause();
197                 if (cause instanceof DnsNameResolverTimeoutException || cause instanceof CancellationException) {
198                     // This query was failed due a timeout or cancellation. Let's delay the removal of the id to reduce
199                     // the risk of reusing the same id again while the remote nameserver might send the response after
200                     // the timeout.
201                     channel.eventLoop().schedule(new Runnable() {
202                         @Override
203                         public void run() {
204                             removeFromContextManager(nameServerAddr);
205                         }
206                     }, ID_REUSE_ON_TIMEOUT_DELAY_MILLIS, TimeUnit.MILLISECONDS);
207                 } else {
208                     // Remove the id from the manager as soon as the query completes. This may be because of success,
209                     // failure or cancellation
210                     removeFromContextManager(nameServerAddr);
211                 }
212             }
213         });
214         final DnsQuestion question = question();
215         final DnsQuery query = newQuery(id, nameServerAddr);
216 
217         query.setRecursionDesired(recursionDesired);
218 
219         query.addRecord(DnsSection.QUESTION, question);
220 
221         for (DnsRecord record: additionals) {
222             query.addRecord(DnsSection.ADDITIONAL, record);
223         }
224 
225         if (optResource != null) {
226             query.addRecord(DnsSection.ADDITIONAL, optResource);
227         }
228 
229         if (logger.isDebugEnabled()) {
230             logger.debug("{} WRITE: {}, [{}: {}], {}",
231                     channel, protocol(), id, nameServerAddr, question);
232         }
233 
234         return sendQuery(query, flush);
235     }
236 
237     private void removeFromContextManager(InetSocketAddress nameServerAddr) {
238         DnsQueryContext self = queryContextManager.remove(nameServerAddr, id);
239 
240         assert self == this : "Removed DnsQueryContext is not the correct instance";
241     }
242 
243     private ChannelFuture sendQuery(final DnsQuery query, final boolean flush) {
244         final ChannelPromise writePromise = channel.newPromise();
245         if (channelReadyFuture.isSuccess()) {
246             writeQuery(query, flush, writePromise);
247         } else {
248             Throwable cause = channelReadyFuture.cause();
249             if (cause != null) {
250                 // the promise failed before so we should also fail this query.
251                 failQuery(query, cause, writePromise);
252             } else {
253                 // The promise is not complete yet, let's delay the query.
254                 channelReadyFuture.addListener(new GenericFutureListener<Future<? super Channel>>() {
255                     @Override
256                     public void operationComplete(Future<? super Channel> future) {
257                         if (future.isSuccess()) {
258                             // If the query is done in a late fashion (as the channel was not ready yet) we always flush
259                             // to ensure we did not race with a previous flush() that was done when the Channel was not
260                             // ready yet.
261                             writeQuery(query, true, writePromise);
262                         } else {
263                             Throwable cause = future.cause();
264                             failQuery(query, cause, writePromise);
265                         }
266                     }
267                 });
268             }
269         }
270         return writePromise;
271     }
272 
273     private void failQuery(DnsQuery query, Throwable cause, ChannelPromise writePromise) {
274         try {
275             promise.tryFailure(cause);
276             writePromise.tryFailure(cause);
277         } finally {
278             ReferenceCountUtil.release(query);
279         }
280     }
281 
282     private void writeQuery(final DnsQuery query,
283                             final boolean flush, ChannelPromise promise) {
284         final ChannelFuture writeFuture = flush ? channel.writeAndFlush(query, promise) :
285                 channel.write(query, promise);
286         if (writeFuture.isDone()) {
287             onQueryWriteCompletion(queryTimeoutMillis, writeFuture);
288         } else {
289             writeFuture.addListener(new ChannelFutureListener() {
290                 @Override
291                 public void operationComplete(ChannelFuture future) {
292                     onQueryWriteCompletion(queryTimeoutMillis, writeFuture);
293                 }
294             });
295         }
296     }
297 
298     private void onQueryWriteCompletion(final long queryTimeoutMillis,
299                                         ChannelFuture writeFuture) {
300         if (!writeFuture.isSuccess()) {
301             finishFailure("failed to send a query '" + id + "' via " + protocol(), writeFuture.cause(), false);
302             return;
303         }
304 
305         // Schedule a query timeout task if necessary.
306         if (queryTimeoutMillis > 0) {
307             timeoutFuture = channel.eventLoop().schedule(new Runnable() {
308                 @Override
309                 public void run() {
310                     if (promise.isDone()) {
311                         // Received a response before the query times out.
312                         return;
313                     }
314 
315                     finishFailure("query '" + id + "' via " + protocol() + " timed out after " +
316                             queryTimeoutMillis + " milliseconds", null, true);
317                 }
318             }, queryTimeoutMillis, TimeUnit.MILLISECONDS);
319         }
320     }
321 
322     /**
323      * Notifies the original {@link Promise} that the response for the query was received.
324      * This method takes ownership of passed {@link AddressedEnvelope}.
325      */
326     void finishSuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope, boolean truncated) {
327         // Check if the response was not truncated or if a fallback to TCP is possible.
328         if (!truncated || !retryWithTcp(envelope)) {
329             final DnsResponse res = envelope.content();
330             if (res.count(DnsSection.QUESTION) != 1) {
331                 logger.warn("{} Received a DNS response with invalid number of questions. Expected: 1, found: {}",
332                         channel, envelope);
333             } else if (!question().equals(res.recordAt(DnsSection.QUESTION))) {
334                 logger.warn("{} Received a mismatching DNS response. Expected: [{}], found: {}",
335                         channel, question(), envelope);
336             } else if (trySuccess(envelope)) {
337                 return; // Ownership transferred, don't release
338             }
339             envelope.release();
340         }
341     }
342 
343     @SuppressWarnings("unchecked")
344     private boolean trySuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
345         return promise.trySuccess((AddressedEnvelope<DnsResponse, InetSocketAddress>) envelope);
346     }
347 
348     /**
349      * Notifies the original {@link Promise} that the query completes because of an failure.
350      */
351     final boolean finishFailure(String message, Throwable cause, boolean timeout) {
352         if (promise.isDone()) {
353             return false;
354         }
355         final DnsQuestion question = question();
356 
357         final StringBuilder buf = new StringBuilder(message.length() + 128);
358         buf.append('[')
359            .append(id)
360            .append(": ")
361            .append(nameServerAddr)
362            .append("] ")
363            .append(question)
364            .append(' ')
365            .append(message)
366            .append(" (no stack trace available)");
367 
368         final DnsNameResolverException e;
369         if (timeout) {
370             // This was caused by a timeout so use DnsNameResolverTimeoutException to allow the user to
371             // handle it special (like retry the query).
372             e = new DnsNameResolverTimeoutException(nameServerAddr, question, buf.toString());
373             if (retryWithTcpOnTimeout && retryWithTcp(e)) {
374                 // We did successfully retry with TCP.
375                 return false;
376             }
377         } else {
378             e = new DnsNameResolverException(nameServerAddr, question, buf.toString(), cause);
379         }
380         return promise.tryFailure(e);
381     }
382 
383     /**
384      * Retry the original query with TCP if possible.
385      *
386      * @param originalResult    the result of the original {@link DnsQueryContext}.
387      * @return                  {@code true} if retry via TCP is supported and so the ownership of
388      *                          {@code originalResult} was transferred, {@code false} otherwise.
389      */
390     private boolean retryWithTcp(final Object originalResult) {
391         if (socketBootstrap == null) {
392             return false;
393         }
394 
395         socketBootstrap.connect(nameServerAddr).addListener(new ChannelFutureListener() {
396             @Override
397             public void operationComplete(ChannelFuture future) {
398                 if (!future.isSuccess()) {
399                     logger.debug("{} Unable to fallback to TCP [{}: {}]",
400                             future.channel(), id, nameServerAddr, future.cause());
401 
402                     // TCP fallback failed, just use the truncated response or error.
403                     finishOriginal(originalResult, future);
404                     return;
405                 }
406                 final Channel tcpCh = future.channel();
407                 Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
408                         tcpCh.eventLoop().newPromise();
409                 final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(tcpCh, channelReadyFuture,
410                         (InetSocketAddress) tcpCh.remoteAddress(), queryContextManager, 0,
411                         recursionDesired, queryTimeoutMillis, question(), additionals, promise);
412                 tcpCh.pipeline().addLast(TCP_ENCODER);
413                 tcpCh.pipeline().addLast(new TcpDnsResponseDecoder());
414                 tcpCh.pipeline().addLast(new ChannelInboundHandlerAdapter() {
415                     @Override
416                     public void channelRead(ChannelHandlerContext ctx, Object msg) {
417                         Channel tcpCh = ctx.channel();
418                         DnsResponse response = (DnsResponse) msg;
419                         int queryId = response.id();
420 
421                         if (logger.isDebugEnabled()) {
422                             logger.debug("{} RECEIVED: TCP [{}: {}], {}", tcpCh, queryId,
423                                     tcpCh.remoteAddress(), response);
424                         }
425 
426                         DnsQueryContext foundCtx = queryContextManager.get(nameServerAddr, queryId);
427                         if (foundCtx != null && foundCtx.isDone()) {
428                             logger.debug("{} Received a DNS response for a query that was timed out or cancelled " +
429                                     ": TCP [{}: {}]", tcpCh, queryId, nameServerAddr);
430                             response.release();
431                         } else if (foundCtx == tcpCtx) {
432                             tcpCtx.finishSuccess(new AddressedEnvelopeAdapter(
433                                     (InetSocketAddress) ctx.channel().remoteAddress(),
434                                     (InetSocketAddress) ctx.channel().localAddress(),
435                                     response), false);
436                         } else {
437                             response.release();
438                             tcpCtx.finishFailure("Received TCP DNS response with unexpected ID", null, false);
439                             if (logger.isDebugEnabled()) {
440                                 logger.debug("{} Received a DNS response with an unexpected ID: TCP [{}: {}]",
441                                         tcpCh, queryId, tcpCh.remoteAddress());
442                             }
443                         }
444                     }
445 
446                     @Override
447                     public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
448                         if (tcpCtx.finishFailure(
449                                 "TCP fallback error", cause, false) && logger.isDebugEnabled()) {
450                             logger.debug("{} Error during processing response: TCP [{}: {}]",
451                                     ctx.channel(), id,
452                                     ctx.channel().remoteAddress(), cause);
453                         }
454                     }
455                 });
456 
457                 promise.addListener(
458                         new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
459                             @Override
460                             public void operationComplete(
461                                     Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
462                                 if (future.isSuccess()) {
463                                     finishSuccess(future.getNow(), false);
464                                     // Release the original result.
465                                     ReferenceCountUtil.release(originalResult);
466                                 } else {
467                                     // TCP fallback failed, just use the truncated response or error.
468                                     finishOriginal(originalResult, future);
469                                 }
470                                 tcpCh.close();
471                             }
472                         });
473                 tcpCtx.writeQuery(true);
474             }
475         });
476         return true;
477     }
478 
479     @SuppressWarnings("unchecked")
480     private void finishOriginal(Object originalResult, Future<?> future) {
481         if (originalResult instanceof Throwable) {
482             Throwable error = (Throwable) originalResult;
483             ThrowableUtil.addSuppressed(error, future.cause());
484             promise.tryFailure(error);
485         } else {
486             finishSuccess((AddressedEnvelope<? extends DnsResponse, InetSocketAddress>) originalResult, false);
487         }
488     }
489 
490     private static final class AddressedEnvelopeAdapter implements AddressedEnvelope<DnsResponse, InetSocketAddress> {
491         private final InetSocketAddress sender;
492         private final InetSocketAddress recipient;
493         private final DnsResponse response;
494 
495         AddressedEnvelopeAdapter(InetSocketAddress sender, InetSocketAddress recipient, DnsResponse response) {
496             this.sender = sender;
497             this.recipient = recipient;
498             this.response = response;
499         }
500 
501         @Override
502         public DnsResponse content() {
503             return response;
504         }
505 
506         @Override
507         public InetSocketAddress sender() {
508             return sender;
509         }
510 
511         @Override
512         public InetSocketAddress recipient() {
513             return recipient;
514         }
515 
516         @Override
517         public AddressedEnvelope<DnsResponse, InetSocketAddress> retain() {
518             response.retain();
519             return this;
520         }
521 
522         @Override
523         public AddressedEnvelope<DnsResponse, InetSocketAddress> retain(int increment) {
524             response.retain(increment);
525             return this;
526         }
527 
528         @Override
529         public AddressedEnvelope<DnsResponse, InetSocketAddress> touch() {
530             response.touch();
531             return this;
532         }
533 
534         @Override
535         public AddressedEnvelope<DnsResponse, InetSocketAddress> touch(Object hint) {
536             response.touch(hint);
537             return this;
538         }
539 
540         @Override
541         public int refCnt() {
542             return response.refCnt();
543         }
544 
545         @Override
546         public boolean release() {
547             return response.release();
548         }
549 
550         @Override
551         public boolean release(int decrement) {
552             return response.release(decrement);
553         }
554 
555         @Override
556         public boolean equals(Object obj) {
557             if (this == obj) {
558                 return true;
559             }
560 
561             if (!(obj instanceof AddressedEnvelope)) {
562                 return false;
563             }
564 
565             @SuppressWarnings("unchecked")
566             final AddressedEnvelope<?, SocketAddress> that = (AddressedEnvelope<?, SocketAddress>) obj;
567             if (sender() == null) {
568                 if (that.sender() != null) {
569                     return false;
570                 }
571             } else if (!sender().equals(that.sender())) {
572                 return false;
573             }
574 
575             if (recipient() == null) {
576                 if (that.recipient() != null) {
577                     return false;
578                 }
579             } else if (!recipient().equals(that.recipient())) {
580                 return false;
581             }
582 
583             return response.equals(obj);
584         }
585 
586         @Override
587         public int hashCode() {
588             int hashCode = response.hashCode();
589             if (sender() != null) {
590                 hashCode = hashCode * 31 + sender().hashCode();
591             }
592             if (recipient() != null) {
593                 hashCode = hashCode * 31 + recipient().hashCode();
594             }
595             return hashCode;
596         }
597     }
598 }