1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.handler.codec.dns;
17
18 import io.netty5.util.Resource;
19 import io.netty5.util.AbstractReferenceCounted;
20 import io.netty5.util.ResourceLeakDetector;
21 import io.netty5.util.ResourceLeakDetectorFactory;
22 import io.netty5.util.ResourceLeakTracker;
23 import io.netty5.util.internal.StringUtil;
24 import io.netty5.util.internal.UnstableApi;
25
26 import java.util.ArrayList;
27 import java.util.List;
28
29 import static java.util.Objects.requireNonNull;
30
31
32
33
34 @UnstableApi
35 public abstract class AbstractDnsMessage extends AbstractReferenceCounted implements DnsMessage {
36
37 private static final ResourceLeakDetector<DnsMessage> leakDetector =
38 ResourceLeakDetectorFactory.instance().newResourceLeakDetector(DnsMessage.class);
39
40 private static final int SECTION_QUESTION = DnsSection.QUESTION.ordinal();
41 private static final int SECTION_COUNT = 4;
42
43 private final ResourceLeakTracker<DnsMessage> leak = leakDetector.track(this);
44 private short id;
45 private DnsOpCode opCode;
46 private boolean recursionDesired;
47 private byte z;
48
49
50
51 private Object questions;
52 private Object answers;
53 private Object authorities;
54 private Object additionals;
55
56
57
58
59 protected AbstractDnsMessage(int id) {
60 this(id, DnsOpCode.QUERY);
61 }
62
63
64
65
66 protected AbstractDnsMessage(int id, DnsOpCode opCode) {
67 setId(id);
68 setOpCode(opCode);
69 }
70
71 @Override
72 public int id() {
73 return id & 0xFFFF;
74 }
75
76 @Override
77 public DnsMessage setId(int id) {
78 this.id = (short) id;
79 return this;
80 }
81
82 @Override
83 public DnsOpCode opCode() {
84 return opCode;
85 }
86
87 @Override
88 public DnsMessage setOpCode(DnsOpCode opCode) {
89 this.opCode = requireNonNull(opCode, "opCode");
90 return this;
91 }
92
93 @Override
94 public boolean isRecursionDesired() {
95 return recursionDesired;
96 }
97
98 @Override
99 public DnsMessage setRecursionDesired(boolean recursionDesired) {
100 this.recursionDesired = recursionDesired;
101 return this;
102 }
103
104 @Override
105 public int z() {
106 return z;
107 }
108
109 @Override
110 public DnsMessage setZ(int z) {
111 this.z = (byte) (z & 7);
112 return this;
113 }
114
115 @Override
116 public int count(DnsSection section) {
117 return count(sectionOrdinal(section));
118 }
119
120 private int count(int section) {
121 final Object records = sectionAt(section);
122 if (records == null) {
123 return 0;
124 }
125 if (records instanceof DnsRecord) {
126 return 1;
127 }
128
129 @SuppressWarnings("unchecked")
130 final List<DnsRecord> recordList = (List<DnsRecord>) records;
131 return recordList.size();
132 }
133
134 @Override
135 public int count() {
136 int count = 0;
137 for (int i = 0; i < SECTION_COUNT; i ++) {
138 count += count(i);
139 }
140 return count;
141 }
142
143 @Override
144 public <T extends DnsRecord> T recordAt(DnsSection section) {
145 return recordAt(sectionOrdinal(section));
146 }
147
148 private <T extends DnsRecord> T recordAt(int section) {
149 final Object records = sectionAt(section);
150 if (records == null) {
151 return null;
152 }
153
154 if (records instanceof DnsRecord) {
155 return castRecord(records);
156 }
157
158 @SuppressWarnings("unchecked")
159 final List<DnsRecord> recordList = (List<DnsRecord>) records;
160 if (recordList.isEmpty()) {
161 return null;
162 }
163
164 return castRecord(recordList.get(0));
165 }
166
167 @Override
168 public <T extends DnsRecord> T recordAt(DnsSection section, int index) {
169 return recordAt(sectionOrdinal(section), index);
170 }
171
172 private <T extends DnsRecord> T recordAt(int section, int index) {
173 final Object records = sectionAt(section);
174 if (records == null) {
175 throw new IndexOutOfBoundsException("index: " + index + " (expected: none)");
176 }
177
178 if (records instanceof DnsRecord) {
179 if (index == 0) {
180 return castRecord(records);
181 } else {
182 throw new IndexOutOfBoundsException("index: " + index + "' (expected: 0)");
183 }
184 }
185
186 @SuppressWarnings("unchecked")
187 final List<DnsRecord> recordList = (List<DnsRecord>) records;
188 return castRecord(recordList.get(index));
189 }
190
191 @Override
192 public DnsMessage setRecord(DnsSection section, DnsRecord record) {
193 setRecord(sectionOrdinal(section), record);
194 return this;
195 }
196
197 private void setRecord(int section, DnsRecord record) {
198 clear(section);
199 setSection(section, checkQuestion(section, record));
200 }
201
202 @Override
203 public <T extends DnsRecord> T setRecord(DnsSection section, int index, DnsRecord record) {
204 return setRecord(sectionOrdinal(section), index, record);
205 }
206
207 private <T extends DnsRecord> T setRecord(int section, int index, DnsRecord record) {
208 checkQuestion(section, record);
209
210 final Object records = sectionAt(section);
211 if (records == null) {
212 throw new IndexOutOfBoundsException("index: " + index + " (expected: none)");
213 }
214
215 if (records instanceof DnsRecord) {
216 if (index == 0) {
217 setSection(section, record);
218 return castRecord(records);
219 } else {
220 throw new IndexOutOfBoundsException("index: " + index + " (expected: 0)");
221 }
222 }
223
224 @SuppressWarnings("unchecked")
225 final List<DnsRecord> recordList = (List<DnsRecord>) records;
226 return castRecord(recordList.set(index, record));
227 }
228
229 @Override
230 public DnsMessage addRecord(DnsSection section, DnsRecord record) {
231 addRecord(sectionOrdinal(section), record);
232 return this;
233 }
234
235 private void addRecord(int section, DnsRecord record) {
236 checkQuestion(section, record);
237
238 final Object records = sectionAt(section);
239 if (records == null) {
240 setSection(section, record);
241 return;
242 }
243
244 if (records instanceof DnsRecord) {
245 final List<DnsRecord> recordList = newRecordList();
246 recordList.add(castRecord(records));
247 recordList.add(record);
248 setSection(section, recordList);
249 return;
250 }
251
252 @SuppressWarnings("unchecked")
253 final List<DnsRecord> recordList = (List<DnsRecord>) records;
254 recordList.add(record);
255 }
256
257 @Override
258 public DnsMessage addRecord(DnsSection section, int index, DnsRecord record) {
259 addRecord(sectionOrdinal(section), index, record);
260 return this;
261 }
262
263 private void addRecord(int section, int index, DnsRecord record) {
264 checkQuestion(section, record);
265
266 final Object records = sectionAt(section);
267 if (records == null) {
268 if (index != 0) {
269 throw new IndexOutOfBoundsException("index: " + index + " (expected: 0)");
270 }
271
272 setSection(section, record);
273 return;
274 }
275
276 if (records instanceof DnsRecord) {
277 final List<DnsRecord> recordList;
278 if (index == 0) {
279 recordList = newRecordList();
280 recordList.add(record);
281 recordList.add(castRecord(records));
282 } else if (index == 1) {
283 recordList = newRecordList();
284 recordList.add(castRecord(records));
285 recordList.add(record);
286 } else {
287 throw new IndexOutOfBoundsException("index: " + index + " (expected: 0 or 1)");
288 }
289 setSection(section, recordList);
290 return;
291 }
292
293 @SuppressWarnings("unchecked")
294 final List<DnsRecord> recordList = (List<DnsRecord>) records;
295 recordList.add(index, record);
296 }
297
298 @Override
299 public <T extends DnsRecord> T removeRecord(DnsSection section, int index) {
300 return removeRecord(sectionOrdinal(section), index);
301 }
302
303 private <T extends DnsRecord> T removeRecord(int section, int index) {
304 final Object records = sectionAt(section);
305 if (records == null) {
306 throw new IndexOutOfBoundsException("index: " + index + " (expected: none)");
307 }
308
309 if (records instanceof DnsRecord) {
310 if (index != 0) {
311 throw new IndexOutOfBoundsException("index: " + index + " (expected: 0)");
312 }
313
314 T record = castRecord(records);
315 setSection(section, null);
316 return record;
317 }
318
319 @SuppressWarnings("unchecked")
320 final List<DnsRecord> recordList = (List<DnsRecord>) records;
321 return castRecord(recordList.remove(index));
322 }
323
324 @Override
325 public DnsMessage clear(DnsSection section) {
326 clear(sectionOrdinal(section));
327 return this;
328 }
329
330 @Override
331 public DnsMessage clear() {
332 for (int i = 0; i < SECTION_COUNT; i ++) {
333 clear(i);
334 }
335 return this;
336 }
337
338 private void clear(int section) {
339 final Object recordOrList = sectionAt(section);
340 setSection(section, null);
341 if (Resource.isAccessible(recordOrList, false)) {
342 Resource.dispose(recordOrList);
343 } else if (recordOrList instanceof List) {
344 @SuppressWarnings("unchecked")
345 List<DnsRecord> list = (List<DnsRecord>) recordOrList;
346 if (!list.isEmpty()) {
347 for (Object r : list) {
348 Resource.dispose(r);
349 }
350 }
351 }
352 }
353
354 @Override
355 public DnsMessage touch() {
356 return (DnsMessage) super.touch();
357 }
358
359 @Override
360 public DnsMessage touch(Object hint) {
361 if (leak != null) {
362 leak.record(hint);
363 }
364 return this;
365 }
366
367 @Override
368 public DnsMessage retain() {
369 return (DnsMessage) super.retain();
370 }
371
372 @Override
373 public DnsMessage retain(int increment) {
374 return (DnsMessage) super.retain(increment);
375 }
376
377 @Override
378 protected void deallocate() {
379 clear();
380
381 final ResourceLeakTracker<DnsMessage> leak = this.leak;
382 if (leak != null) {
383 boolean closed = leak.close(this);
384 assert closed;
385 }
386 }
387
388 @Override
389 public boolean equals(Object obj) {
390 if (this == obj) {
391 return true;
392 }
393
394 if (!(obj instanceof DnsMessage)) {
395 return false;
396 }
397
398 final DnsMessage that = (DnsMessage) obj;
399 if (id() != that.id()) {
400 return false;
401 }
402
403 if (this instanceof DnsQuery) {
404 return that instanceof DnsQuery;
405 } else {
406 return !(that instanceof DnsQuery);
407 }
408 }
409
410 @Override
411 public int hashCode() {
412 return id() * 31 + (this instanceof DnsQuery? 0 : 1);
413 }
414
415 private Object sectionAt(int section) {
416 switch (section) {
417 case 0:
418 return questions;
419 case 1:
420 return answers;
421 case 2:
422 return authorities;
423 case 3:
424 return additionals;
425 default:
426 break;
427 }
428
429 throw new Error();
430 }
431
432 private void setSection(int section, Object value) {
433 switch (section) {
434 case 0:
435 questions = value;
436 return;
437 case 1:
438 answers = value;
439 return;
440 case 2:
441 authorities = value;
442 return;
443 case 3:
444 additionals = value;
445 return;
446 default:
447 break;
448 }
449
450 throw new Error();
451 }
452
453 private static int sectionOrdinal(DnsSection section) {
454 return requireNonNull(section, "section").ordinal();
455 }
456
457 private static DnsRecord checkQuestion(int section, DnsRecord record) {
458 if (section == SECTION_QUESTION && !(requireNonNull(record, "record") instanceof DnsQuestion)) {
459 throw new IllegalArgumentException(
460 "record: " + record + " (expected: " + StringUtil.simpleClassName(DnsQuestion.class) + ')');
461 }
462 return record;
463 }
464
465 @SuppressWarnings("unchecked")
466 private static <T extends DnsRecord> T castRecord(Object record) {
467 return (T) record;
468 }
469
470 private static ArrayList<DnsRecord> newRecordList() {
471 return new ArrayList<>(2);
472 }
473 }