View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.resolver.dns;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.channel.AddressedEnvelope;
20  import io.netty.channel.Channel;
21  import io.netty.channel.ChannelFuture;
22  import io.netty.channel.ChannelFutureListener;
23  import io.netty.channel.ChannelHandlerContext;
24  import io.netty.channel.ChannelInboundHandlerAdapter;
25  import io.netty.channel.ChannelPromise;
26  import io.netty.handler.codec.dns.AbstractDnsOptPseudoRrRecord;
27  import io.netty.handler.codec.dns.DnsQuery;
28  import io.netty.handler.codec.dns.DnsQuestion;
29  import io.netty.handler.codec.dns.DnsRecord;
30  import io.netty.handler.codec.dns.DnsRecordType;
31  import io.netty.handler.codec.dns.DnsResponse;
32  import io.netty.handler.codec.dns.DnsSection;
33  import io.netty.handler.codec.dns.TcpDnsQueryEncoder;
34  import io.netty.handler.codec.dns.TcpDnsResponseDecoder;
35  import io.netty.util.ReferenceCountUtil;
36  import io.netty.util.concurrent.Future;
37  import io.netty.util.concurrent.FutureListener;
38  import io.netty.util.concurrent.Promise;
39  import io.netty.util.internal.SystemPropertyUtil;
40  import io.netty.util.internal.ThrowableUtil;
41  import io.netty.util.internal.logging.InternalLogger;
42  import io.netty.util.internal.logging.InternalLoggerFactory;
43  
44  import java.net.InetSocketAddress;
45  import java.net.SocketAddress;
46  import java.util.concurrent.CancellationException;
47  import java.util.concurrent.TimeUnit;
48  
49  import static io.netty.util.internal.ObjectUtil.checkNotNull;
50  
51  abstract class DnsQueryContext {
52  
53      private static final InternalLogger logger = InternalLoggerFactory.getInstance(DnsQueryContext.class);
54      private static final long ID_REUSE_ON_TIMEOUT_DELAY_MILLIS;
55  
56      static {
57          ID_REUSE_ON_TIMEOUT_DELAY_MILLIS =
58                  SystemPropertyUtil.getLong("io.netty.resolver.dns.idReuseOnTimeoutDelayMillis", 10000);
59          logger.debug("-Dio.netty.resolver.dns.idReuseOnTimeoutDelayMillis: {}", ID_REUSE_ON_TIMEOUT_DELAY_MILLIS);
60      }
61  
62      private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder();
63  
64      private final Channel channel;
65      private final InetSocketAddress nameServerAddr;
66      private final DnsQueryContextManager queryContextManager;
67      private final DnsQueryLifecycleObserver queryLifecycleObserver;
68      private final Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise;
69  
70      private final DnsQuestion question;
71      private final DnsRecord[] additionals;
72      private final DnsRecord optResource;
73  
74      private final boolean recursionDesired;
75  
76      private final Bootstrap socketBootstrap;
77  
78      private final boolean retryWithTcpOnTimeout;
79      private final long queryTimeoutMillis;
80  
81      private volatile Future<?> timeoutFuture;
82  
83      private int id = Integer.MIN_VALUE;
84  
85      DnsQueryContext(Channel channel,
86                      InetSocketAddress nameServerAddr,
87                      DnsQueryContextManager queryContextManager,
88                      DnsQueryLifecycleObserver queryLifecycleObserver,
89                      int maxPayLoadSize,
90                      boolean recursionDesired,
91                      long queryTimeoutMillis,
92                      DnsQuestion question,
93                      DnsRecord[] additionals,
94                      Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise,
95                      Bootstrap socketBootstrap,
96                      boolean retryWithTcpOnTimeout) {
97          this.channel = checkNotNull(channel, "channel");
98          this.queryContextManager = checkNotNull(queryContextManager, "queryContextManager");
99          this.queryLifecycleObserver = checkNotNull(queryLifecycleObserver, "queryLifecycleObserver");
100         this.nameServerAddr = checkNotNull(nameServerAddr, "nameServerAddr");
101         this.question = checkNotNull(question, "question");
102         this.additionals = checkNotNull(additionals, "additionals");
103         this.promise = checkNotNull(promise, "promise");
104         this.recursionDesired = recursionDesired;
105         this.queryTimeoutMillis = queryTimeoutMillis;
106         this.socketBootstrap = socketBootstrap;
107         this.retryWithTcpOnTimeout = retryWithTcpOnTimeout;
108 
109         if (maxPayLoadSize > 0 &&
110                 // Only add the extra OPT record if there is not already one. This is required as only one is allowed
111                 // as per RFC:
112                 //  - https://datatracker.ietf.org/doc/html/rfc6891#section-6.1.1
113                 !hasOptRecord(additionals)) {
114             optResource = new AbstractDnsOptPseudoRrRecord(maxPayLoadSize, 0, 0) {
115                 // We may want to remove this in the future and let the user just specify the opt record in the query.
116             };
117         } else {
118             optResource = null;
119         }
120     }
121 
122     private static boolean hasOptRecord(DnsRecord[] additionals) {
123         if (additionals != null && additionals.length > 0) {
124             for (DnsRecord additional: additionals) {
125                 if (additional.type() == DnsRecordType.OPT) {
126                     return true;
127                 }
128             }
129         }
130         return false;
131     }
132 
133     /**
134      * Returns {@code true} if the query was completed already.
135      *
136      * @return {@code true} if done.
137      */
138     final boolean isDone() {
139         return promise.isDone();
140     }
141 
142     /**
143      * Returns the {@link DnsQuestion} that will be written as part of the {@link DnsQuery}.
144      *
145      * @return the question.
146      */
147     final DnsQuestion question() {
148         return question;
149     }
150 
151     /**
152      * Creates and returns a new {@link DnsQuery}.
153      *
154      * @param id                the transaction id to use.
155      * @param nameServerAddr    the nameserver to which the query will be send.
156      * @return                  the new query.
157      */
158     protected abstract DnsQuery newQuery(int id, InetSocketAddress nameServerAddr);
159 
160     /**
161      * Returns the protocol that is used for the query.
162      *
163      * @return  the protocol.
164      */
165     protected abstract String protocol();
166 
167     /**
168      * Write the query and return the {@link ChannelFuture} that is completed once the write completes.
169      *
170      * @param flush                 {@code true} if {@link Channel#flush()} should be called as well.
171      */
172     final void writeQuery(boolean flush) {
173         assert id == Integer.MIN_VALUE : this.getClass().getSimpleName() +
174                 ".writeQuery(...) can only be executed once.";
175 
176         if ((id = queryContextManager.add(nameServerAddr, this)) == -1) {
177             // We did exhaust the id space, fail the query
178             IllegalStateException e = new IllegalStateException("query ID space exhausted: " + question());
179             finishFailure("failed to send a query via " + protocol(), e, false);
180             queryLifecycleObserver.queryWritten(nameServerAddr, channel.newFailedFuture(e));
181         }
182 
183         // Ensure we remove the id from the QueryContextManager once the query completes.
184         promise.addListener((FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>) future -> {
185             // Cancel the timeout task.
186             Future<?> timeoutFuture = DnsQueryContext.this.timeoutFuture;
187             if (timeoutFuture != null) {
188                 DnsQueryContext.this.timeoutFuture = null;
189                 timeoutFuture.cancel(false);
190             }
191 
192             Throwable cause = future.cause();
193             if (cause instanceof DnsNameResolverTimeoutException || cause instanceof CancellationException) {
194                 // This query was failed due a timeout or cancellation. Let's delay the removal of the id to reduce
195                 // the risk of reusing the same id again while the remote nameserver might send the response after
196                 // the timeout.
197                 channel.eventLoop().schedule(new Runnable() {
198                     @Override
199                     public void run() {
200                         removeFromContextManager(nameServerAddr);
201                     }
202                 }, ID_REUSE_ON_TIMEOUT_DELAY_MILLIS, TimeUnit.MILLISECONDS);
203             } else {
204                 // Remove the id from the manager as soon as the query completes. This may be because of success,
205                 // failure or cancellation
206                 removeFromContextManager(nameServerAddr);
207             }
208         });
209         final DnsQuestion question = question();
210         final DnsQuery query = newQuery(id, nameServerAddr);
211 
212         query.setRecursionDesired(recursionDesired);
213 
214         query.addRecord(DnsSection.QUESTION, question);
215 
216         for (DnsRecord record: additionals) {
217             query.addRecord(DnsSection.ADDITIONAL, record);
218         }
219 
220         if (optResource != null) {
221             query.addRecord(DnsSection.ADDITIONAL, optResource);
222         }
223 
224         if (logger.isDebugEnabled()) {
225             logger.debug("{} WRITE: {}, [{}: {}], {}",
226                     channel, protocol(), id, nameServerAddr, question);
227         }
228 
229         ChannelFuture f = sendQuery(query, flush);
230         queryLifecycleObserver.queryWritten(nameServerAddr, f);
231     }
232 
233     private void removeFromContextManager(InetSocketAddress nameServerAddr) {
234         DnsQueryContext self = queryContextManager.remove(nameServerAddr, id);
235 
236         assert self == this : "Removed DnsQueryContext is not the correct instance";
237     }
238 
239     private ChannelFuture sendQuery(final DnsQuery query, final boolean flush) {
240         final ChannelPromise writePromise = channel.newPromise();
241         writeQuery(query, flush, writePromise);
242         return writePromise;
243     }
244 
245     private void writeQuery(final DnsQuery query,
246                             final boolean flush, ChannelPromise promise) {
247         final ChannelFuture writeFuture = flush ? channel.writeAndFlush(query, promise) :
248                 channel.write(query, promise);
249         if (writeFuture.isDone()) {
250             onQueryWriteCompletion(queryTimeoutMillis, writeFuture);
251         } else {
252             writeFuture.addListener((ChannelFutureListener) future ->
253                     onQueryWriteCompletion(queryTimeoutMillis, future));
254         }
255     }
256 
257     private void onQueryWriteCompletion(final long queryTimeoutMillis,
258                                         ChannelFuture writeFuture) {
259         if (!writeFuture.isSuccess()) {
260             finishFailure("failed to send a query '" + id + "' via " + protocol(), writeFuture.cause(), false);
261             return;
262         }
263 
264         // Schedule a query timeout task if necessary.
265         if (queryTimeoutMillis > 0) {
266             timeoutFuture = channel.eventLoop().schedule(new Runnable() {
267                 @Override
268                 public void run() {
269                     if (promise.isDone()) {
270                         // Received a response before the query times out.
271                         return;
272                     }
273 
274                     finishFailure("query '" + id + "' via " + protocol() + " timed out after " +
275                             queryTimeoutMillis + " milliseconds", null, true);
276                 }
277             }, queryTimeoutMillis, TimeUnit.MILLISECONDS);
278         }
279     }
280 
281     /**
282      * Notifies the original {@link Promise} that the response for the query was received.
283      * This method takes ownership of passed {@link AddressedEnvelope}.
284      */
285     void finishSuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope, boolean truncated) {
286         // Check if the response was not truncated or if a fallback to TCP is possible.
287         if (!truncated || !retryWithTcp(envelope)) {
288             final DnsResponse res = envelope.content();
289             if (res.count(DnsSection.QUESTION) != 1) {
290                 logger.warn("{} Received a DNS response with invalid number of questions. Expected: 1, found: {}",
291                         channel, envelope);
292             } else if (!question().equals(res.recordAt(DnsSection.QUESTION))) {
293                 logger.warn("{} Received a mismatching DNS response. Expected: [{}], found: {}",
294                         channel, question(), envelope);
295             } else if (trySuccess(envelope)) {
296                 return; // Ownership transferred, don't release
297             }
298             envelope.release();
299         }
300     }
301 
302     @SuppressWarnings("unchecked")
303     private boolean trySuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
304         return promise.trySuccess((AddressedEnvelope<DnsResponse, InetSocketAddress>) envelope);
305     }
306 
307     /**
308      * Notifies the original {@link Promise} that the query completes because of an failure.
309      */
310     final boolean finishFailure(String message, Throwable cause, boolean timeout) {
311         if (promise.isDone()) {
312             return false;
313         }
314         final DnsQuestion question = question();
315 
316         final StringBuilder buf = new StringBuilder(message.length() + 128);
317         buf.append('[')
318            .append(id)
319            .append(": ")
320            .append(nameServerAddr)
321            .append("] ")
322            .append(question)
323            .append(' ')
324            .append(message)
325            .append(" (no stack trace available)");
326 
327         final DnsNameResolverException e;
328         if (timeout) {
329             // This was caused by a timeout so use DnsNameResolverTimeoutException to allow the user to
330             // handle it special (like retry the query).
331             e = new DnsNameResolverTimeoutException(nameServerAddr, question, buf.toString());
332             if (retryWithTcpOnTimeout && retryWithTcp(e)) {
333                 // We did successfully retry with TCP.
334                 return false;
335             }
336         } else {
337             e = new DnsNameResolverException(nameServerAddr, question, buf.toString(), cause);
338         }
339         return promise.tryFailure(e);
340     }
341 
342     /**
343      * Retry the original query with TCP if possible.
344      *
345      * @param originalResult    the result of the original {@link DnsQueryContext}.
346      * @return                  {@code true} if retry via TCP is supported and so the ownership of
347      *                          {@code originalResult} was transferred, {@code false} otherwise.
348      */
349     private boolean retryWithTcp(final Object originalResult) {
350         if (socketBootstrap == null) {
351             return false;
352         }
353 
354         socketBootstrap.connect(nameServerAddr).addListener((ChannelFutureListener) future -> {
355             if (!future.isSuccess()) {
356                 logger.debug("{} Unable to fallback to TCP [{}: {}]",
357                         future.channel(), id, nameServerAddr, future.cause());
358                 // TCP fallback failed, just use the truncated response or error.
359                 finishOriginal(originalResult, future);
360                 return;
361             }
362             final Channel tcpCh = future.channel();
363             Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
364                     tcpCh.eventLoop().newPromise();
365             final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(tcpCh,
366                     (InetSocketAddress) tcpCh.remoteAddress(), queryContextManager, queryLifecycleObserver, 0,
367                     recursionDesired, queryTimeoutMillis, question(), additionals, promise);
368             tcpCh.pipeline().addLast(TCP_ENCODER);
369             tcpCh.pipeline().addLast(new TcpDnsResponseDecoder());
370             tcpCh.pipeline().addLast(new ChannelInboundHandlerAdapter() {
371                 @Override
372                 public void channelRead(ChannelHandlerContext ctx, Object msg) {
373                     Channel tcpCh = ctx.channel();
374                     DnsResponse response = (DnsResponse) msg;
375                     int queryId = response.id();
376 
377                     if (logger.isDebugEnabled()) {
378                         logger.debug("{} RECEIVED: TCP [{}: {}], {}", tcpCh, queryId,
379                                 tcpCh.remoteAddress(), response);
380                     }
381 
382                     DnsQueryContext foundCtx = queryContextManager.get(nameServerAddr, queryId);
383                     if (foundCtx != null && foundCtx.isDone()) {
384                         logger.debug("{} Received a DNS response for a query that was timed out or cancelled " +
385                                 ": TCP [{}: {}]", tcpCh, queryId, nameServerAddr);
386                         response.release();
387                     } else if (foundCtx == tcpCtx) {
388                         tcpCtx.finishSuccess(new AddressedEnvelopeAdapter(
389                                 (InetSocketAddress) ctx.channel().remoteAddress(),
390                                 (InetSocketAddress) ctx.channel().localAddress(),
391                                 response), false);
392                     } else {
393                         response.release();
394                         tcpCtx.finishFailure("Received TCP DNS response with unexpected ID", null, false);
395                         if (logger.isDebugEnabled()) {
396                             logger.debug("{} Received a DNS response with an unexpected ID: TCP [{}: {}]",
397                                     tcpCh, queryId, tcpCh.remoteAddress());
398                         }
399                     }
400                 }
401 
402                 @Override
403                 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
404                     if (tcpCtx.finishFailure(
405                             "TCP fallback error", cause, false) && logger.isDebugEnabled()) {
406                         logger.debug("{} Error during processing response: TCP [{}: {}]",
407                                 ctx.channel(), id,
408                                 ctx.channel().remoteAddress(), cause);
409                     }
410                 }
411             });
412 
413             promise.addListener(
414                     (FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>) future1 -> {
415                         if (future1.isSuccess()) {
416                             finishSuccess(future1.getNow(), false);
417                             // Release the original result.
418                             ReferenceCountUtil.release(originalResult);
419                         } else {
420                             // TCP fallback failed, just use the truncated response or error.
421                             finishOriginal(originalResult, future1);
422                         }
423                         tcpCh.close();
424                     });
425             tcpCtx.writeQuery(true);
426         });
427         return true;
428     }
429 
430     @SuppressWarnings("unchecked")
431     private void finishOriginal(Object originalResult, Future<?> future) {
432         if (originalResult instanceof Throwable) {
433             Throwable error = (Throwable) originalResult;
434             ThrowableUtil.addSuppressed(error, future.cause());
435             promise.tryFailure(error);
436         } else {
437             finishSuccess((AddressedEnvelope<? extends DnsResponse, InetSocketAddress>) originalResult, false);
438         }
439     }
440 
441     private static final class AddressedEnvelopeAdapter implements AddressedEnvelope<DnsResponse, InetSocketAddress> {
442         private final InetSocketAddress sender;
443         private final InetSocketAddress recipient;
444         private final DnsResponse response;
445 
446         AddressedEnvelopeAdapter(InetSocketAddress sender, InetSocketAddress recipient, DnsResponse response) {
447             this.sender = sender;
448             this.recipient = recipient;
449             this.response = response;
450         }
451 
452         @Override
453         public DnsResponse content() {
454             return response;
455         }
456 
457         @Override
458         public InetSocketAddress sender() {
459             return sender;
460         }
461 
462         @Override
463         public InetSocketAddress recipient() {
464             return recipient;
465         }
466 
467         @Override
468         public AddressedEnvelope<DnsResponse, InetSocketAddress> retain() {
469             response.retain();
470             return this;
471         }
472 
473         @Override
474         public AddressedEnvelope<DnsResponse, InetSocketAddress> retain(int increment) {
475             response.retain(increment);
476             return this;
477         }
478 
479         @Override
480         public AddressedEnvelope<DnsResponse, InetSocketAddress> touch() {
481             response.touch();
482             return this;
483         }
484 
485         @Override
486         public AddressedEnvelope<DnsResponse, InetSocketAddress> touch(Object hint) {
487             response.touch(hint);
488             return this;
489         }
490 
491         @Override
492         public int refCnt() {
493             return response.refCnt();
494         }
495 
496         @Override
497         public boolean release() {
498             return response.release();
499         }
500 
501         @Override
502         public boolean release(int decrement) {
503             return response.release(decrement);
504         }
505 
506         @Override
507         public boolean equals(Object obj) {
508             if (this == obj) {
509                 return true;
510             }
511 
512             if (!(obj instanceof AddressedEnvelope)) {
513                 return false;
514             }
515 
516             @SuppressWarnings("unchecked")
517             final AddressedEnvelope<?, SocketAddress> that = (AddressedEnvelope<?, SocketAddress>) obj;
518             if (sender() == null) {
519                 if (that.sender() != null) {
520                     return false;
521                 }
522             } else if (!sender().equals(that.sender())) {
523                 return false;
524             }
525 
526             if (recipient() == null) {
527                 if (that.recipient() != null) {
528                     return false;
529                 }
530             } else if (!recipient().equals(that.recipient())) {
531                 return false;
532             }
533 
534             return response.equals(obj);
535         }
536 
537         @Override
538         public int hashCode() {
539             int hashCode = response.hashCode();
540             if (sender() != null) {
541                 hashCode = hashCode * 31 + sender().hashCode();
542             }
543             if (recipient() != null) {
544                 hashCode = hashCode * 31 + recipient().hashCode();
545             }
546             return hashCode;
547         }
548     }
549 }