View Javadoc
1   /*
2    * Copyright 2021 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.ssl;
17  
18  import io.netty.internal.tcnative.SSLSession;
19  import io.netty.internal.tcnative.SSLSessionCache;
20  import io.netty.util.ResourceLeakDetector;
21  import io.netty.util.ResourceLeakDetectorFactory;
22  import io.netty.util.ResourceLeakTracker;
23  import io.netty.util.internal.EmptyArrays;
24  import io.netty.util.internal.SystemPropertyUtil;
25  
26  import javax.security.cert.X509Certificate;
27  import java.security.Principal;
28  import java.security.cert.Certificate;
29  import java.util.ArrayList;
30  import java.util.Iterator;
31  import java.util.LinkedHashMap;
32  import java.util.List;
33  import java.util.Map;
34  import java.util.concurrent.atomic.AtomicInteger;
35  
36  /**
37   * {@link SSLSessionCache} implementation for our native SSL implementation.
38   */
39  class OpenSslSessionCache implements SSLSessionCache {
40      private static final OpenSslSession[] EMPTY_SESSIONS = new OpenSslSession[0];
41  
42      private static final int DEFAULT_CACHE_SIZE;
43      static {
44          // Respect the same system property as the JDK implementation to make it easy to switch between implementations.
45          int cacheSize = SystemPropertyUtil.getInt("javax.net.ssl.sessionCacheSize", 20480);
46          if (cacheSize >= 0) {
47              DEFAULT_CACHE_SIZE = cacheSize;
48          } else {
49              DEFAULT_CACHE_SIZE = 20480;
50          }
51      }
52      private final OpenSslEngineMap engineMap;
53  
54      private final Map<OpenSslSessionId, NativeSslSession> sessions =
55              new LinkedHashMap<OpenSslSessionId, NativeSslSession>() {
56  
57                  private static final long serialVersionUID = -7773696788135734448L;
58  
59                  @Override
60                  protected boolean removeEldestEntry(Map.Entry<OpenSslSessionId, NativeSslSession> eldest) {
61                      int maxSize = maximumCacheSize.get();
62                      if (maxSize >= 0 && size() > maxSize) {
63                          removeSessionWithId(eldest.getKey());
64                      }
65                      // We always need to return false as we modify the map directly.
66                      return false;
67                  }
68              };
69  
70      private final AtomicInteger maximumCacheSize = new AtomicInteger(DEFAULT_CACHE_SIZE);
71  
72      // Let's use the same default value as OpenSSL does.
73      // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_default_timeout.html
74      private final AtomicInteger sessionTimeout = new AtomicInteger(300);
75      private int sessionCounter;
76  
77      OpenSslSessionCache(OpenSslEngineMap engineMap) {
78          this.engineMap = engineMap;
79      }
80  
81      final void setSessionTimeout(int seconds) {
82          int oldTimeout = sessionTimeout.getAndSet(seconds);
83          if (oldTimeout > seconds) {
84              // Drain the whole cache as this way we can use the ordering of the LinkedHashMap to detect early
85              // if there are any other sessions left that are invalid.
86              clear();
87          }
88      }
89  
90      final int getSessionTimeout() {
91          return sessionTimeout.get();
92      }
93  
94      /**
95       * Called once a new {@link OpenSslSession} was created.
96       *
97       * @param session the new session.
98       * @return {@code true} if the session should be cached, {@code false} otherwise.
99       */
100     protected boolean sessionCreated(NativeSslSession session) {
101         return true;
102     }
103 
104     /**
105      * Called once an {@link OpenSslSession} was removed from the cache.
106      *
107      * @param session the session to remove.
108      */
109     protected void sessionRemoved(NativeSslSession session) { }
110 
111     final void setSessionCacheSize(int size) {
112         long oldSize = maximumCacheSize.getAndSet(size);
113         if (oldSize > size || size == 0) {
114             // Just keep it simple for now and drain the whole cache.
115             clear();
116         }
117     }
118 
119     final int getSessionCacheSize() {
120         return maximumCacheSize.get();
121     }
122 
123     private void expungeInvalidSessions() {
124         if (sessions.isEmpty()) {
125             return;
126         }
127         long now = System.currentTimeMillis();
128         Iterator<Map.Entry<OpenSslSessionId, NativeSslSession>> iterator = sessions.entrySet().iterator();
129         while (iterator.hasNext()) {
130             NativeSslSession session = iterator.next().getValue();
131             // As we use a LinkedHashMap we can break the while loop as soon as we find a valid session.
132             // This is true as we always drain the cache as soon as we change the timeout to a smaller value as
133             // it was set before. This way its true that the insertion order matches the timeout order.
134             if (session.isValid(now)) {
135                 break;
136             }
137             iterator.remove();
138 
139             notifyRemovalAndFree(session);
140         }
141     }
142 
143     @Override
144     public boolean sessionCreated(long ssl, long sslSession) {
145         ReferenceCountedOpenSslEngine engine = engineMap.get(ssl);
146         if (engine == null) {
147             // We couldn't find the engine itself.
148             return false;
149         }
150         OpenSslSession openSslSession = (OpenSslSession) engine.getSession();
151         // Create the native session that we will put into our cache. We will share the key-value storage
152         // with the already existing session instance.
153         NativeSslSession session = new NativeSslSession(sslSession, engine.getPeerHost(), engine.getPeerPort(),
154                 getSessionTimeout() * 1000L, openSslSession.keyValueStorage());
155 
156         openSslSession.setSessionDetails(
157                 session.creationTime, session.lastAccessedTime, session.sessionId(), session.keyValueStorage);
158         synchronized (this) {
159             // Mimic what OpenSSL is doing and expunge every 255 new sessions
160             // See https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_flush_sessions.html
161             if (++sessionCounter == 255) {
162                 sessionCounter = 0;
163                 expungeInvalidSessions();
164             }
165 
166             if (!sessionCreated(session)) {
167                 // Should not be cached, return false. In this case we also need to call close() to ensure we
168                 // close the ResourceLeakTracker.
169                 session.close();
170                 return false;
171             }
172             final NativeSslSession old = sessions.put(session.sessionId(), session);
173             if (old != null) {
174                 notifyRemovalAndFree(old);
175             }
176         }
177         return true;
178     }
179 
180     @Override
181     public final long getSession(long ssl, byte[] sessionId) {
182         OpenSslSessionId id = new OpenSslSessionId(sessionId);
183         final NativeSslSession session;
184         synchronized (this) {
185             session = sessions.get(id);
186             if (session == null) {
187                 return -1;
188             }
189 
190             // If the session is not valid anymore we should remove it from the cache and just signal back
191             // that we couldn't find a session that is re-usable.
192             if (!session.isValid() ||
193                     // This needs to happen in the synchronized block so we ensure we never destroy it before we
194                     // incremented the reference count. If we cant increment the reference count there is something
195                     // wrong. In this case just remove the session from the cache and signal back that we couldn't
196                     // find a session for re-use.
197                     !session.upRef()) {
198                 // Remove the session from the cache. This will also take care of calling SSL_SESSION_free(...)
199                 removeSessionWithId(session.sessionId());
200                 return -1;
201             }
202 
203             // At this point we already incremented the reference count via SSL_SESSION_up_ref(...).
204             if (session.shouldBeSingleUse()) {
205                 // Should only be used once. In this case invalidate the session which will also ensure we remove it
206                 // from the cache and call SSL_SESSION_free(...).
207                 removeSessionWithId(session.sessionId());
208             }
209         }
210         session.setLastAccessedTime(System.currentTimeMillis());
211         ReferenceCountedOpenSslEngine engine = engineMap.get(ssl);
212         if (engine != null) {
213             OpenSslSession sslSession = (OpenSslSession) engine.getSession();
214             sslSession.setSessionDetails(session.getCreationTime(),
215                     session.getLastAccessedTime(), session.sessionId(), session.keyValueStorage);
216         }
217 
218         return session.session();
219     }
220 
221     boolean setSession(long ssl, OpenSslSession session, String host, int port) {
222         // Do nothing by default as this needs special handling for the client side.
223        return false;
224     }
225 
226     /**
227      * Remove the session with the given id from the cache
228      */
229     final synchronized void removeSessionWithId(OpenSslSessionId id) {
230         NativeSslSession sslSession = sessions.remove(id);
231         if (sslSession != null) {
232             notifyRemovalAndFree(sslSession);
233         }
234     }
235 
236     /**
237      * Returns {@code true} if there is a session for the given id in the cache.
238      */
239     final synchronized boolean containsSessionWithId(OpenSslSessionId id) {
240         return sessions.containsKey(id);
241     }
242 
243     private void notifyRemovalAndFree(NativeSslSession session) {
244         sessionRemoved(session);
245         session.free();
246     }
247 
248     /**
249      * Return the {@link OpenSslSession} which is cached for the given id.
250      */
251     final synchronized OpenSslSession getSession(OpenSslSessionId id) {
252         NativeSslSession session = sessions.get(id);
253         if (session != null && !session.isValid()) {
254             // The session is not valid anymore, let's remove it and just signal back that there is no session
255             // with the given ID in the cache anymore. This also takes care of calling SSL_SESSION_free(...)
256             removeSessionWithId(session.sessionId());
257             return null;
258         }
259         return session;
260     }
261 
262     /**
263      * Returns a snapshot of the session ids of the current valid sessions.
264      */
265     final List<OpenSslSessionId> getIds() {
266         final OpenSslSession[] sessionsArray;
267         synchronized (this) {
268             sessionsArray = sessions.values().toArray(EMPTY_SESSIONS);
269         }
270         List<OpenSslSessionId> ids = new ArrayList<OpenSslSessionId>(sessionsArray.length);
271         for (OpenSslSession session: sessionsArray) {
272             if (session.isValid()) {
273                 ids.add(session.sessionId());
274             }
275         }
276         return ids;
277     }
278 
279     /**
280      * Clear the cache and free all cached SSL_SESSION*.
281      */
282     synchronized void clear() {
283         Iterator<Map.Entry<OpenSslSessionId, NativeSslSession>> iterator = sessions.entrySet().iterator();
284         while (iterator.hasNext()) {
285             NativeSslSession session = iterator.next().getValue();
286             iterator.remove();
287 
288             // Notify about removal. This also takes care of calling SSL_SESSION_free(...).
289             notifyRemovalAndFree(session);
290         }
291     }
292 
293     /**
294      * {@link OpenSslSession} implementation which wraps the native SSL_SESSION* while in cache.
295      */
296     static final class NativeSslSession implements OpenSslSession {
297         static final ResourceLeakDetector<NativeSslSession> LEAK_DETECTOR = ResourceLeakDetectorFactory.instance()
298                 .newResourceLeakDetector(NativeSslSession.class);
299         private final ResourceLeakTracker<NativeSslSession> leakTracker;
300 
301         final Map<String, Object> keyValueStorage;
302 
303         private final long session;
304         private final String peerHost;
305         private final int peerPort;
306         private final OpenSslSessionId id;
307         private final long timeout;
308         private final long creationTime = System.currentTimeMillis();
309         private volatile long lastAccessedTime = creationTime;
310         private volatile boolean valid = true;
311         private boolean freed;
312 
313         NativeSslSession(long session, String peerHost, int peerPort, long timeout,
314                          Map<String, Object> keyValueStorage) {
315             this.session = session;
316             this.peerHost = peerHost;
317             this.peerPort = peerPort;
318             this.timeout = timeout;
319             this.id = new OpenSslSessionId(io.netty.internal.tcnative.SSLSession.getSessionId(session));
320             this.keyValueStorage = keyValueStorage;
321             leakTracker = LEAK_DETECTOR.track(this);
322         }
323 
324         @Override
325         public Map<String, Object> keyValueStorage() {
326             return keyValueStorage;
327         }
328 
329         @Override
330         public void prepareHandshake() {
331             throw new UnsupportedOperationException();
332         }
333 
334         @Override
335         public void setSessionDetails(long creationTime, long lastAccessedTime,
336                                       OpenSslSessionId id, Map<String, Object> keyValueStorage) {
337             throw new UnsupportedOperationException();
338         }
339 
340         boolean shouldBeSingleUse() {
341             assert !freed;
342             return SSLSession.shouldBeSingleUse(session);
343         }
344 
345         long session() {
346             assert !freed;
347             return session;
348         }
349 
350         boolean upRef() {
351             assert !freed;
352             return SSLSession.upRef(session);
353         }
354 
355         synchronized void free() {
356             close();
357             SSLSession.free(session);
358         }
359 
360         void close() {
361             assert !freed;
362             freed = true;
363             invalidate();
364             if (leakTracker != null) {
365                 leakTracker.close(this);
366             }
367         }
368 
369         @Override
370         public OpenSslSessionId sessionId() {
371             return id;
372         }
373 
374         boolean isValid(long now) {
375             return creationTime + timeout >= now && valid;
376         }
377 
378         @Override
379         public void setLocalCertificate(Certificate[] localCertificate) {
380             throw new UnsupportedOperationException();
381         }
382 
383         @Override
384         public OpenSslSessionContext getSessionContext() {
385             return null;
386         }
387 
388         @Override
389         public void tryExpandApplicationBufferSize(int packetLengthDataOnly) {
390             throw new UnsupportedOperationException();
391         }
392 
393         @Override
394         public void handshakeFinished(byte[] id, String cipher, String protocol, byte[] peerCertificate,
395                                       byte[][] peerCertificateChain, long creationTime, long timeout) {
396             throw new UnsupportedOperationException();
397         }
398 
399         @Override
400         public byte[] getId() {
401             return id.cloneBytes();
402         }
403 
404         @Override
405         public long getCreationTime() {
406             return creationTime;
407         }
408 
409         @Override
410         public void setLastAccessedTime(long time) {
411             lastAccessedTime = time;
412         }
413 
414         @Override
415         public long getLastAccessedTime() {
416             return lastAccessedTime;
417         }
418 
419         @Override
420         public void invalidate() {
421             valid = false;
422         }
423 
424         @Override
425         public boolean isValid() {
426             return isValid(System.currentTimeMillis());
427         }
428 
429         @Override
430         public void putValue(String name, Object value) {
431             throw new UnsupportedOperationException();
432         }
433 
434         @Override
435         public Object getValue(String name) {
436             return null;
437         }
438 
439         @Override
440         public void removeValue(String name) {
441             // NOOP
442         }
443 
444         @Override
445         public String[] getValueNames() {
446             return EmptyArrays.EMPTY_STRINGS;
447         }
448 
449         @Override
450         public Certificate[] getPeerCertificates() {
451             throw new UnsupportedOperationException();
452         }
453 
454         @Override
455         public Certificate[] getLocalCertificates() {
456             throw new UnsupportedOperationException();
457         }
458 
459         @Override
460         public X509Certificate[] getPeerCertificateChain() {
461             throw new UnsupportedOperationException();
462         }
463 
464         @Override
465         public Principal getPeerPrincipal() {
466             throw new UnsupportedOperationException();
467         }
468 
469         @Override
470         public Principal getLocalPrincipal() {
471             throw new UnsupportedOperationException();
472         }
473 
474         @Override
475         public String getCipherSuite() {
476             return null;
477         }
478 
479         @Override
480         public String getProtocol() {
481             return null;
482         }
483 
484         @Override
485         public String getPeerHost() {
486             return peerHost;
487         }
488 
489         @Override
490         public int getPeerPort() {
491             return peerPort;
492         }
493 
494         @Override
495         public int getPacketBufferSize() {
496             return ReferenceCountedOpenSslEngine.MAX_RECORD_SIZE;
497         }
498 
499         @Override
500         public int getApplicationBufferSize() {
501             return ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH;
502         }
503 
504         @Override
505         public int hashCode() {
506             return id.hashCode();
507         }
508 
509         @Override
510         public boolean equals(Object o) {
511             if (this == o) {
512                 return true;
513             }
514             if (!(o instanceof OpenSslSession)) {
515                 return false;
516             }
517             OpenSslSession session1 = (OpenSslSession) o;
518             return id.equals(session1.sessionId());
519         }
520     }
521 }