View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.util.internal;
17  
18  import org.jetbrains.annotations.NotNull;
19  
20  import java.io.FilterInputStream;
21  import java.io.IOException;
22  import java.io.InputStream;
23  
24  public final class BoundedInputStream extends FilterInputStream {
25  
26      private final int maxBytesRead;
27      private int numRead;
28  
29      public BoundedInputStream(@NotNull InputStream in, int maxBytesRead) {
30          super(in);
31          this.maxBytesRead = ObjectUtil.checkPositive(maxBytesRead, "maxRead");
32      }
33  
34      public BoundedInputStream(@NotNull InputStream in) {
35          this(in, 8 * 1024);
36      }
37  
38      @Override
39      public int read() throws IOException {
40          checkMaxBytesRead(1);
41          try {
42              int b = super.read();
43              if (b <= 0) {
44                  // We couldn't read anything.
45                  numRead--;
46              }
47              return b;
48          } catch (IOException e) {
49              numRead--;
50              throw e;
51          }
52      }
53  
54      @Override
55      public int read(byte[] buf, int off, int len) throws IOException {
56          // Calculate the maximum number of bytes that we should try to read.
57          int num = Math.min(len, maxBytesRead - numRead + 1);
58          checkMaxBytesRead(num);
59          try {
60              int b = super.read(buf, off, num);
61              if (b == -1) {
62                  // We couldn't read anything.
63                  numRead -= num;
64              } else if (b != num) {
65                  // Correct numRead based on the actual amount we were able to read.
66                  numRead -= num - b;
67              }
68              return b;
69          } catch (IOException e) {
70              numRead -= num;
71              throw e;
72          }
73      }
74  
75      private void checkMaxBytesRead(int n) throws IOException {
76          int sum = numRead + n;
77          if (sum < 0 || sum > maxBytesRead) {
78              numRead = maxBytesRead + 1;
79              throw new IOException("Maximum number of bytes read: " + maxBytesRead);
80          }
81          numRead = sum;
82      }
83  }