View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.resolver.dns;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.channel.AddressedEnvelope;
20  import io.netty.channel.Channel;
21  import io.netty.channel.ChannelFuture;
22  import io.netty.channel.ChannelFutureListener;
23  import io.netty.channel.ChannelHandlerContext;
24  import io.netty.channel.ChannelInboundHandlerAdapter;
25  import io.netty.channel.ChannelPromise;
26  import io.netty.handler.codec.dns.AbstractDnsOptPseudoRrRecord;
27  import io.netty.handler.codec.dns.DnsQuery;
28  import io.netty.handler.codec.dns.DnsQuestion;
29  import io.netty.handler.codec.dns.DnsRecord;
30  import io.netty.handler.codec.dns.DnsRecordType;
31  import io.netty.handler.codec.dns.DnsResponse;
32  import io.netty.handler.codec.dns.DnsSection;
33  import io.netty.handler.codec.dns.TcpDnsQueryEncoder;
34  import io.netty.handler.codec.dns.TcpDnsResponseDecoder;
35  import io.netty.util.ReferenceCountUtil;
36  import io.netty.util.concurrent.Future;
37  import io.netty.util.concurrent.FutureListener;
38  import io.netty.util.concurrent.Promise;
39  import io.netty.util.internal.SystemPropertyUtil;
40  import io.netty.util.internal.ThrowableUtil;
41  import io.netty.util.internal.logging.InternalLogger;
42  import io.netty.util.internal.logging.InternalLoggerFactory;
43  
44  import java.net.InetSocketAddress;
45  import java.net.SocketAddress;
46  import java.util.concurrent.CancellationException;
47  import java.util.concurrent.TimeUnit;
48  
49  import static io.netty.util.internal.ObjectUtil.checkNotNull;
50  
51  abstract class DnsQueryContext {
52  
53      private static final InternalLogger logger = InternalLoggerFactory.getInstance(DnsQueryContext.class);
54      private static final long ID_REUSE_ON_TIMEOUT_DELAY_MILLIS;
55  
56      static {
57          ID_REUSE_ON_TIMEOUT_DELAY_MILLIS =
58                  SystemPropertyUtil.getLong("io.netty.resolver.dns.idReuseOnTimeoutDelayMillis", 10000);
59          logger.debug("-Dio.netty.resolver.dns.idReuseOnTimeoutDelayMillis: {}", ID_REUSE_ON_TIMEOUT_DELAY_MILLIS);
60      }
61  
62      private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder();
63  
64      private final Channel channel;
65      private final InetSocketAddress nameServerAddr;
66      private final DnsQueryContextManager queryContextManager;
67      private final Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise;
68  
69      private final DnsQuestion question;
70      private final DnsRecord[] additionals;
71      private final DnsRecord optResource;
72  
73      private final boolean recursionDesired;
74  
75      private final Bootstrap socketBootstrap;
76  
77      private final boolean retryWithTcpOnTimeout;
78      private final long queryTimeoutMillis;
79  
80      private volatile Future<?> timeoutFuture;
81  
82      private int id = Integer.MIN_VALUE;
83  
84      DnsQueryContext(Channel channel,
85                      InetSocketAddress nameServerAddr,
86                      DnsQueryContextManager queryContextManager,
87                      int maxPayLoadSize,
88                      boolean recursionDesired,
89                      long queryTimeoutMillis,
90                      DnsQuestion question,
91                      DnsRecord[] additionals,
92                      Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise,
93                      Bootstrap socketBootstrap,
94                      boolean retryWithTcpOnTimeout) {
95          this.channel = checkNotNull(channel, "channel");
96          this.queryContextManager = checkNotNull(queryContextManager, "queryContextManager");
97          this.nameServerAddr = checkNotNull(nameServerAddr, "nameServerAddr");
98          this.question = checkNotNull(question, "question");
99          this.additionals = checkNotNull(additionals, "additionals");
100         this.promise = checkNotNull(promise, "promise");
101         this.recursionDesired = recursionDesired;
102         this.queryTimeoutMillis = queryTimeoutMillis;
103         this.socketBootstrap = socketBootstrap;
104         this.retryWithTcpOnTimeout = retryWithTcpOnTimeout;
105 
106         if (maxPayLoadSize > 0 &&
107                 // Only add the extra OPT record if there is not already one. This is required as only one is allowed
108                 // as per RFC:
109                 //  - https://datatracker.ietf.org/doc/html/rfc6891#section-6.1.1
110                 !hasOptRecord(additionals)) {
111             optResource = new AbstractDnsOptPseudoRrRecord(maxPayLoadSize, 0, 0) {
112                 // We may want to remove this in the future and let the user just specify the opt record in the query.
113             };
114         } else {
115             optResource = null;
116         }
117     }
118 
119     private static boolean hasOptRecord(DnsRecord[] additionals) {
120         if (additionals != null && additionals.length > 0) {
121             for (DnsRecord additional: additionals) {
122                 if (additional.type() == DnsRecordType.OPT) {
123                     return true;
124                 }
125             }
126         }
127         return false;
128     }
129 
130     /**
131      * Returns {@code true} if the query was completed already.
132      *
133      * @return {@code true} if done.
134      */
135     final boolean isDone() {
136         return promise.isDone();
137     }
138 
139     /**
140      * Returns the {@link DnsQuestion} that will be written as part of the {@link DnsQuery}.
141      *
142      * @return the question.
143      */
144     final DnsQuestion question() {
145         return question;
146     }
147 
148     /**
149      * Creates and returns a new {@link DnsQuery}.
150      *
151      * @param id                the transaction id to use.
152      * @param nameServerAddr    the nameserver to which the query will be send.
153      * @return                  the new query.
154      */
155     protected abstract DnsQuery newQuery(int id, InetSocketAddress nameServerAddr);
156 
157     /**
158      * Returns the protocol that is used for the query.
159      *
160      * @return  the protocol.
161      */
162     protected abstract String protocol();
163 
164     /**
165      * Write the query and return the {@link ChannelFuture} that is completed once the write completes.
166      *
167      * @param flush                 {@code true} if {@link Channel#flush()} should be called as well.
168      * @return                      the {@link ChannelFuture} that is notified once once the write completes.
169      */
170     final ChannelFuture writeQuery(boolean flush) {
171         assert id == Integer.MIN_VALUE : this.getClass().getSimpleName() +
172                 ".writeQuery(...) can only be executed once.";
173 
174         if ((id = queryContextManager.add(nameServerAddr, this)) == -1) {
175             // We did exhaust the id space, fail the query
176             IllegalStateException e = new IllegalStateException("query ID space exhausted: " + question());
177             finishFailure("failed to send a query via " + protocol(), e, false);
178             return channel.newFailedFuture(e);
179         }
180 
181         // Ensure we remove the id from the QueryContextManager once the query completes.
182         promise.addListener(new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
183             @Override
184             public void operationComplete(Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
185                 // Cancel the timeout task.
186                 Future<?> timeoutFuture = DnsQueryContext.this.timeoutFuture;
187                 if (timeoutFuture != null) {
188                     DnsQueryContext.this.timeoutFuture = null;
189                     timeoutFuture.cancel(false);
190                 }
191 
192                 Throwable cause = future.cause();
193                 if (cause instanceof DnsNameResolverTimeoutException || cause instanceof CancellationException) {
194                     // This query was failed due a timeout or cancellation. Let's delay the removal of the id to reduce
195                     // the risk of reusing the same id again while the remote nameserver might send the response after
196                     // the timeout.
197                     channel.eventLoop().schedule(new Runnable() {
198                         @Override
199                         public void run() {
200                             removeFromContextManager(nameServerAddr);
201                         }
202                     }, ID_REUSE_ON_TIMEOUT_DELAY_MILLIS, TimeUnit.MILLISECONDS);
203                 } else {
204                     // Remove the id from the manager as soon as the query completes. This may be because of success,
205                     // failure or cancellation
206                     removeFromContextManager(nameServerAddr);
207                 }
208             }
209         });
210         final DnsQuestion question = question();
211         final DnsQuery query = newQuery(id, nameServerAddr);
212 
213         query.setRecursionDesired(recursionDesired);
214 
215         query.addRecord(DnsSection.QUESTION, question);
216 
217         for (DnsRecord record: additionals) {
218             query.addRecord(DnsSection.ADDITIONAL, record);
219         }
220 
221         if (optResource != null) {
222             query.addRecord(DnsSection.ADDITIONAL, optResource);
223         }
224 
225         if (logger.isDebugEnabled()) {
226             logger.debug("{} WRITE: {}, [{}: {}], {}",
227                     channel, protocol(), id, nameServerAddr, question);
228         }
229 
230         return sendQuery(query, flush);
231     }
232 
233     private void removeFromContextManager(InetSocketAddress nameServerAddr) {
234         DnsQueryContext self = queryContextManager.remove(nameServerAddr, id);
235 
236         assert self == this : "Removed DnsQueryContext is not the correct instance";
237     }
238 
239     private ChannelFuture sendQuery(final DnsQuery query, final boolean flush) {
240         final ChannelPromise writePromise = channel.newPromise();
241         writeQuery(query, flush, writePromise);
242         return writePromise;
243     }
244 
245     private void writeQuery(final DnsQuery query,
246                             final boolean flush, ChannelPromise promise) {
247         final ChannelFuture writeFuture = flush ? channel.writeAndFlush(query, promise) :
248                 channel.write(query, promise);
249         if (writeFuture.isDone()) {
250             onQueryWriteCompletion(queryTimeoutMillis, writeFuture);
251         } else {
252             writeFuture.addListener(new ChannelFutureListener() {
253                 @Override
254                 public void operationComplete(ChannelFuture future) {
255                     onQueryWriteCompletion(queryTimeoutMillis, writeFuture);
256                 }
257             });
258         }
259     }
260 
261     private void onQueryWriteCompletion(final long queryTimeoutMillis,
262                                         ChannelFuture writeFuture) {
263         if (!writeFuture.isSuccess()) {
264             finishFailure("failed to send a query '" + id + "' via " + protocol(), writeFuture.cause(), false);
265             return;
266         }
267 
268         // Schedule a query timeout task if necessary.
269         if (queryTimeoutMillis > 0) {
270             timeoutFuture = channel.eventLoop().schedule(new Runnable() {
271                 @Override
272                 public void run() {
273                     if (promise.isDone()) {
274                         // Received a response before the query times out.
275                         return;
276                     }
277 
278                     finishFailure("query '" + id + "' via " + protocol() + " timed out after " +
279                             queryTimeoutMillis + " milliseconds", null, true);
280                 }
281             }, queryTimeoutMillis, TimeUnit.MILLISECONDS);
282         }
283     }
284 
285     /**
286      * Notifies the original {@link Promise} that the response for the query was received.
287      * This method takes ownership of passed {@link AddressedEnvelope}.
288      */
289     void finishSuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope, boolean truncated) {
290         // Check if the response was not truncated or if a fallback to TCP is possible.
291         if (!truncated || !retryWithTcp(envelope)) {
292             final DnsResponse res = envelope.content();
293             if (res.count(DnsSection.QUESTION) != 1) {
294                 logger.warn("{} Received a DNS response with invalid number of questions. Expected: 1, found: {}",
295                         channel, envelope);
296             } else if (!question().equals(res.recordAt(DnsSection.QUESTION))) {
297                 logger.warn("{} Received a mismatching DNS response. Expected: [{}], found: {}",
298                         channel, question(), envelope);
299             } else if (trySuccess(envelope)) {
300                 return; // Ownership transferred, don't release
301             }
302             envelope.release();
303         }
304     }
305 
306     @SuppressWarnings("unchecked")
307     private boolean trySuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
308         return promise.trySuccess((AddressedEnvelope<DnsResponse, InetSocketAddress>) envelope);
309     }
310 
311     /**
312      * Notifies the original {@link Promise} that the query completes because of an failure.
313      */
314     final boolean finishFailure(String message, Throwable cause, boolean timeout) {
315         if (promise.isDone()) {
316             return false;
317         }
318         final DnsQuestion question = question();
319 
320         final StringBuilder buf = new StringBuilder(message.length() + 128);
321         buf.append('[')
322            .append(id)
323            .append(": ")
324            .append(nameServerAddr)
325            .append("] ")
326            .append(question)
327            .append(' ')
328            .append(message)
329            .append(" (no stack trace available)");
330 
331         final DnsNameResolverException e;
332         if (timeout) {
333             // This was caused by a timeout so use DnsNameResolverTimeoutException to allow the user to
334             // handle it special (like retry the query).
335             e = new DnsNameResolverTimeoutException(nameServerAddr, question, buf.toString());
336             if (retryWithTcpOnTimeout && retryWithTcp(e)) {
337                 // We did successfully retry with TCP.
338                 return false;
339             }
340         } else {
341             e = new DnsNameResolverException(nameServerAddr, question, buf.toString(), cause);
342         }
343         return promise.tryFailure(e);
344     }
345 
346     /**
347      * Retry the original query with TCP if possible.
348      *
349      * @param originalResult    the result of the original {@link DnsQueryContext}.
350      * @return                  {@code true} if retry via TCP is supported and so the ownership of
351      *                          {@code originalResult} was transferred, {@code false} otherwise.
352      */
353     private boolean retryWithTcp(final Object originalResult) {
354         if (socketBootstrap == null) {
355             return false;
356         }
357 
358         socketBootstrap.connect(nameServerAddr).addListener(new ChannelFutureListener() {
359             @Override
360             public void operationComplete(ChannelFuture future) {
361                 if (!future.isSuccess()) {
362                     logger.debug("{} Unable to fallback to TCP [{}: {}]",
363                             future.channel(), id, nameServerAddr, future.cause());
364 
365                     // TCP fallback failed, just use the truncated response or error.
366                     finishOriginal(originalResult, future);
367                     return;
368                 }
369                 final Channel tcpCh = future.channel();
370                 Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
371                         tcpCh.eventLoop().newPromise();
372                 final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(tcpCh,
373                         (InetSocketAddress) tcpCh.remoteAddress(), queryContextManager, 0,
374                         recursionDesired, queryTimeoutMillis, question(), additionals, promise);
375                 tcpCh.pipeline().addLast(TCP_ENCODER);
376                 tcpCh.pipeline().addLast(new TcpDnsResponseDecoder());
377                 tcpCh.pipeline().addLast(new ChannelInboundHandlerAdapter() {
378                     @Override
379                     public void channelRead(ChannelHandlerContext ctx, Object msg) {
380                         Channel tcpCh = ctx.channel();
381                         DnsResponse response = (DnsResponse) msg;
382                         int queryId = response.id();
383 
384                         if (logger.isDebugEnabled()) {
385                             logger.debug("{} RECEIVED: TCP [{}: {}], {}", tcpCh, queryId,
386                                     tcpCh.remoteAddress(), response);
387                         }
388 
389                         DnsQueryContext foundCtx = queryContextManager.get(nameServerAddr, queryId);
390                         if (foundCtx != null && foundCtx.isDone()) {
391                             logger.debug("{} Received a DNS response for a query that was timed out or cancelled " +
392                                     ": TCP [{}: {}]", tcpCh, queryId, nameServerAddr);
393                             response.release();
394                         } else if (foundCtx == tcpCtx) {
395                             tcpCtx.finishSuccess(new AddressedEnvelopeAdapter(
396                                     (InetSocketAddress) ctx.channel().remoteAddress(),
397                                     (InetSocketAddress) ctx.channel().localAddress(),
398                                     response), false);
399                         } else {
400                             response.release();
401                             tcpCtx.finishFailure("Received TCP DNS response with unexpected ID", null, false);
402                             if (logger.isDebugEnabled()) {
403                                 logger.debug("{} Received a DNS response with an unexpected ID: TCP [{}: {}]",
404                                         tcpCh, queryId, tcpCh.remoteAddress());
405                             }
406                         }
407                     }
408 
409                     @Override
410                     public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
411                         if (tcpCtx.finishFailure(
412                                 "TCP fallback error", cause, false) && logger.isDebugEnabled()) {
413                             logger.debug("{} Error during processing response: TCP [{}: {}]",
414                                     ctx.channel(), id,
415                                     ctx.channel().remoteAddress(), cause);
416                         }
417                     }
418                 });
419 
420                 promise.addListener(
421                         new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
422                             @Override
423                             public void operationComplete(
424                                     Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
425                                 if (future.isSuccess()) {
426                                     finishSuccess(future.getNow(), false);
427                                     // Release the original result.
428                                     ReferenceCountUtil.release(originalResult);
429                                 } else {
430                                     // TCP fallback failed, just use the truncated response or error.
431                                     finishOriginal(originalResult, future);
432                                 }
433                                 tcpCh.close();
434                             }
435                         });
436                 tcpCtx.writeQuery(true);
437             }
438         });
439         return true;
440     }
441 
442     @SuppressWarnings("unchecked")
443     private void finishOriginal(Object originalResult, Future<?> future) {
444         if (originalResult instanceof Throwable) {
445             Throwable error = (Throwable) originalResult;
446             ThrowableUtil.addSuppressed(error, future.cause());
447             promise.tryFailure(error);
448         } else {
449             finishSuccess((AddressedEnvelope<? extends DnsResponse, InetSocketAddress>) originalResult, false);
450         }
451     }
452 
453     private static final class AddressedEnvelopeAdapter implements AddressedEnvelope<DnsResponse, InetSocketAddress> {
454         private final InetSocketAddress sender;
455         private final InetSocketAddress recipient;
456         private final DnsResponse response;
457 
458         AddressedEnvelopeAdapter(InetSocketAddress sender, InetSocketAddress recipient, DnsResponse response) {
459             this.sender = sender;
460             this.recipient = recipient;
461             this.response = response;
462         }
463 
464         @Override
465         public DnsResponse content() {
466             return response;
467         }
468 
469         @Override
470         public InetSocketAddress sender() {
471             return sender;
472         }
473 
474         @Override
475         public InetSocketAddress recipient() {
476             return recipient;
477         }
478 
479         @Override
480         public AddressedEnvelope<DnsResponse, InetSocketAddress> retain() {
481             response.retain();
482             return this;
483         }
484 
485         @Override
486         public AddressedEnvelope<DnsResponse, InetSocketAddress> retain(int increment) {
487             response.retain(increment);
488             return this;
489         }
490 
491         @Override
492         public AddressedEnvelope<DnsResponse, InetSocketAddress> touch() {
493             response.touch();
494             return this;
495         }
496 
497         @Override
498         public AddressedEnvelope<DnsResponse, InetSocketAddress> touch(Object hint) {
499             response.touch(hint);
500             return this;
501         }
502 
503         @Override
504         public int refCnt() {
505             return response.refCnt();
506         }
507 
508         @Override
509         public boolean release() {
510             return response.release();
511         }
512 
513         @Override
514         public boolean release(int decrement) {
515             return response.release(decrement);
516         }
517 
518         @Override
519         public boolean equals(Object obj) {
520             if (this == obj) {
521                 return true;
522             }
523 
524             if (!(obj instanceof AddressedEnvelope)) {
525                 return false;
526             }
527 
528             @SuppressWarnings("unchecked")
529             final AddressedEnvelope<?, SocketAddress> that = (AddressedEnvelope<?, SocketAddress>) obj;
530             if (sender() == null) {
531                 if (that.sender() != null) {
532                     return false;
533                 }
534             } else if (!sender().equals(that.sender())) {
535                 return false;
536             }
537 
538             if (recipient() == null) {
539                 if (that.recipient() != null) {
540                     return false;
541                 }
542             } else if (!recipient().equals(that.recipient())) {
543                 return false;
544             }
545 
546             return response.equals(obj);
547         }
548 
549         @Override
550         public int hashCode() {
551             int hashCode = response.hashCode();
552             if (sender() != null) {
553                 hashCode = hashCode * 31 + sender().hashCode();
554             }
555             if (recipient() != null) {
556                 hashCode = hashCode * 31 + recipient().hashCode();
557             }
558             return hashCode;
559         }
560     }
561 }