1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.resolver.dns;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.channel.AddressedEnvelope;
20 import io.netty.channel.Channel;
21 import io.netty.channel.ChannelFuture;
22 import io.netty.channel.ChannelFutureListener;
23 import io.netty.channel.ChannelHandlerContext;
24 import io.netty.channel.ChannelInboundHandlerAdapter;
25 import io.netty.channel.ChannelPromise;
26 import io.netty.handler.codec.dns.AbstractDnsOptPseudoRrRecord;
27 import io.netty.handler.codec.dns.DnsQuery;
28 import io.netty.handler.codec.dns.DnsQuestion;
29 import io.netty.handler.codec.dns.DnsRecord;
30 import io.netty.handler.codec.dns.DnsRecordType;
31 import io.netty.handler.codec.dns.DnsResponse;
32 import io.netty.handler.codec.dns.DnsSection;
33 import io.netty.handler.codec.dns.TcpDnsQueryEncoder;
34 import io.netty.handler.codec.dns.TcpDnsResponseDecoder;
35 import io.netty.util.ReferenceCountUtil;
36 import io.netty.util.concurrent.Future;
37 import io.netty.util.concurrent.FutureListener;
38 import io.netty.util.concurrent.Promise;
39 import io.netty.util.internal.SystemPropertyUtil;
40 import io.netty.util.internal.ThrowableUtil;
41 import io.netty.util.internal.logging.InternalLogger;
42 import io.netty.util.internal.logging.InternalLoggerFactory;
43
44 import java.net.InetSocketAddress;
45 import java.net.SocketAddress;
46 import java.util.concurrent.CancellationException;
47 import java.util.concurrent.TimeUnit;
48
49 import static io.netty.util.internal.ObjectUtil.checkNotNull;
50
51 abstract class DnsQueryContext {
52
53 private static final InternalLogger logger = InternalLoggerFactory.getInstance(DnsQueryContext.class);
54 private static final long ID_REUSE_ON_TIMEOUT_DELAY_MILLIS;
55
56 static {
57 ID_REUSE_ON_TIMEOUT_DELAY_MILLIS =
58 SystemPropertyUtil.getLong("io.netty.resolver.dns.idReuseOnTimeoutDelayMillis", 10000);
59 logger.debug("-Dio.netty.resolver.dns.idReuseOnTimeoutDelayMillis: {}", ID_REUSE_ON_TIMEOUT_DELAY_MILLIS);
60 }
61
62 private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder();
63
64 private final Channel channel;
65 private final InetSocketAddress nameServerAddr;
66 private final DnsQueryContextManager queryContextManager;
67 private final Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise;
68
69 private final DnsQuestion question;
70 private final DnsRecord[] additionals;
71 private final DnsRecord optResource;
72
73 private final boolean recursionDesired;
74
75 private final Bootstrap socketBootstrap;
76
77 private final boolean retryWithTcpOnTimeout;
78 private final long queryTimeoutMillis;
79
80 private volatile Future<?> timeoutFuture;
81
82 private int id = Integer.MIN_VALUE;
83
84 DnsQueryContext(Channel channel,
85 InetSocketAddress nameServerAddr,
86 DnsQueryContextManager queryContextManager,
87 int maxPayLoadSize,
88 boolean recursionDesired,
89 long queryTimeoutMillis,
90 DnsQuestion question,
91 DnsRecord[] additionals,
92 Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise,
93 Bootstrap socketBootstrap,
94 boolean retryWithTcpOnTimeout) {
95 this.channel = checkNotNull(channel, "channel");
96 this.queryContextManager = checkNotNull(queryContextManager, "queryContextManager");
97 this.nameServerAddr = checkNotNull(nameServerAddr, "nameServerAddr");
98 this.question = checkNotNull(question, "question");
99 this.additionals = checkNotNull(additionals, "additionals");
100 this.promise = checkNotNull(promise, "promise");
101 this.recursionDesired = recursionDesired;
102 this.queryTimeoutMillis = queryTimeoutMillis;
103 this.socketBootstrap = socketBootstrap;
104 this.retryWithTcpOnTimeout = retryWithTcpOnTimeout;
105
106 if (maxPayLoadSize > 0 &&
107
108
109
110 !hasOptRecord(additionals)) {
111 optResource = new AbstractDnsOptPseudoRrRecord(maxPayLoadSize, 0, 0) {
112
113 };
114 } else {
115 optResource = null;
116 }
117 }
118
119 private static boolean hasOptRecord(DnsRecord[] additionals) {
120 if (additionals != null && additionals.length > 0) {
121 for (DnsRecord additional: additionals) {
122 if (additional.type() == DnsRecordType.OPT) {
123 return true;
124 }
125 }
126 }
127 return false;
128 }
129
130
131
132
133
134
135 final boolean isDone() {
136 return promise.isDone();
137 }
138
139
140
141
142
143
144 final DnsQuestion question() {
145 return question;
146 }
147
148
149
150
151
152
153
154
155 protected abstract DnsQuery newQuery(int id, InetSocketAddress nameServerAddr);
156
157
158
159
160
161
162 protected abstract String protocol();
163
164
165
166
167
168
169
170 final ChannelFuture writeQuery(boolean flush) {
171 assert id == Integer.MIN_VALUE : this.getClass().getSimpleName() +
172 ".writeQuery(...) can only be executed once.";
173
174 if ((id = queryContextManager.add(nameServerAddr, this)) == -1) {
175
176 IllegalStateException e = new IllegalStateException("query ID space exhausted: " + question());
177 finishFailure("failed to send a query via " + protocol(), e, false);
178 return channel.newFailedFuture(e);
179 }
180
181
182 promise.addListener(new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
183 @Override
184 public void operationComplete(Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
185
186 Future<?> timeoutFuture = DnsQueryContext.this.timeoutFuture;
187 if (timeoutFuture != null) {
188 DnsQueryContext.this.timeoutFuture = null;
189 timeoutFuture.cancel(false);
190 }
191
192 Throwable cause = future.cause();
193 if (cause instanceof DnsNameResolverTimeoutException || cause instanceof CancellationException) {
194
195
196
197 channel.eventLoop().schedule(new Runnable() {
198 @Override
199 public void run() {
200 removeFromContextManager(nameServerAddr);
201 }
202 }, ID_REUSE_ON_TIMEOUT_DELAY_MILLIS, TimeUnit.MILLISECONDS);
203 } else {
204
205
206 removeFromContextManager(nameServerAddr);
207 }
208 }
209 });
210 final DnsQuestion question = question();
211 final DnsQuery query = newQuery(id, nameServerAddr);
212
213 query.setRecursionDesired(recursionDesired);
214
215 query.addRecord(DnsSection.QUESTION, question);
216
217 for (DnsRecord record: additionals) {
218 query.addRecord(DnsSection.ADDITIONAL, record);
219 }
220
221 if (optResource != null) {
222 query.addRecord(DnsSection.ADDITIONAL, optResource);
223 }
224
225 if (logger.isDebugEnabled()) {
226 logger.debug("{} WRITE: {}, [{}: {}], {}",
227 channel, protocol(), id, nameServerAddr, question);
228 }
229
230 return sendQuery(query, flush);
231 }
232
233 private void removeFromContextManager(InetSocketAddress nameServerAddr) {
234 DnsQueryContext self = queryContextManager.remove(nameServerAddr, id);
235
236 assert self == this : "Removed DnsQueryContext is not the correct instance";
237 }
238
239 private ChannelFuture sendQuery(final DnsQuery query, final boolean flush) {
240 final ChannelPromise writePromise = channel.newPromise();
241 writeQuery(query, flush, writePromise);
242 return writePromise;
243 }
244
245 private void writeQuery(final DnsQuery query,
246 final boolean flush, ChannelPromise promise) {
247 final ChannelFuture writeFuture = flush ? channel.writeAndFlush(query, promise) :
248 channel.write(query, promise);
249 if (writeFuture.isDone()) {
250 onQueryWriteCompletion(queryTimeoutMillis, writeFuture);
251 } else {
252 writeFuture.addListener(new ChannelFutureListener() {
253 @Override
254 public void operationComplete(ChannelFuture future) {
255 onQueryWriteCompletion(queryTimeoutMillis, writeFuture);
256 }
257 });
258 }
259 }
260
261 private void onQueryWriteCompletion(final long queryTimeoutMillis,
262 ChannelFuture writeFuture) {
263 if (!writeFuture.isSuccess()) {
264 finishFailure("failed to send a query '" + id + "' via " + protocol(), writeFuture.cause(), false);
265 return;
266 }
267
268
269 if (queryTimeoutMillis > 0) {
270 timeoutFuture = channel.eventLoop().schedule(new Runnable() {
271 @Override
272 public void run() {
273 if (promise.isDone()) {
274
275 return;
276 }
277
278 finishFailure("query '" + id + "' via " + protocol() + " timed out after " +
279 queryTimeoutMillis + " milliseconds", null, true);
280 }
281 }, queryTimeoutMillis, TimeUnit.MILLISECONDS);
282 }
283 }
284
285
286
287
288
289 void finishSuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope, boolean truncated) {
290
291 if (!truncated || !retryWithTcp(envelope)) {
292 final DnsResponse res = envelope.content();
293 if (res.count(DnsSection.QUESTION) != 1) {
294 logger.warn("{} Received a DNS response with invalid number of questions. Expected: 1, found: {}",
295 channel, envelope);
296 } else if (!question().equals(res.recordAt(DnsSection.QUESTION))) {
297 logger.warn("{} Received a mismatching DNS response. Expected: [{}], found: {}",
298 channel, question(), envelope);
299 } else if (trySuccess(envelope)) {
300 return;
301 }
302 envelope.release();
303 }
304 }
305
306 @SuppressWarnings("unchecked")
307 private boolean trySuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
308 return promise.trySuccess((AddressedEnvelope<DnsResponse, InetSocketAddress>) envelope);
309 }
310
311
312
313
314 final boolean finishFailure(String message, Throwable cause, boolean timeout) {
315 if (promise.isDone()) {
316 return false;
317 }
318 final DnsQuestion question = question();
319
320 final StringBuilder buf = new StringBuilder(message.length() + 128);
321 buf.append('[')
322 .append(id)
323 .append(": ")
324 .append(nameServerAddr)
325 .append("] ")
326 .append(question)
327 .append(' ')
328 .append(message)
329 .append(" (no stack trace available)");
330
331 final DnsNameResolverException e;
332 if (timeout) {
333
334
335 e = new DnsNameResolverTimeoutException(nameServerAddr, question, buf.toString());
336 if (retryWithTcpOnTimeout && retryWithTcp(e)) {
337
338 return false;
339 }
340 } else {
341 e = new DnsNameResolverException(nameServerAddr, question, buf.toString(), cause);
342 }
343 return promise.tryFailure(e);
344 }
345
346
347
348
349
350
351
352
353 private boolean retryWithTcp(final Object originalResult) {
354 if (socketBootstrap == null) {
355 return false;
356 }
357
358 socketBootstrap.connect(nameServerAddr).addListener(new ChannelFutureListener() {
359 @Override
360 public void operationComplete(ChannelFuture future) {
361 if (!future.isSuccess()) {
362 logger.debug("{} Unable to fallback to TCP [{}: {}]",
363 future.channel(), id, nameServerAddr, future.cause());
364
365
366 finishOriginal(originalResult, future);
367 return;
368 }
369 final Channel tcpCh = future.channel();
370 Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
371 tcpCh.eventLoop().newPromise();
372 final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(tcpCh,
373 (InetSocketAddress) tcpCh.remoteAddress(), queryContextManager, 0,
374 recursionDesired, queryTimeoutMillis, question(), additionals, promise);
375 tcpCh.pipeline().addLast(TCP_ENCODER);
376 tcpCh.pipeline().addLast(new TcpDnsResponseDecoder());
377 tcpCh.pipeline().addLast(new ChannelInboundHandlerAdapter() {
378 @Override
379 public void channelRead(ChannelHandlerContext ctx, Object msg) {
380 Channel tcpCh = ctx.channel();
381 DnsResponse response = (DnsResponse) msg;
382 int queryId = response.id();
383
384 if (logger.isDebugEnabled()) {
385 logger.debug("{} RECEIVED: TCP [{}: {}], {}", tcpCh, queryId,
386 tcpCh.remoteAddress(), response);
387 }
388
389 DnsQueryContext foundCtx = queryContextManager.get(nameServerAddr, queryId);
390 if (foundCtx != null && foundCtx.isDone()) {
391 logger.debug("{} Received a DNS response for a query that was timed out or cancelled " +
392 ": TCP [{}: {}]", tcpCh, queryId, nameServerAddr);
393 response.release();
394 } else if (foundCtx == tcpCtx) {
395 tcpCtx.finishSuccess(new AddressedEnvelopeAdapter(
396 (InetSocketAddress) ctx.channel().remoteAddress(),
397 (InetSocketAddress) ctx.channel().localAddress(),
398 response), false);
399 } else {
400 response.release();
401 tcpCtx.finishFailure("Received TCP DNS response with unexpected ID", null, false);
402 if (logger.isDebugEnabled()) {
403 logger.debug("{} Received a DNS response with an unexpected ID: TCP [{}: {}]",
404 tcpCh, queryId, tcpCh.remoteAddress());
405 }
406 }
407 }
408
409 @Override
410 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
411 if (tcpCtx.finishFailure(
412 "TCP fallback error", cause, false) && logger.isDebugEnabled()) {
413 logger.debug("{} Error during processing response: TCP [{}: {}]",
414 ctx.channel(), id,
415 ctx.channel().remoteAddress(), cause);
416 }
417 }
418 });
419
420 promise.addListener(
421 new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
422 @Override
423 public void operationComplete(
424 Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
425 if (future.isSuccess()) {
426 finishSuccess(future.getNow(), false);
427
428 ReferenceCountUtil.release(originalResult);
429 } else {
430
431 finishOriginal(originalResult, future);
432 }
433 tcpCh.close();
434 }
435 });
436 tcpCtx.writeQuery(true);
437 }
438 });
439 return true;
440 }
441
442 @SuppressWarnings("unchecked")
443 private void finishOriginal(Object originalResult, Future<?> future) {
444 if (originalResult instanceof Throwable) {
445 Throwable error = (Throwable) originalResult;
446 ThrowableUtil.addSuppressed(error, future.cause());
447 promise.tryFailure(error);
448 } else {
449 finishSuccess((AddressedEnvelope<? extends DnsResponse, InetSocketAddress>) originalResult, false);
450 }
451 }
452
453 private static final class AddressedEnvelopeAdapter implements AddressedEnvelope<DnsResponse, InetSocketAddress> {
454 private final InetSocketAddress sender;
455 private final InetSocketAddress recipient;
456 private final DnsResponse response;
457
458 AddressedEnvelopeAdapter(InetSocketAddress sender, InetSocketAddress recipient, DnsResponse response) {
459 this.sender = sender;
460 this.recipient = recipient;
461 this.response = response;
462 }
463
464 @Override
465 public DnsResponse content() {
466 return response;
467 }
468
469 @Override
470 public InetSocketAddress sender() {
471 return sender;
472 }
473
474 @Override
475 public InetSocketAddress recipient() {
476 return recipient;
477 }
478
479 @Override
480 public AddressedEnvelope<DnsResponse, InetSocketAddress> retain() {
481 response.retain();
482 return this;
483 }
484
485 @Override
486 public AddressedEnvelope<DnsResponse, InetSocketAddress> retain(int increment) {
487 response.retain(increment);
488 return this;
489 }
490
491 @Override
492 public AddressedEnvelope<DnsResponse, InetSocketAddress> touch() {
493 response.touch();
494 return this;
495 }
496
497 @Override
498 public AddressedEnvelope<DnsResponse, InetSocketAddress> touch(Object hint) {
499 response.touch(hint);
500 return this;
501 }
502
503 @Override
504 public int refCnt() {
505 return response.refCnt();
506 }
507
508 @Override
509 public boolean release() {
510 return response.release();
511 }
512
513 @Override
514 public boolean release(int decrement) {
515 return response.release(decrement);
516 }
517
518 @Override
519 public boolean equals(Object obj) {
520 if (this == obj) {
521 return true;
522 }
523
524 if (!(obj instanceof AddressedEnvelope)) {
525 return false;
526 }
527
528 @SuppressWarnings("unchecked")
529 final AddressedEnvelope<?, SocketAddress> that = (AddressedEnvelope<?, SocketAddress>) obj;
530 if (sender() == null) {
531 if (that.sender() != null) {
532 return false;
533 }
534 } else if (!sender().equals(that.sender())) {
535 return false;
536 }
537
538 if (recipient() == null) {
539 if (that.recipient() != null) {
540 return false;
541 }
542 } else if (!recipient().equals(that.recipient())) {
543 return false;
544 }
545
546 return response.equals(obj);
547 }
548
549 @Override
550 public int hashCode() {
551 int hashCode = response.hashCode();
552 if (sender() != null) {
553 hashCode = hashCode * 31 + sender().hashCode();
554 }
555 if (recipient() != null) {
556 hashCode = hashCode * 31 + recipient().hashCode();
557 }
558 return hashCode;
559 }
560 }
561 }