View Javadoc
1   /*
2    * Copyright 2015 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.dns;
17  
18  import io.netty.util.AbstractReferenceCounted;
19  import io.netty.util.ReferenceCountUtil;
20  import io.netty.util.ReferenceCounted;
21  import io.netty.util.ResourceLeakDetector;
22  import io.netty.util.ResourceLeakDetectorFactory;
23  import io.netty.util.ResourceLeakTracker;
24  import io.netty.util.internal.StringUtil;
25  
26  import java.util.ArrayList;
27  import java.util.List;
28  
29  import static io.netty.util.internal.ObjectUtil.checkNotNull;
30  
31  /**
32   * A skeletal implementation of {@link DnsMessage}.
33   */
34  public abstract class AbstractDnsMessage extends AbstractReferenceCounted implements DnsMessage {
35  
36      private static final ResourceLeakDetector<DnsMessage> leakDetector =
37              ResourceLeakDetectorFactory.instance().newResourceLeakDetector(DnsMessage.class);
38  
39      private static final int SECTION_QUESTION = DnsSection.QUESTION.ordinal();
40      private static final int SECTION_COUNT = 4;
41  
42      private final ResourceLeakTracker<DnsMessage> leak = leakDetector.track(this);
43      private short id;
44      private DnsOpCode opCode;
45      private boolean recursionDesired;
46      private byte z;
47  
48      // To reduce the memory footprint of a message,
49      // each of the following fields is a single record or a list of records.
50      private Object questions;
51      private Object answers;
52      private Object authorities;
53      private Object additionals;
54  
55      /**
56       * Creates a new instance with the specified {@code id} and {@link DnsOpCode#QUERY} opCode.
57       */
58      protected AbstractDnsMessage(int id) {
59          this(id, DnsOpCode.QUERY);
60      }
61  
62      /**
63       * Creates a new instance with the specified {@code id} and {@code opCode}.
64       */
65      protected AbstractDnsMessage(int id, DnsOpCode opCode) {
66          setId(id);
67          setOpCode(opCode);
68      }
69  
70      @Override
71      public int id() {
72          return id & 0xFFFF;
73      }
74  
75      @Override
76      public DnsMessage setId(int id) {
77          this.id = (short) id;
78          return this;
79      }
80  
81      @Override
82      public DnsOpCode opCode() {
83          return opCode;
84      }
85  
86      @Override
87      public DnsMessage setOpCode(DnsOpCode opCode) {
88          this.opCode = checkNotNull(opCode, "opCode");
89          return this;
90      }
91  
92      @Override
93      public boolean isRecursionDesired() {
94          return recursionDesired;
95      }
96  
97      @Override
98      public DnsMessage setRecursionDesired(boolean recursionDesired) {
99          this.recursionDesired = recursionDesired;
100         return this;
101     }
102 
103     @Override
104     public int z() {
105         return z;
106     }
107 
108     @Override
109     public DnsMessage setZ(int z) {
110         this.z = (byte) (z & 7);
111         return this;
112     }
113 
114     @Override
115     public int count(DnsSection section) {
116         return count(sectionOrdinal(section));
117     }
118 
119     private int count(int section) {
120         final Object records = sectionAt(section);
121         if (records == null) {
122             return 0;
123         }
124         if (records instanceof DnsRecord) {
125             return 1;
126         }
127 
128         @SuppressWarnings("unchecked")
129         final List<DnsRecord> recordList = (List<DnsRecord>) records;
130         return recordList.size();
131     }
132 
133     @Override
134     public int count() {
135         int count = 0;
136         for (int i = 0; i < SECTION_COUNT; i ++) {
137             count += count(i);
138         }
139         return count;
140     }
141 
142     @Override
143     public <T extends DnsRecord> T recordAt(DnsSection section) {
144         return recordAt(sectionOrdinal(section));
145     }
146 
147     private <T extends DnsRecord> T recordAt(int section) {
148         final Object records = sectionAt(section);
149         if (records == null) {
150             return null;
151         }
152 
153         if (records instanceof DnsRecord) {
154             return castRecord(records);
155         }
156 
157         @SuppressWarnings("unchecked")
158         final List<DnsRecord> recordList = (List<DnsRecord>) records;
159         if (recordList.isEmpty()) {
160             return null;
161         }
162 
163         return castRecord(recordList.get(0));
164     }
165 
166     @Override
167     public <T extends DnsRecord> T recordAt(DnsSection section, int index) {
168         return recordAt(sectionOrdinal(section), index);
169     }
170 
171     private <T extends DnsRecord> T recordAt(int section, int index) {
172         final Object records = sectionAt(section);
173         if (records == null) {
174             throw new IndexOutOfBoundsException("index: " + index + " (expected: none)");
175         }
176 
177         if (records instanceof DnsRecord) {
178             if (index == 0) {
179                 return castRecord(records);
180             } else {
181                 throw new IndexOutOfBoundsException("index: " + index + "' (expected: 0)");
182             }
183         }
184 
185         @SuppressWarnings("unchecked")
186         final List<DnsRecord> recordList = (List<DnsRecord>) records;
187         return castRecord(recordList.get(index));
188     }
189 
190     @Override
191     public DnsMessage setRecord(DnsSection section, DnsRecord record) {
192         setRecord(sectionOrdinal(section), record);
193         return this;
194     }
195 
196     private void setRecord(int section, DnsRecord record) {
197         clear(section);
198         setSection(section, checkQuestion(section, record));
199     }
200 
201     @Override
202     public <T extends DnsRecord> T setRecord(DnsSection section, int index, DnsRecord record) {
203         return setRecord(sectionOrdinal(section), index, record);
204     }
205 
206     private <T extends DnsRecord> T setRecord(int section, int index, DnsRecord record) {
207         checkQuestion(section, record);
208 
209         final Object records = sectionAt(section);
210         if (records == null) {
211             throw new IndexOutOfBoundsException("index: " + index + " (expected: none)");
212         }
213 
214         if (records instanceof DnsRecord) {
215             if (index == 0) {
216                 setSection(section, record);
217                 return castRecord(records);
218             } else {
219                 throw new IndexOutOfBoundsException("index: " + index + " (expected: 0)");
220             }
221         }
222 
223         @SuppressWarnings("unchecked")
224         final List<DnsRecord> recordList = (List<DnsRecord>) records;
225         return castRecord(recordList.set(index, record));
226     }
227 
228     @Override
229     public DnsMessage addRecord(DnsSection section, DnsRecord record) {
230         addRecord(sectionOrdinal(section), record);
231         return this;
232     }
233 
234     private void addRecord(int section, DnsRecord record) {
235         checkQuestion(section, record);
236 
237         final Object records = sectionAt(section);
238         if (records == null) {
239             setSection(section, record);
240             return;
241         }
242 
243         if (records instanceof DnsRecord) {
244             final List<DnsRecord> recordList = newRecordList();
245             recordList.add(castRecord(records));
246             recordList.add(record);
247             setSection(section, recordList);
248             return;
249         }
250 
251         @SuppressWarnings("unchecked")
252         final List<DnsRecord> recordList = (List<DnsRecord>) records;
253         recordList.add(record);
254     }
255 
256     @Override
257     public DnsMessage addRecord(DnsSection section, int index, DnsRecord record) {
258         addRecord(sectionOrdinal(section), index, record);
259         return this;
260     }
261 
262     private void addRecord(int section, int index, DnsRecord record) {
263         checkQuestion(section, record);
264 
265         final Object records = sectionAt(section);
266         if (records == null) {
267             if (index != 0) {
268                 throw new IndexOutOfBoundsException("index: " + index + " (expected: 0)");
269             }
270 
271             setSection(section, record);
272             return;
273         }
274 
275         if (records instanceof DnsRecord) {
276             final List<DnsRecord> recordList;
277             if (index == 0) {
278                 recordList = newRecordList();
279                 recordList.add(record);
280                 recordList.add(castRecord(records));
281             } else if (index == 1) {
282                 recordList = newRecordList();
283                 recordList.add(castRecord(records));
284                 recordList.add(record);
285             } else {
286                 throw new IndexOutOfBoundsException("index: " + index + " (expected: 0 or 1)");
287             }
288             setSection(section, recordList);
289             return;
290         }
291 
292         @SuppressWarnings("unchecked")
293         final List<DnsRecord> recordList = (List<DnsRecord>) records;
294         recordList.add(index, record);
295     }
296 
297     @Override
298     public <T extends DnsRecord> T removeRecord(DnsSection section, int index) {
299         return removeRecord(sectionOrdinal(section), index);
300     }
301 
302     private <T extends DnsRecord> T removeRecord(int section, int index) {
303         final Object records = sectionAt(section);
304         if (records == null) {
305             throw new IndexOutOfBoundsException("index: " + index + " (expected: none)");
306         }
307 
308         if (records instanceof DnsRecord) {
309             if (index != 0) {
310                 throw new IndexOutOfBoundsException("index: " + index + " (expected: 0)");
311             }
312 
313             T record = castRecord(records);
314             setSection(section, null);
315             return record;
316         }
317 
318         @SuppressWarnings("unchecked")
319         final List<DnsRecord> recordList = (List<DnsRecord>) records;
320         return castRecord(recordList.remove(index));
321     }
322 
323     @Override
324     public DnsMessage clear(DnsSection section) {
325         clear(sectionOrdinal(section));
326         return this;
327     }
328 
329     @Override
330     public DnsMessage clear() {
331         for (int i = 0; i < SECTION_COUNT; i ++) {
332             clear(i);
333         }
334         return this;
335     }
336 
337     private void clear(int section) {
338         final Object recordOrList = sectionAt(section);
339         setSection(section, null);
340         if (recordOrList instanceof ReferenceCounted) {
341             ((ReferenceCounted) recordOrList).release();
342         } else if (recordOrList instanceof List) {
343             @SuppressWarnings("unchecked")
344             List<DnsRecord> list = (List<DnsRecord>) recordOrList;
345             if (!list.isEmpty()) {
346                 for (Object r : list) {
347                     ReferenceCountUtil.release(r);
348                 }
349             }
350         }
351     }
352 
353     @Override
354     public DnsMessage touch() {
355         return (DnsMessage) super.touch();
356     }
357 
358     @Override
359     public DnsMessage touch(Object hint) {
360         if (leak != null) {
361             leak.record(hint);
362         }
363         return this;
364     }
365 
366     @Override
367     public DnsMessage retain() {
368         return (DnsMessage) super.retain();
369     }
370 
371     @Override
372     public DnsMessage retain(int increment) {
373         return (DnsMessage) super.retain(increment);
374     }
375 
376     @Override
377     protected void deallocate() {
378         clear();
379 
380         final ResourceLeakTracker<DnsMessage> leak = this.leak;
381         if (leak != null) {
382             boolean closed = leak.close(this);
383             assert closed;
384         }
385     }
386 
387     @Override
388     public boolean equals(Object obj) {
389         if (this == obj) {
390             return true;
391         }
392 
393         if (!(obj instanceof DnsMessage)) {
394             return false;
395         }
396 
397         final DnsMessage that = (DnsMessage) obj;
398         if (id() != that.id()) {
399             return false;
400         }
401 
402         if (this instanceof DnsQuery) {
403             if (!(that instanceof DnsQuery)) {
404                 return false;
405             }
406         } else if (that instanceof DnsQuery) {
407             return false;
408         }
409 
410         return true;
411     }
412 
413     @Override
414     public int hashCode() {
415         return id() * 31 + (this instanceof DnsQuery? 0 : 1);
416     }
417 
418     private Object sectionAt(int section) {
419         switch (section) {
420         case 0:
421             return questions;
422         case 1:
423             return answers;
424         case 2:
425             return authorities;
426         case 3:
427             return additionals;
428         default:
429             break;
430         }
431 
432         throw new Error(); // Should never reach here.
433     }
434 
435     private void setSection(int section, Object value) {
436         switch (section) {
437         case 0:
438             questions = value;
439             return;
440         case 1:
441             answers = value;
442             return;
443         case 2:
444             authorities = value;
445             return;
446         case 3:
447             additionals = value;
448             return;
449         default:
450             break;
451         }
452 
453         throw new Error(); // Should never reach here.
454     }
455 
456     private static int sectionOrdinal(DnsSection section) {
457         return checkNotNull(section, "section").ordinal();
458     }
459 
460     private static DnsRecord checkQuestion(int section, DnsRecord record) {
461         if (section == SECTION_QUESTION && !(checkNotNull(record, "record") instanceof DnsQuestion)) {
462             throw new IllegalArgumentException(
463                     "record: " + record + " (expected: " + StringUtil.simpleClassName(DnsQuestion.class) + ')');
464         }
465         return record;
466     }
467 
468     @SuppressWarnings("unchecked")
469     private static <T extends DnsRecord> T castRecord(Object record) {
470         return (T) record;
471     }
472 
473     private static ArrayList<DnsRecord> newRecordList() {
474         return new ArrayList<DnsRecord>(2);
475     }
476 }