1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.handler.ssl;
17
18 import io.netty5.buffer.api.Buffer;
19 import io.netty5.buffer.api.ReadableComponent;
20 import io.netty5.buffer.api.ReadableComponentProcessor;
21 import io.netty5.buffer.api.WritableComponent;
22 import io.netty5.buffer.api.WritableComponentProcessor;
23 import io.netty5.buffer.api.internal.Statics;
24 import io.netty5.util.internal.PlatformDependent;
25
26 import javax.net.ssl.SSLEngine;
27 import javax.net.ssl.SSLEngineResult;
28 import javax.net.ssl.SSLException;
29 import java.lang.ref.Reference;
30 import java.nio.ByteBuffer;
31 import java.util.Objects;
32
33
34
35
36
37
38
39
40 class EngineWrapper implements ReadableComponentProcessor<RuntimeException>,
41 WritableComponentProcessor<RuntimeException> {
42
43
44
45 private static final ByteBuffer EMPTY_BUFFER_DIRECT = ByteBuffer.allocateDirect(0);
46
47
48
49 private static final ByteBuffer EMPTY_BUFFER_HEAP = ByteBuffer.allocate(0);
50
51 private final SSLEngine engine;
52 private final boolean useDirectBuffer;
53
54
55
56
57
58
59 private final ByteBuffer[] singleEmptyBuffer;
60 private final ByteBuffer[] singleReadableBuffer;
61 private final ByteBuffer[] singleWritableBuffer;
62
63 private ByteBuffer[] inputs;
64 private ByteBuffer[] outputs;
65 private SSLEngineResult result;
66 private ByteBuffer cachedReadingBuffer;
67 private ByteBuffer cachedWritingBuffer;
68 private boolean writeBack;
69
70 EngineWrapper(SSLEngine engine, boolean useDirectBuffer) {
71 this.engine = Objects.requireNonNull(engine, "engine");
72 this.useDirectBuffer = useDirectBuffer;
73 singleEmptyBuffer = new ByteBuffer[1];
74 singleEmptyBuffer[0] = useDirectBuffer? EMPTY_BUFFER_DIRECT : EMPTY_BUFFER_HEAP;
75 singleReadableBuffer = new ByteBuffer[1];
76 singleWritableBuffer = new ByteBuffer[1];
77 }
78
79 SSLEngineResult wrap(Buffer in, Buffer out) throws SSLException {
80 try {
81 prepare(in, out);
82 int count = outputs.length;
83 assert count == 1 : "Wrap can only output to a single buffer, but got " + count + " buffers.";
84 return processResult(engine.wrap(inputs, outputs[0]));
85 } finally {
86 finish(in, out);
87 }
88 }
89
90 SSLEngineResult unwrap(Buffer in, int length, Buffer out) throws SSLException {
91 try {
92 prepare(in, out);
93 limitInput(length);
94 if (engine instanceof VectoredUnwrap) {
95 VectoredUnwrap vectoredEngine = (VectoredUnwrap) engine;
96 return processResult(vectoredEngine.unwrap(inputs, outputs));
97 }
98 if (inputs.length > 1) {
99 coalesceInputs();
100 }
101 return processResult(engine.unwrap(inputs[0], outputs));
102 } finally {
103 finish(in, out);
104 }
105 }
106
107 private void prepare(Buffer in, Buffer out) {
108 if (in == null || in.readableBytes() == 0) {
109 inputs = singleEmptyBuffer;
110 } else if (in.isDirect() == useDirectBuffer) {
111 int count = in.countReadableComponents();
112 assert count > 0 : "Input buffer has readable bytes, but no readable components: " + in;
113 inputs = count == 1? singleReadableBuffer : new ByteBuffer[count];
114 int prepared = in.forEachReadable(0, this);
115 assert prepared == count : "Expected to prepare " + count + " buffers, but got " + prepared;
116 } else {
117 inputs = singleReadableBuffer;
118 int readable = in.readableBytes();
119 if (cachedReadingBuffer == null || cachedReadingBuffer.capacity() < readable) {
120 cachedReadingBuffer = allocateCachingBuffer(readable);
121 }
122 cachedReadingBuffer.clear();
123 in.copyInto(in.readerOffset(), cachedReadingBuffer, 0, readable);
124 cachedReadingBuffer.limit(readable);
125 inputs[0] = cachedReadingBuffer;
126 }
127 if (out == null || out.writableBytes() == 0) {
128 outputs = singleEmptyBuffer;
129 } else if (out.isDirect() == useDirectBuffer) {
130 int count = out.countWritableComponents();
131 assert count > 0 : "Output buffer has writable space, but no writable components: " + out;
132 outputs = count == 1? singleWritableBuffer : new ByteBuffer[count];
133 int prepared = out.forEachWritable(0, this);
134 assert prepared == count : "Expected to prepare " + count + " buffers, but got " + prepared;
135 } else {
136 inputs = singleWritableBuffer;
137 int writable = out.writableBytes();
138 if (cachedWritingBuffer == null || cachedWritingBuffer.capacity() < writable) {
139 cachedWritingBuffer = allocateCachingBuffer(writable);
140 }
141 outputs[0] = cachedWritingBuffer.position(0).limit(writable);
142 writeBack = true;
143 }
144 }
145
146 private ByteBuffer allocateCachingBuffer(int capacity) {
147 capacity = PlatformDependent.roundToPowerOfTwo(capacity);
148 return useDirectBuffer? ByteBuffer.allocateDirect(capacity) : ByteBuffer.allocate(capacity);
149 }
150
151 private void limitInput(int length) {
152 for (ByteBuffer input : inputs) {
153 int remaining = input.remaining();
154 if (remaining > length) {
155 input.limit(input.position() + length);
156 length = 0;
157 } else {
158 length -= remaining;
159 }
160 }
161 }
162
163 private void coalesceInputs() {
164 int rem = 0;
165 for (ByteBuffer input : inputs) {
166 rem += input.remaining();
167 }
168 if (cachedReadingBuffer == null || cachedReadingBuffer.capacity() < rem) {
169 cachedReadingBuffer = allocateCachingBuffer(rem);
170 }
171 cachedReadingBuffer.clear();
172 for (ByteBuffer input : inputs) {
173 cachedReadingBuffer.put(input);
174 }
175 cachedReadingBuffer.flip();
176 singleReadableBuffer[0] = cachedReadingBuffer;
177 inputs = singleReadableBuffer;
178 }
179
180 private SSLEngineResult processResult(SSLEngineResult result) {
181 this.result = result;
182 return result;
183 }
184
185 private void finish(Buffer in, Buffer out) {
186 if (result != null) {
187 if (in != null) {
188 in.skipReadableBytes(result.bytesConsumed());
189 }
190 if (out != null) {
191 if (writeBack) {
192 assert outputs.length == 1;
193 ByteBuffer buf = outputs[0];
194 while (buf.remaining() >= Long.BYTES) {
195 out.writeLong(buf.getLong());
196 }
197 if (buf.remaining() >= Integer.BYTES) {
198 out.writeInt(buf.getInt());
199 }
200 if (buf.remaining() >= Short.BYTES) {
201 out.writeShort(buf.getShort());
202 }
203 if (buf.hasRemaining()) {
204 out.writeByte(buf.get());
205 }
206 } else {
207 out.skipWritableBytes(result.bytesProduced());
208 }
209 }
210 result = null;
211 }
212 singleReadableBuffer[0] = null;
213 singleWritableBuffer[0] = null;
214 inputs = null;
215 outputs = null;
216
217
218
219
220 Reference.reachabilityFence(in);
221 Reference.reachabilityFence(out);
222 }
223
224 @Override
225 public boolean process(int index, ReadableComponent component) {
226
227
228 ByteBuffer byteBuffer = Statics.tryGetWritableBufferFromReadableComponent(component);
229 if (byteBuffer == null) {
230 byteBuffer = component.readableBuffer();
231 }
232 inputs[index] = byteBuffer;
233 return true;
234 }
235
236 @Override
237 public boolean process(int index, WritableComponent component) {
238 outputs[index] = component.writableBuffer();
239 return true;
240 }
241
242 @Override
243 public String toString() {
244 return "EngineWrapper(for " + engine.getPeerPort() + ')';
245 }
246 }