View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.util.internal;
17  
18  import org.jetbrains.annotations.NotNull;
19  
20  import java.io.FilterInputStream;
21  import java.io.IOException;
22  import java.io.InputStream;
23  
24  public final class BoundedInputStream extends FilterInputStream {
25  
26      private final int maxBytesRead;
27      private int numRead;
28  
29      public BoundedInputStream(@NotNull InputStream in, int maxBytesRead) {
30          super(in);
31          this.maxBytesRead = ObjectUtil.checkPositive(maxBytesRead, "maxRead");
32      }
33  
34      public BoundedInputStream(@NotNull InputStream in) {
35          this(in, 8 * 1024);
36      }
37  
38      @Override
39      public int read() throws IOException {
40          checkMaxBytesRead();
41  
42          int b = super.read();
43          if (b > 0) {
44              numRead++;
45          }
46  
47          checkMaxBytesRead();
48          return b;
49      }
50  
51      @Override
52      public int read(byte[] buf, int off, int len) throws IOException {
53          checkMaxBytesRead();
54  
55          // Calculate the maximum number of bytes that we should try to read.
56          int num = Math.min(len, maxBytesRead - numRead + 1);
57  
58          int b = super.read(buf, off, num);
59  
60          if (b > 0) {
61              numRead += b;
62          }
63  
64          checkMaxBytesRead();
65          return b;
66      }
67  
68      private void checkMaxBytesRead() throws IOException {
69          if (numRead > maxBytesRead) {
70              throw new IOException("Maximum number of bytes read: " + numRead);
71          }
72      }
73  }