1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.resolver.dns;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.channel.AddressedEnvelope;
20 import io.netty.channel.Channel;
21 import io.netty.channel.ChannelFuture;
22 import io.netty.channel.ChannelFutureListener;
23 import io.netty.channel.ChannelHandlerContext;
24 import io.netty.channel.ChannelInboundHandlerAdapter;
25 import io.netty.channel.ChannelPromise;
26 import io.netty.handler.codec.dns.AbstractDnsOptPseudoRrRecord;
27 import io.netty.handler.codec.dns.DnsQuery;
28 import io.netty.handler.codec.dns.DnsQuestion;
29 import io.netty.handler.codec.dns.DnsRecord;
30 import io.netty.handler.codec.dns.DnsRecordType;
31 import io.netty.handler.codec.dns.DnsResponse;
32 import io.netty.handler.codec.dns.DnsSection;
33 import io.netty.handler.codec.dns.TcpDnsQueryEncoder;
34 import io.netty.handler.codec.dns.TcpDnsResponseDecoder;
35 import io.netty.util.ReferenceCountUtil;
36 import io.netty.util.concurrent.Future;
37 import io.netty.util.concurrent.FutureListener;
38 import io.netty.util.concurrent.Promise;
39 import io.netty.util.internal.SystemPropertyUtil;
40 import io.netty.util.internal.ThrowableUtil;
41 import io.netty.util.internal.logging.InternalLogger;
42 import io.netty.util.internal.logging.InternalLoggerFactory;
43
44 import java.net.InetSocketAddress;
45 import java.net.SocketAddress;
46 import java.util.concurrent.CancellationException;
47 import java.util.concurrent.TimeUnit;
48
49 import static io.netty.util.internal.ObjectUtil.checkNotNull;
50
51 abstract class DnsQueryContext {
52
53 private static final InternalLogger logger = InternalLoggerFactory.getInstance(DnsQueryContext.class);
54 private static final long ID_REUSE_ON_TIMEOUT_DELAY_MILLIS;
55
56 static {
57 ID_REUSE_ON_TIMEOUT_DELAY_MILLIS =
58 SystemPropertyUtil.getLong("io.netty.resolver.dns.idReuseOnTimeoutDelayMillis", 10000);
59 logger.debug("-Dio.netty.resolver.dns.idReuseOnTimeoutDelayMillis: {}", ID_REUSE_ON_TIMEOUT_DELAY_MILLIS);
60 }
61
62 private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder();
63
64 private final Channel channel;
65 private final InetSocketAddress nameServerAddr;
66 private final DnsQueryContextManager queryContextManager;
67 private final DnsQueryLifecycleObserver queryLifecycleObserver;
68 private final Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise;
69
70 private final DnsQuestion question;
71 private final DnsRecord[] additionals;
72 private final DnsRecord optResource;
73
74 private final boolean recursionDesired;
75
76 private final Bootstrap socketBootstrap;
77
78 private final boolean retryWithTcpOnTimeout;
79 private final long queryTimeoutMillis;
80
81 private volatile Future<?> timeoutFuture;
82
83 private int id = Integer.MIN_VALUE;
84
85 DnsQueryContext(Channel channel,
86 InetSocketAddress nameServerAddr,
87 DnsQueryContextManager queryContextManager,
88 DnsQueryLifecycleObserver queryLifecycleObserver,
89 int maxPayLoadSize,
90 boolean recursionDesired,
91 long queryTimeoutMillis,
92 DnsQuestion question,
93 DnsRecord[] additionals,
94 Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise,
95 Bootstrap socketBootstrap,
96 boolean retryWithTcpOnTimeout) {
97 this.channel = checkNotNull(channel, "channel");
98 this.queryContextManager = checkNotNull(queryContextManager, "queryContextManager");
99 this.queryLifecycleObserver = checkNotNull(queryLifecycleObserver, "queryLifecycleObserver");
100 this.nameServerAddr = checkNotNull(nameServerAddr, "nameServerAddr");
101 this.question = checkNotNull(question, "question");
102 this.additionals = checkNotNull(additionals, "additionals");
103 this.promise = checkNotNull(promise, "promise");
104 this.recursionDesired = recursionDesired;
105 this.queryTimeoutMillis = queryTimeoutMillis;
106 this.socketBootstrap = socketBootstrap;
107 this.retryWithTcpOnTimeout = retryWithTcpOnTimeout;
108
109 if (maxPayLoadSize > 0 &&
110
111
112
113 !hasOptRecord(additionals)) {
114 optResource = new AbstractDnsOptPseudoRrRecord(maxPayLoadSize, 0, 0) {
115
116 };
117 } else {
118 optResource = null;
119 }
120 }
121
122 private static boolean hasOptRecord(DnsRecord[] additionals) {
123 if (additionals != null && additionals.length > 0) {
124 for (DnsRecord additional: additionals) {
125 if (additional.type() == DnsRecordType.OPT) {
126 return true;
127 }
128 }
129 }
130 return false;
131 }
132
133
134
135
136
137
138 final boolean isDone() {
139 return promise.isDone();
140 }
141
142
143
144
145
146
147 final DnsQuestion question() {
148 return question;
149 }
150
151
152
153
154
155
156
157
158 protected abstract DnsQuery newQuery(int id, InetSocketAddress nameServerAddr);
159
160
161
162
163
164
165 protected abstract String protocol();
166
167
168
169
170
171
172 final void writeQuery(boolean flush) {
173 assert id == Integer.MIN_VALUE : this.getClass().getSimpleName() +
174 ".writeQuery(...) can only be executed once.";
175
176 if ((id = queryContextManager.add(nameServerAddr, this)) == -1) {
177
178 IllegalStateException e = new IllegalStateException("query ID space exhausted: " + question());
179 finishFailure("failed to send a query via " + protocol(), e, false);
180 queryLifecycleObserver.queryWritten(nameServerAddr, channel.newFailedFuture(e));
181 }
182
183
184 promise.addListener((FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>) future -> {
185
186 Future<?> timeoutFuture = DnsQueryContext.this.timeoutFuture;
187 if (timeoutFuture != null) {
188 DnsQueryContext.this.timeoutFuture = null;
189 timeoutFuture.cancel(false);
190 }
191
192 Throwable cause = future.cause();
193 if (cause instanceof DnsNameResolverTimeoutException || cause instanceof CancellationException) {
194
195
196
197 channel.eventLoop().schedule(new Runnable() {
198 @Override
199 public void run() {
200 removeFromContextManager(nameServerAddr);
201 }
202 }, ID_REUSE_ON_TIMEOUT_DELAY_MILLIS, TimeUnit.MILLISECONDS);
203 } else {
204
205
206 removeFromContextManager(nameServerAddr);
207 }
208 });
209 final DnsQuestion question = question();
210 final DnsQuery query = newQuery(id, nameServerAddr);
211
212 query.setRecursionDesired(recursionDesired);
213
214 query.addRecord(DnsSection.QUESTION, question);
215
216 for (DnsRecord record: additionals) {
217 query.addRecord(DnsSection.ADDITIONAL, record);
218 }
219
220 if (optResource != null) {
221 query.addRecord(DnsSection.ADDITIONAL, optResource);
222 }
223
224 if (logger.isDebugEnabled()) {
225 logger.debug("{} WRITE: {}, [{}: {}], {}",
226 channel, protocol(), id, nameServerAddr, question);
227 }
228
229 ChannelFuture f = sendQuery(query, flush);
230 queryLifecycleObserver.queryWritten(nameServerAddr, f);
231 }
232
233 private void removeFromContextManager(InetSocketAddress nameServerAddr) {
234 DnsQueryContext self = queryContextManager.remove(nameServerAddr, id);
235
236 assert self == this : "Removed DnsQueryContext is not the correct instance";
237 }
238
239 private ChannelFuture sendQuery(final DnsQuery query, final boolean flush) {
240 final ChannelPromise writePromise = channel.newPromise();
241 writeQuery(query, flush, writePromise);
242 return writePromise;
243 }
244
245 private void writeQuery(final DnsQuery query,
246 final boolean flush, ChannelPromise promise) {
247 final ChannelFuture writeFuture = flush ? channel.writeAndFlush(query, promise) :
248 channel.write(query, promise);
249 if (writeFuture.isDone()) {
250 onQueryWriteCompletion(queryTimeoutMillis, writeFuture);
251 } else {
252 writeFuture.addListener((ChannelFutureListener) future ->
253 onQueryWriteCompletion(queryTimeoutMillis, future));
254 }
255 }
256
257 private void onQueryWriteCompletion(final long queryTimeoutMillis,
258 ChannelFuture writeFuture) {
259 if (!writeFuture.isSuccess()) {
260 finishFailure("failed to send a query '" + id + "' via " + protocol(), writeFuture.cause(), false);
261 return;
262 }
263
264
265 if (queryTimeoutMillis > 0) {
266 timeoutFuture = channel.eventLoop().schedule(new Runnable() {
267 @Override
268 public void run() {
269 if (promise.isDone()) {
270
271 return;
272 }
273
274 finishFailure("query '" + id + "' via " + protocol() + " timed out after " +
275 queryTimeoutMillis + " milliseconds", null, true);
276 }
277 }, queryTimeoutMillis, TimeUnit.MILLISECONDS);
278 }
279 }
280
281
282
283
284
285 void finishSuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope, boolean truncated) {
286
287 if (!truncated || !retryWithTcp(envelope)) {
288 final DnsResponse res = envelope.content();
289 if (res.count(DnsSection.QUESTION) != 1) {
290 logger.warn("{} Received a DNS response with invalid number of questions. Expected: 1, found: {}",
291 channel, envelope);
292 } else if (!question().equals(res.recordAt(DnsSection.QUESTION))) {
293 logger.warn("{} Received a mismatching DNS response. Expected: [{}], found: {}",
294 channel, question(), envelope);
295 } else if (trySuccess(envelope)) {
296 return;
297 }
298 envelope.release();
299 }
300 }
301
302 @SuppressWarnings("unchecked")
303 private boolean trySuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
304 return promise.trySuccess((AddressedEnvelope<DnsResponse, InetSocketAddress>) envelope);
305 }
306
307
308
309
310 final boolean finishFailure(String message, Throwable cause, boolean timeout) {
311 if (promise.isDone()) {
312 return false;
313 }
314 final DnsQuestion question = question();
315
316 final StringBuilder buf = new StringBuilder(message.length() + 128);
317 buf.append('[')
318 .append(id)
319 .append(": ")
320 .append(nameServerAddr)
321 .append("] ")
322 .append(question)
323 .append(' ')
324 .append(message)
325 .append(" (no stack trace available)");
326
327 final DnsNameResolverException e;
328 if (timeout) {
329
330
331 e = new DnsNameResolverTimeoutException(nameServerAddr, question, buf.toString());
332 if (retryWithTcpOnTimeout && retryWithTcp(e)) {
333
334 return false;
335 }
336 } else {
337 e = new DnsNameResolverException(nameServerAddr, question, buf.toString(), cause);
338 }
339 return promise.tryFailure(e);
340 }
341
342
343
344
345
346
347
348
349 private boolean retryWithTcp(final Object originalResult) {
350 if (socketBootstrap == null) {
351 return false;
352 }
353
354 socketBootstrap.connect(nameServerAddr).addListener((ChannelFutureListener) future -> {
355 if (!future.isSuccess()) {
356 logger.debug("{} Unable to fallback to TCP [{}: {}]",
357 future.channel(), id, nameServerAddr, future.cause());
358
359 finishOriginal(originalResult, future);
360 return;
361 }
362 final Channel tcpCh = future.channel();
363 Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
364 tcpCh.eventLoop().newPromise();
365 final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(tcpCh,
366 (InetSocketAddress) tcpCh.remoteAddress(), queryContextManager, queryLifecycleObserver, 0,
367 recursionDesired, queryTimeoutMillis, question(), additionals, promise);
368 tcpCh.pipeline().addLast(TCP_ENCODER);
369 tcpCh.pipeline().addLast(new TcpDnsResponseDecoder());
370 tcpCh.pipeline().addLast(new ChannelInboundHandlerAdapter() {
371 @Override
372 public void channelRead(ChannelHandlerContext ctx, Object msg) {
373 Channel tcpCh = ctx.channel();
374 DnsResponse response = (DnsResponse) msg;
375 int queryId = response.id();
376
377 if (logger.isDebugEnabled()) {
378 logger.debug("{} RECEIVED: TCP [{}: {}], {}", tcpCh, queryId,
379 tcpCh.remoteAddress(), response);
380 }
381
382 DnsQueryContext foundCtx = queryContextManager.get(nameServerAddr, queryId);
383 if (foundCtx != null && foundCtx.isDone()) {
384 logger.debug("{} Received a DNS response for a query that was timed out or cancelled " +
385 ": TCP [{}: {}]", tcpCh, queryId, nameServerAddr);
386 response.release();
387 } else if (foundCtx == tcpCtx) {
388 tcpCtx.finishSuccess(new AddressedEnvelopeAdapter(
389 (InetSocketAddress) ctx.channel().remoteAddress(),
390 (InetSocketAddress) ctx.channel().localAddress(),
391 response), false);
392 } else {
393 response.release();
394 tcpCtx.finishFailure("Received TCP DNS response with unexpected ID", null, false);
395 if (logger.isDebugEnabled()) {
396 logger.debug("{} Received a DNS response with an unexpected ID: TCP [{}: {}]",
397 tcpCh, queryId, tcpCh.remoteAddress());
398 }
399 }
400 }
401
402 @Override
403 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
404 if (tcpCtx.finishFailure(
405 "TCP fallback error", cause, false) && logger.isDebugEnabled()) {
406 logger.debug("{} Error during processing response: TCP [{}: {}]",
407 ctx.channel(), id,
408 ctx.channel().remoteAddress(), cause);
409 }
410 }
411 });
412
413 promise.addListener(
414 (FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>) future1 -> {
415 if (future1.isSuccess()) {
416 finishSuccess(future1.getNow(), false);
417
418 ReferenceCountUtil.release(originalResult);
419 } else {
420
421 finishOriginal(originalResult, future1);
422 }
423 tcpCh.close();
424 });
425 tcpCtx.writeQuery(true);
426 });
427 return true;
428 }
429
430 @SuppressWarnings("unchecked")
431 private void finishOriginal(Object originalResult, Future<?> future) {
432 if (originalResult instanceof Throwable) {
433 Throwable error = (Throwable) originalResult;
434 ThrowableUtil.addSuppressed(error, future.cause());
435 promise.tryFailure(error);
436 } else {
437 finishSuccess((AddressedEnvelope<? extends DnsResponse, InetSocketAddress>) originalResult, false);
438 }
439 }
440
441 private static final class AddressedEnvelopeAdapter implements AddressedEnvelope<DnsResponse, InetSocketAddress> {
442 private final InetSocketAddress sender;
443 private final InetSocketAddress recipient;
444 private final DnsResponse response;
445
446 AddressedEnvelopeAdapter(InetSocketAddress sender, InetSocketAddress recipient, DnsResponse response) {
447 this.sender = sender;
448 this.recipient = recipient;
449 this.response = response;
450 }
451
452 @Override
453 public DnsResponse content() {
454 return response;
455 }
456
457 @Override
458 public InetSocketAddress sender() {
459 return sender;
460 }
461
462 @Override
463 public InetSocketAddress recipient() {
464 return recipient;
465 }
466
467 @Override
468 public AddressedEnvelope<DnsResponse, InetSocketAddress> retain() {
469 response.retain();
470 return this;
471 }
472
473 @Override
474 public AddressedEnvelope<DnsResponse, InetSocketAddress> retain(int increment) {
475 response.retain(increment);
476 return this;
477 }
478
479 @Override
480 public AddressedEnvelope<DnsResponse, InetSocketAddress> touch() {
481 response.touch();
482 return this;
483 }
484
485 @Override
486 public AddressedEnvelope<DnsResponse, InetSocketAddress> touch(Object hint) {
487 response.touch(hint);
488 return this;
489 }
490
491 @Override
492 public int refCnt() {
493 return response.refCnt();
494 }
495
496 @Override
497 public boolean release() {
498 return response.release();
499 }
500
501 @Override
502 public boolean release(int decrement) {
503 return response.release(decrement);
504 }
505
506 @Override
507 public boolean equals(Object obj) {
508 if (this == obj) {
509 return true;
510 }
511
512 if (!(obj instanceof AddressedEnvelope)) {
513 return false;
514 }
515
516 @SuppressWarnings("unchecked")
517 final AddressedEnvelope<?, SocketAddress> that = (AddressedEnvelope<?, SocketAddress>) obj;
518 if (sender() == null) {
519 if (that.sender() != null) {
520 return false;
521 }
522 } else if (!sender().equals(that.sender())) {
523 return false;
524 }
525
526 if (recipient() == null) {
527 if (that.recipient() != null) {
528 return false;
529 }
530 } else if (!recipient().equals(that.recipient())) {
531 return false;
532 }
533
534 return response.equals(obj);
535 }
536
537 @Override
538 public int hashCode() {
539 int hashCode = response.hashCode();
540 if (sender() != null) {
541 hashCode = hashCode * 31 + sender().hashCode();
542 }
543 if (recipient() != null) {
544 hashCode = hashCode * 31 + recipient().hashCode();
545 }
546 return hashCode;
547 }
548 }
549 }