1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.resolver.dns;
17
18 import io.netty5.channel.AddressedEnvelope;
19 import io.netty5.channel.Channel;
20 import io.netty5.handler.codec.dns.AbstractDnsOptPseudoRrRecord;
21 import io.netty5.handler.codec.dns.DnsOptPseudoRecord;
22 import io.netty5.handler.codec.dns.DnsQuery;
23 import io.netty5.handler.codec.dns.DnsQuestion;
24 import io.netty5.handler.codec.dns.DnsRecord;
25 import io.netty5.handler.codec.dns.DnsRecordType;
26 import io.netty5.handler.codec.dns.DnsResponse;
27 import io.netty5.handler.codec.dns.DnsSection;
28 import io.netty5.util.Resource;
29 import io.netty5.util.concurrent.Future;
30 import io.netty5.util.concurrent.FutureListener;
31 import io.netty5.util.concurrent.Promise;
32 import io.netty5.util.internal.SilentDispose;
33 import io.netty5.util.internal.logging.InternalLogger;
34 import io.netty5.util.internal.logging.InternalLoggerFactory;
35
36 import java.net.InetSocketAddress;
37 import java.util.concurrent.TimeUnit;
38
39 import static java.util.Objects.requireNonNull;
40
41 abstract class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>> {
42
43 private static final InternalLogger logger = InternalLoggerFactory.getInstance(DnsQueryContext.class);
44
45 private final DnsNameResolver parent;
46 private final Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise;
47 private final int id;
48 private final DnsQuestion question;
49 private final DnsRecord[] additionals;
50 private final DnsRecord optResource;
51 private final InetSocketAddress nameServerAddr;
52
53 private final boolean recursionDesired;
54 private volatile Future<?> timeoutFuture;
55
56 DnsQueryContext(DnsNameResolver parent,
57 InetSocketAddress nameServerAddr,
58 DnsQuestion question,
59 DnsRecord[] additionals,
60 Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
61
62 this.parent = requireNonNull(parent, "parent");
63 this.nameServerAddr = requireNonNull(nameServerAddr, "nameServerAddr");
64 this.question = requireNonNull(question, "question");
65 this.additionals = requireNonNull(additionals, "additionals");
66 this.promise = requireNonNull(promise, "promise");
67 recursionDesired = parent.isRecursionDesired();
68 id = parent.queryContextManager.add(this);
69
70
71 promise.asFuture().addListener(this);
72
73 if (parent.isOptResourceEnabled() &&
74
75
76
77 !hasOptRecord(additionals)) {
78 optResource = new AbstractDnsOptPseudoRrRecord(parent.maxPayloadSize(), 0, 0) {
79
80 @Override
81 public DnsOptPseudoRecord copy() {
82 return this;
83 }
84 };
85 } else {
86 optResource = null;
87 }
88 }
89
90 private static boolean hasOptRecord(DnsRecord[] additionals) {
91 if (additionals != null && additionals.length > 0) {
92 for (DnsRecord additional: additionals) {
93 if (additional.type() == DnsRecordType.OPT) {
94 return true;
95 }
96 }
97 }
98 return false;
99 }
100
101 InetSocketAddress nameServerAddr() {
102 return nameServerAddr;
103 }
104
105 DnsQuestion question() {
106 return question;
107 }
108
109 DnsNameResolver parent() {
110 return parent;
111 }
112
113 protected abstract DnsQuery newQuery(int id);
114 protected abstract Channel channel();
115 protected abstract String protocol();
116
117 void query(boolean flush, Promise<Void> writePromise) {
118 final DnsQuestion question = question();
119 final InetSocketAddress nameServerAddr = nameServerAddr();
120 final DnsQuery query = newQuery(id);
121
122 query.setRecursionDesired(recursionDesired);
123
124 query.addRecord(DnsSection.QUESTION, question);
125
126 for (DnsRecord record: additionals) {
127 query.addRecord(DnsSection.ADDITIONAL, record);
128 }
129
130 if (optResource != null) {
131 query.addRecord(DnsSection.ADDITIONAL, optResource);
132 }
133
134 if (logger.isDebugEnabled()) {
135 logger.debug("{} WRITE: {}, [{}: {}], {}", channel(), protocol(), id, nameServerAddr, question);
136 }
137
138 sendQuery(query, flush, writePromise);
139 }
140
141 private void sendQuery(final DnsQuery query, final boolean flush, final Promise<Void> writePromise) {
142 if (parent.channelReadyPromise.isSuccess()) {
143 writeQuery(query, flush, writePromise);
144 } else if (parent.channelReadyPromise.isFailed()) {
145 failQuery(query, parent.channelReadyPromise.cause(), writePromise);
146 } else {
147 parent.channelReadyPromise.asFuture().addListener(future -> {
148 if (future.isSuccess()) {
149
150
151
152 writeQuery(query, true, writePromise);
153 } else {
154 failQuery(query, future.cause(), writePromise);
155 }
156 });
157 }
158 }
159
160 private void failQuery(DnsQuery query, Throwable cause, Promise<Void> writePromise) {
161 try {
162 promise.tryFailure(cause);
163 writePromise.setFailure(cause);
164 } catch (Throwable throwable) {
165 SilentDispose.dispose(query, logger);
166 throw throwable;
167 }
168 Resource.dispose(query);
169 }
170
171 private void writeQuery(final DnsQuery query, final boolean flush, final Promise<Void> writePromise) {
172 final Future<Void> writeFuture = flush ? channel().writeAndFlush(query) :
173 channel().write(query);
174 if (writeFuture.isDone()) {
175 onQueryWriteCompletion(writeFuture, writePromise);
176 } else {
177 writeFuture.addListener(future ->
178 onQueryWriteCompletion(future, writePromise));
179 }
180 }
181
182 private void onQueryWriteCompletion(Future<?> writeFuture, Promise<Void> writePromise) {
183 if (writeFuture.isFailed()) {
184 writePromise.setFailure(writeFuture.cause());
185 tryFailure("failed to send a query via " + protocol(), writeFuture.cause(), false);
186 return;
187 }
188 writePromise.setSuccess(null);
189
190 final long queryTimeoutMillis = parent.queryTimeoutMillis();
191 if (queryTimeoutMillis > 0) {
192 timeoutFuture = parent.ch.executor().schedule(() -> {
193 if (promise.isDone()) {
194
195 return;
196 }
197
198 tryFailure("query via " + protocol() + " timed out after " +
199 queryTimeoutMillis + " milliseconds", null, true);
200 }, queryTimeoutMillis, TimeUnit.MILLISECONDS);
201 }
202 }
203
204
205
206
207 void finish(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
208 final DnsResponse res = envelope.content();
209 if (res.count(DnsSection.QUESTION) != 1) {
210 logger.warn("Received a DNS response with invalid number of questions: {}", envelope);
211 } else if (!question().equals(res.recordAt(DnsSection.QUESTION))) {
212 logger.warn("Received a mismatching DNS response: {}", envelope);
213 } else if (trySuccess(envelope)) {
214 return;
215 }
216 Resource.dispose(envelope);
217 }
218
219 @SuppressWarnings("unchecked")
220 private boolean trySuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
221 return promise.trySuccess((AddressedEnvelope<DnsResponse, InetSocketAddress>) envelope);
222 }
223
224 boolean tryFailure(String message, Throwable cause, boolean timeout) {
225 if (promise.isDone()) {
226 return false;
227 }
228 final InetSocketAddress nameServerAddr = nameServerAddr();
229
230 final StringBuilder buf = new StringBuilder(message.length() + 64);
231 buf.append('[')
232 .append(nameServerAddr)
233 .append("] ")
234 .append(message)
235 .append(" (no stack trace available)");
236
237 final DnsNameResolverException e;
238 if (timeout) {
239
240
241 e = new DnsNameResolverTimeoutException(nameServerAddr, question(), buf.toString());
242 } else {
243 e = new DnsNameResolverException(nameServerAddr, question(), buf.toString(), cause);
244 }
245 return promise.tryFailure(e);
246 }
247
248 @Override
249 public void operationComplete(Future<? extends AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
250
251 final Future<?> timeoutFuture = this.timeoutFuture;
252 if (timeoutFuture != null) {
253 this.timeoutFuture = null;
254 timeoutFuture.cancel();
255 }
256
257
258
259 parent.queryContextManager.remove(nameServerAddr, id);
260 }
261 }