View Javadoc
1   /*
2    * Copyright 2021 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty5.handler.ssl;
17  
18  import io.netty.internal.tcnative.SSLSession;
19  import io.netty.internal.tcnative.SSLSessionCache;
20  import io.netty5.util.ResourceLeakDetector;
21  import io.netty5.util.ResourceLeakDetectorFactory;
22  import io.netty5.util.ResourceLeakTracker;
23  import io.netty5.util.internal.EmptyArrays;
24  import io.netty5.util.internal.SystemPropertyUtil;
25  
26  import javax.security.cert.X509Certificate;
27  import java.security.Principal;
28  import java.security.cert.Certificate;
29  import java.util.ArrayList;
30  import java.util.Iterator;
31  import java.util.LinkedHashMap;
32  import java.util.List;
33  import java.util.Map;
34  import java.util.concurrent.atomic.AtomicInteger;
35  
36  /**
37   * {@link SSLSessionCache} implementation for our native SSL implementation.
38   */
39  class OpenSslSessionCache implements SSLSessionCache {
40      private static final OpenSslSession[] EMPTY_SESSIONS = new OpenSslSession[0];
41  
42      private static final int DEFAULT_CACHE_SIZE;
43      static {
44          // Respect the same system property as the JDK implementation to make it easy to switch between implementations.
45          int cacheSize = SystemPropertyUtil.getInt("javax.net.ssl.sessionCacheSize", 20480);
46          if (cacheSize >= 0) {
47              DEFAULT_CACHE_SIZE = cacheSize;
48          } else {
49              DEFAULT_CACHE_SIZE = 20480;
50          }
51      }
52      private final OpenSslEngineMap engineMap;
53  
54      private final Map<OpenSslSessionId, NativeSslSession> sessions =
55              new LinkedHashMap<>() {
56  
57                  private static final long serialVersionUID = -7773696788135734448L;
58  
59                  @Override
60                  protected boolean removeEldestEntry(Map.Entry<OpenSslSessionId, NativeSslSession> eldest) {
61                      int maxSize = maximumCacheSize.get();
62                      if (maxSize >= 0 && size() > maxSize) {
63                          removeSessionWithId(eldest.getKey());
64                      }
65                      // We always need to return false as we modify the map directly.
66                      return false;
67                  }
68              };
69  
70      private final AtomicInteger maximumCacheSize = new AtomicInteger(DEFAULT_CACHE_SIZE);
71  
72      // Let's use the same default value as OpenSSL does.
73      // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_default_timeout.html
74      private final AtomicInteger sessionTimeout = new AtomicInteger(300);
75      private int sessionCounter;
76  
77      OpenSslSessionCache(OpenSslEngineMap engineMap) {
78          this.engineMap = engineMap;
79      }
80  
81      final void setSessionTimeout(int seconds) {
82          int oldTimeout = sessionTimeout.getAndSet(seconds);
83          if (oldTimeout > seconds) {
84              // Drain the whole cache as this way we can use the ordering of the LinkedHashMap to detect early
85              // if there are any other sessions left that are invalid.
86              clear();
87          }
88      }
89  
90      final int getSessionTimeout() {
91          return sessionTimeout.get();
92      }
93  
94      /**
95       * Called once a new {@link OpenSslSession} was created.
96       *
97       * @param session the new session.
98       * @return {@code true} if the session should be cached, {@code false} otherwise.
99       */
100     protected boolean sessionCreated(NativeSslSession session) {
101         return true;
102     }
103 
104     /**
105      * Called once an {@link OpenSslSession} was removed from the cache.
106      *
107      * @param session the session to remove.
108      */
109     protected void sessionRemoved(NativeSslSession session) { }
110 
111     final void setSessionCacheSize(int size) {
112         long oldSize = maximumCacheSize.getAndSet(size);
113         if (oldSize > size || size == 0) {
114             // Just keep it simple for now and drain the whole cache.
115             clear();
116         }
117     }
118 
119     final int getSessionCacheSize() {
120         return maximumCacheSize.get();
121     }
122 
123     private void expungeInvalidSessions() {
124         if (sessions.isEmpty()) {
125             return;
126         }
127         long now = System.currentTimeMillis();
128         Iterator<Map.Entry<OpenSslSessionId, NativeSslSession>> iterator = sessions.entrySet().iterator();
129         while (iterator.hasNext()) {
130             NativeSslSession session = iterator.next().getValue();
131             // As we use a LinkedHashMap we can break the while loop as soon as we find a valid session.
132             // This is true as we always drain the cache as soon as we change the timeout to a smaller value as
133             // it was set before. This way its true that the insertion order matches the timeout order.
134             if (session.isValid(now)) {
135                 break;
136             }
137             iterator.remove();
138 
139             notifyRemovalAndFree(session);
140         }
141     }
142 
143     @Override
144     public final boolean sessionCreated(long ssl, long sslSession) {
145         ReferenceCountedOpenSslEngine engine = engineMap.get(ssl);
146         if (engine == null) {
147             // We couldn't find the engine itself.
148             return false;
149         }
150         NativeSslSession session = new NativeSslSession(sslSession, engine.getPeerHost(), engine.getPeerPort(),
151                 getSessionTimeout() * 1000L);
152         engine.setSessionId(session.sessionId());
153         synchronized (this) {
154             // Mimic what OpenSSL is doing and expunge every 255 new sessions
155             // See https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_flush_sessions.html
156             if (++sessionCounter == 255) {
157                 sessionCounter = 0;
158                 expungeInvalidSessions();
159             }
160 
161             if (!sessionCreated(session)) {
162                 // Should not be cached, return false. In this case we also need to call close() to ensure we
163                 // close the ResourceLeakTracker.
164                 session.close();
165                 return false;
166             }
167 
168             final NativeSslSession old = sessions.put(session.sessionId(), session);
169             if (old != null) {
170                 notifyRemovalAndFree(old);
171             }
172         }
173         return true;
174     }
175 
176     @Override
177     public final long getSession(long ssl, byte[] sessionId) {
178         OpenSslSessionId id = new OpenSslSessionId(sessionId);
179         final NativeSslSession session;
180         synchronized (this) {
181             session = sessions.get(id);
182             if (session == null) {
183                 return -1;
184             }
185 
186             // If the session is not valid anymore we should remove it from the cache and just signal back
187             // that we couldn't find a session that is re-usable.
188             if (!session.isValid() ||
189                     // This needs to happen in the synchronized block so we ensure we never destroy it before we
190                     // incremented the reference count. If we cant increment the reference count there is something
191                     // wrong. In this case just remove the session from the cache and signal back that we couldn't
192                     // find a session for re-use.
193                     !session.upRef()) {
194                 // Remove the session from the cache. This will also take care of calling SSL_SESSION_free(...)
195                 removeSessionWithId(session.sessionId());
196                 return -1;
197             }
198 
199             // At this point we already incremented the reference count via SSL_SESSION_up_ref(...).
200             if (session.shouldBeSingleUse()) {
201                 // Should only be used once. In this case invalidate the session which will also ensure we remove it
202                 // from the cache and call SSL_SESSION_free(...).
203                 removeSessionWithId(session.sessionId());
204             }
205         }
206         session.updateLastAccessedTime();
207         return session.session();
208     }
209 
210     void setSession(long ssl, String host, int port) {
211         // Do nothing by default as this needs special handling for the client side.
212     }
213 
214     /**
215      * Remove the session with the given id from the cache
216      */
217     final synchronized void removeSessionWithId(OpenSslSessionId id) {
218         NativeSslSession sslSession = sessions.remove(id);
219         if (sslSession != null) {
220             notifyRemovalAndFree(sslSession);
221         }
222     }
223 
224     /**
225      * Returns {@code true} if there is a session for the given id in the cache.
226      */
227     final synchronized boolean containsSessionWithId(OpenSslSessionId id) {
228         return sessions.containsKey(id);
229     }
230 
231     private void notifyRemovalAndFree(NativeSslSession session) {
232         sessionRemoved(session);
233         session.free();
234     }
235 
236     /**
237      * Return the {@link OpenSslSession} which is cached for the given id.
238      */
239     final synchronized OpenSslSession getSession(OpenSslSessionId id) {
240         NativeSslSession session = sessions.get(id);
241         if (session != null && !session.isValid()) {
242             // The session is not valid anymore, let's remove it and just signal back that there is no session
243             // with the given ID in the cache anymore. This also takes care of calling SSL_SESSION_free(...)
244             removeSessionWithId(session.sessionId());
245             return null;
246         }
247         return session;
248     }
249 
250     /**
251      * Returns a snapshot of the session ids of the current valid sessions.
252      */
253     final List<OpenSslSessionId> getIds() {
254         final OpenSslSession[] sessionsArray;
255         synchronized (this) {
256             sessionsArray = sessions.values().toArray(EMPTY_SESSIONS);
257         }
258         List<OpenSslSessionId> ids = new ArrayList<>(sessionsArray.length);
259         for (OpenSslSession session: sessionsArray) {
260             if (session.isValid()) {
261                 ids.add(session.sessionId());
262             }
263         }
264         return ids;
265     }
266 
267     /**
268      * Clear the cache and free all cached SSL_SESSION*.
269      */
270     synchronized void clear() {
271         Iterator<Map.Entry<OpenSslSessionId, NativeSslSession>> iterator = sessions.entrySet().iterator();
272         while (iterator.hasNext()) {
273             NativeSslSession session = iterator.next().getValue();
274             iterator.remove();
275 
276             // Notify about removal. This also takes care of calling SSL_SESSION_free(...).
277             notifyRemovalAndFree(session);
278         }
279     }
280 
281     /**
282      * {@link OpenSslSession} implementation which wraps the native SSL_SESSION* while in cache.
283      */
284     static final class NativeSslSession implements OpenSslSession {
285         static final ResourceLeakDetector<NativeSslSession> LEAK_DETECTOR = ResourceLeakDetectorFactory.instance()
286                 .newResourceLeakDetector(NativeSslSession.class);
287         private final ResourceLeakTracker<NativeSslSession> leakTracker;
288         private final long session;
289         private final String peerHost;
290         private final int peerPort;
291         private final OpenSslSessionId id;
292         private final long timeout;
293         private final long creationTime = System.currentTimeMillis();
294         private volatile long lastAccessedTime = creationTime;
295         private volatile boolean valid = true;
296         private boolean freed;
297 
298         NativeSslSession(long session, String peerHost, int peerPort, long timeout) {
299             this.session = session;
300             this.peerHost = peerHost;
301             this.peerPort = peerPort;
302             this.timeout = timeout;
303             id = new OpenSslSessionId(SSLSession.getSessionId(session));
304             leakTracker = LEAK_DETECTOR.track(this);
305         }
306 
307         @Override
308         public void setSessionId(OpenSslSessionId id) {
309             throw new UnsupportedOperationException();
310         }
311 
312         boolean shouldBeSingleUse() {
313             assert !freed;
314             return SSLSession.shouldBeSingleUse(session);
315         }
316 
317         long session() {
318             assert !freed;
319             return session;
320         }
321 
322         boolean upRef() {
323             assert !freed;
324             return SSLSession.upRef(session);
325         }
326 
327         synchronized void free() {
328             close();
329             SSLSession.free(session);
330         }
331 
332         void close() {
333             assert !freed;
334             freed = true;
335             invalidate();
336             if (leakTracker != null) {
337                 leakTracker.close(this);
338             }
339         }
340 
341         @Override
342         public OpenSslSessionId sessionId() {
343             return id;
344         }
345 
346         boolean isValid(long now) {
347             return creationTime + timeout >= now && valid;
348         }
349 
350         @Override
351         public void setLocalCertificate(Certificate[] localCertificate) {
352             throw new UnsupportedOperationException();
353         }
354 
355         @Override
356         public OpenSslSessionContext getSessionContext() {
357             return null;
358         }
359 
360         @Override
361         public void tryExpandApplicationBufferSize(int packetLengthDataOnly) {
362             throw new UnsupportedOperationException();
363         }
364 
365         @Override
366         public void handshakeFinished(byte[] id, String cipher, String protocol, byte[] peerCertificate,
367                                       byte[][] peerCertificateChain, long creationTime, long timeout) {
368             throw new UnsupportedOperationException();
369         }
370 
371         @Override
372         public byte[] getId() {
373             return id.cloneBytes();
374         }
375 
376         @Override
377         public long getCreationTime() {
378             return creationTime;
379         }
380 
381         void updateLastAccessedTime() {
382             lastAccessedTime = System.currentTimeMillis();
383         }
384 
385         @Override
386         public long getLastAccessedTime() {
387             return lastAccessedTime;
388         }
389 
390         @Override
391         public void invalidate() {
392             valid = false;
393         }
394 
395         @Override
396         public boolean isValid() {
397             return isValid(System.currentTimeMillis());
398         }
399 
400         @Override
401         public void putValue(String name, Object value) {
402             throw new UnsupportedOperationException();
403         }
404 
405         @Override
406         public Object getValue(String name) {
407             return null;
408         }
409 
410         @Override
411         public void removeValue(String name) {
412             // NOOP
413         }
414 
415         @Override
416         public String[] getValueNames() {
417             return EmptyArrays.EMPTY_STRINGS;
418         }
419 
420         @Override
421         public Certificate[] getPeerCertificates() {
422             throw new UnsupportedOperationException();
423         }
424 
425         @Override
426         public Certificate[] getLocalCertificates() {
427             throw new UnsupportedOperationException();
428         }
429 
430         @Override
431         public X509Certificate[] getPeerCertificateChain() {
432             throw new UnsupportedOperationException();
433         }
434 
435         @Override
436         public Principal getPeerPrincipal() {
437             throw new UnsupportedOperationException();
438         }
439 
440         @Override
441         public Principal getLocalPrincipal() {
442             throw new UnsupportedOperationException();
443         }
444 
445         @Override
446         public String getCipherSuite() {
447             return null;
448         }
449 
450         @Override
451         public String getProtocol() {
452             return null;
453         }
454 
455         @Override
456         public String getPeerHost() {
457             return peerHost;
458         }
459 
460         @Override
461         public int getPeerPort() {
462             return peerPort;
463         }
464 
465         @Override
466         public int getPacketBufferSize() {
467             return ReferenceCountedOpenSslEngine.MAX_RECORD_SIZE;
468         }
469 
470         @Override
471         public int getApplicationBufferSize() {
472             return ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH;
473         }
474 
475         @Override
476         public int hashCode() {
477             return id.hashCode();
478         }
479 
480         @Override
481         public boolean equals(Object o) {
482             if (this == o) {
483                 return true;
484             }
485             if (!(o instanceof OpenSslSession)) {
486                 return false;
487             }
488             OpenSslSession session1 = (OpenSslSession) o;
489             return id.equals(session1.sessionId());
490         }
491     }
492 }