View Javadoc
1   /*
2    * Copyright 2021 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.quic;
17  
18  import io.netty.util.AsciiString;
19  import io.netty.util.internal.SystemPropertyUtil;
20  import org.jetbrains.annotations.Nullable;
21  
22  import java.util.Iterator;
23  import java.util.LinkedHashMap;
24  import java.util.Map;
25  import java.util.concurrent.atomic.AtomicInteger;
26  
27  final class QuicClientSessionCache {
28  
29      private static final int DEFAULT_CACHE_SIZE;
30      static {
31          // Respect the same system property as the JDK implementation to make it easy to switch between implementations.
32          int cacheSize = SystemPropertyUtil.getInt("javax.net.ssl.sessionCacheSize", 20480);
33          if (cacheSize >= 0) {
34              DEFAULT_CACHE_SIZE = cacheSize;
35          } else {
36              DEFAULT_CACHE_SIZE = 20480;
37          }
38      }
39  
40      private final AtomicInteger maximumCacheSize = new AtomicInteger(DEFAULT_CACHE_SIZE);
41  
42      // Let's use the same default value as OpenSSL does.
43      // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_default_timeout.html
44      private final AtomicInteger sessionTimeout = new AtomicInteger(300);
45      private int sessionCounter;
46  
47      private final Map<HostPort, SessionHolder> sessions =
48              new LinkedHashMap<HostPort, SessionHolder>() {
49  
50                  private static final long serialVersionUID = -7773696788135734448L;
51  
52                  @Override
53                  protected boolean removeEldestEntry(Map.Entry<HostPort, SessionHolder> eldest) {
54                      int maxSize = maximumCacheSize.get();
55                      return maxSize >= 0 && size() > maxSize;
56                  }
57              };
58  
59      void saveSession(@Nullable String host, int port, long creationTime, long timeout, byte[] session,
60                       boolean isSingleUse) {
61          HostPort hostPort = keyFor(host, port);
62          if (hostPort != null) {
63              synchronized (sessions) {
64                  // Mimic what OpenSSL is doing and expunge every 255 new sessions
65                  // See https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_flush_sessions.html
66                  if (++sessionCounter == 255) {
67                      sessionCounter = 0;
68                      expungeInvalidSessions();
69                  }
70  
71                  sessions.put(hostPort, new SessionHolder(creationTime, timeout, session, isSingleUse));
72              }
73          }
74      }
75  
76      // Only used for testing.
77      boolean hasSession(@Nullable String host, int port) {
78          HostPort hostPort = keyFor(host, port);
79          if (hostPort != null) {
80              synchronized (sessions) {
81                  return sessions.containsKey(hostPort);
82              }
83          }
84          return false;
85      }
86  
87      byte @Nullable [] getSession(@Nullable String host, int port) {
88          HostPort hostPort = keyFor(host, port);
89          if (hostPort != null) {
90              SessionHolder sessionHolder;
91              synchronized (sessions) {
92                  sessionHolder = sessions.get(hostPort);
93                  if (sessionHolder == null) {
94                      return null;
95                  }
96                  if (sessionHolder.isSingleUse()) {
97                      // Remove session as it should only be re-used once.
98                      sessions.remove(hostPort);
99                  }
100             }
101             if (sessionHolder.isValid()) {
102                 return sessionHolder.sessionBytes();
103             }
104         }
105         return null;
106     }
107 
108     void removeSession(@Nullable String host, int port) {
109         HostPort hostPort = keyFor(host, port);
110         if (hostPort != null) {
111             synchronized (sessions) {
112                 sessions.remove(hostPort);
113             }
114         }
115     }
116 
117     void setSessionTimeout(int seconds) {
118         int oldTimeout = sessionTimeout.getAndSet(seconds);
119         if (oldTimeout > seconds) {
120             // Drain the whole cache as this way we can use the ordering of the LinkedHashMap to detect early
121             // if there are any other sessions left that are invalid.
122             clear();
123         }
124     }
125 
126     int getSessionTimeout() {
127         return sessionTimeout.get();
128     }
129 
130     void setSessionCacheSize(int size) {
131         long oldSize = maximumCacheSize.getAndSet(size);
132         if (oldSize > size || size == 0) {
133             // Just keep it simple for now and drain the whole cache.
134             clear();
135         }
136     }
137 
138     int getSessionCacheSize() {
139         return maximumCacheSize.get();
140     }
141 
142     /**
143      * Clear the cache and free all cached SSL_SESSION*.
144      */
145     void clear() {
146         synchronized (sessions) {
147             sessions.clear();
148         }
149     }
150 
151     private void expungeInvalidSessions() {
152         assert Thread.holdsLock(sessions);
153 
154         if (sessions.isEmpty()) {
155             return;
156         }
157         long now = System.currentTimeMillis();
158         Iterator<Map.Entry<HostPort, SessionHolder>> iterator = sessions.entrySet().iterator();
159         while (iterator.hasNext()) {
160             SessionHolder sessionHolder = iterator.next().getValue();
161             // As we use a LinkedHashMap we can break the while loop as soon as we find a valid session.
162             // This is true as we always drain the cache as soon as we change the timeout to a smaller value as
163             // it was set before. This way its true that the insertion order matches the timeout order.
164             if (sessionHolder.isValid(now)) {
165                 break;
166             }
167             iterator.remove();
168         }
169     }
170 
171     @Nullable
172     private static HostPort keyFor(@Nullable String host, int port) {
173         if (host == null && port < 1) {
174             return null;
175         }
176         return new HostPort(host, port);
177     }
178 
179     private static final class SessionHolder {
180         private final long creationTime;
181         private final long timeout;
182         private final byte[] sessionBytes;
183         private final boolean isSingleUse;
184 
185         SessionHolder(long creationTime, long timeout, byte[] session, boolean isSingleUse) {
186             this.creationTime = creationTime;
187             this.timeout = timeout;
188             this.sessionBytes = session;
189             this.isSingleUse = isSingleUse;
190         }
191 
192         boolean isValid() {
193             return isValid(System.currentTimeMillis());
194         }
195 
196         boolean isValid(long current) {
197             return current <= creationTime + timeout;
198         }
199 
200         boolean isSingleUse() {
201             return isSingleUse;
202         }
203 
204         byte[] sessionBytes() {
205             return sessionBytes;
206         }
207     }
208 
209     /**
210      * Host / Port tuple used to find a session in the cache.
211      */
212     private static final class HostPort {
213         private final int hash;
214         private final String host;
215         private final int port;
216 
217         HostPort(@Nullable String host, int port) {
218             this.host = host;
219             this.port = port;
220             // Calculate a hashCode that does ignore case.
221             this.hash = 31 * AsciiString.hashCode(host) + port;
222         }
223 
224         @Override
225         public int hashCode() {
226             return hash;
227         }
228 
229         @Override
230         public boolean equals(Object obj) {
231             if (!(obj instanceof HostPort)) {
232                 return false;
233             }
234             HostPort other = (HostPort) obj;
235             return port == other.port && host.equalsIgnoreCase(other.host);
236         }
237 
238         @Override
239         public String toString() {
240             return "HostPort{" +
241                     "host='" + host + '\'' +
242                     ", port=" + port +
243                     '}';
244         }
245     }
246 }