1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.quic;
17
18 import io.netty.util.AsciiString;
19 import io.netty.util.internal.SystemPropertyUtil;
20 import org.jetbrains.annotations.Nullable;
21
22 import java.util.Iterator;
23 import java.util.LinkedHashMap;
24 import java.util.Map;
25 import java.util.concurrent.atomic.AtomicInteger;
26
27 final class QuicClientSessionCache {
28
29 private static final int DEFAULT_CACHE_SIZE;
30 static {
31
32 int cacheSize = SystemPropertyUtil.getInt("javax.net.ssl.sessionCacheSize", 20480);
33 if (cacheSize >= 0) {
34 DEFAULT_CACHE_SIZE = cacheSize;
35 } else {
36 DEFAULT_CACHE_SIZE = 20480;
37 }
38 }
39
40 private final AtomicInteger maximumCacheSize = new AtomicInteger(DEFAULT_CACHE_SIZE);
41
42
43
44 private final AtomicInteger sessionTimeout = new AtomicInteger(300);
45 private int sessionCounter;
46
47 private final Map<HostPort, SessionHolder> sessions =
48 new LinkedHashMap<HostPort, SessionHolder>() {
49
50 private static final long serialVersionUID = -7773696788135734448L;
51
52 @Override
53 protected boolean removeEldestEntry(Map.Entry<HostPort, SessionHolder> eldest) {
54 int maxSize = maximumCacheSize.get();
55 return maxSize >= 0 && size() > maxSize;
56 }
57 };
58
59 void saveSession(@Nullable String host, int port, long creationTime, long timeout, byte[] session,
60 boolean isSingleUse) {
61 HostPort hostPort = keyFor(host, port);
62 if (hostPort != null) {
63 synchronized (sessions) {
64
65
66 if (++sessionCounter == 255) {
67 sessionCounter = 0;
68 expungeInvalidSessions();
69 }
70
71 sessions.put(hostPort, new SessionHolder(creationTime, timeout, session, isSingleUse));
72 }
73 }
74 }
75
76
77 boolean hasSession(@Nullable String host, int port) {
78 HostPort hostPort = keyFor(host, port);
79 if (hostPort != null) {
80 synchronized (sessions) {
81 return sessions.containsKey(hostPort);
82 }
83 }
84 return false;
85 }
86
87 byte @Nullable [] getSession(@Nullable String host, int port) {
88 HostPort hostPort = keyFor(host, port);
89 if (hostPort != null) {
90 SessionHolder sessionHolder;
91 synchronized (sessions) {
92 sessionHolder = sessions.get(hostPort);
93 if (sessionHolder == null) {
94 return null;
95 }
96 if (sessionHolder.isSingleUse()) {
97
98 sessions.remove(hostPort);
99 }
100 }
101 if (sessionHolder.isValid()) {
102 return sessionHolder.sessionBytes();
103 }
104 }
105 return null;
106 }
107
108 void removeSession(@Nullable String host, int port) {
109 HostPort hostPort = keyFor(host, port);
110 if (hostPort != null) {
111 synchronized (sessions) {
112 sessions.remove(hostPort);
113 }
114 }
115 }
116
117 void setSessionTimeout(int seconds) {
118 int oldTimeout = sessionTimeout.getAndSet(seconds);
119 if (oldTimeout > seconds) {
120
121
122 clear();
123 }
124 }
125
126 int getSessionTimeout() {
127 return sessionTimeout.get();
128 }
129
130 void setSessionCacheSize(int size) {
131 long oldSize = maximumCacheSize.getAndSet(size);
132 if (oldSize > size || size == 0) {
133
134 clear();
135 }
136 }
137
138 int getSessionCacheSize() {
139 return maximumCacheSize.get();
140 }
141
142
143
144
145 void clear() {
146 synchronized (sessions) {
147 sessions.clear();
148 }
149 }
150
151 private void expungeInvalidSessions() {
152 assert Thread.holdsLock(sessions);
153
154 if (sessions.isEmpty()) {
155 return;
156 }
157 long now = System.currentTimeMillis();
158 Iterator<Map.Entry<HostPort, SessionHolder>> iterator = sessions.entrySet().iterator();
159 while (iterator.hasNext()) {
160 SessionHolder sessionHolder = iterator.next().getValue();
161
162
163
164 if (sessionHolder.isValid(now)) {
165 break;
166 }
167 iterator.remove();
168 }
169 }
170
171 @Nullable
172 private static HostPort keyFor(@Nullable String host, int port) {
173 if (host == null && port < 1) {
174 return null;
175 }
176 return new HostPort(host, port);
177 }
178
179 private static final class SessionHolder {
180 private final long creationTime;
181 private final long timeout;
182 private final byte[] sessionBytes;
183 private final boolean isSingleUse;
184
185 SessionHolder(long creationTime, long timeout, byte[] session, boolean isSingleUse) {
186 this.creationTime = creationTime;
187 this.timeout = timeout;
188 this.sessionBytes = session;
189 this.isSingleUse = isSingleUse;
190 }
191
192 boolean isValid() {
193 return isValid(System.currentTimeMillis());
194 }
195
196 boolean isValid(long current) {
197 return current <= creationTime + timeout;
198 }
199
200 boolean isSingleUse() {
201 return isSingleUse;
202 }
203
204 byte[] sessionBytes() {
205 return sessionBytes;
206 }
207 }
208
209
210
211
212 private static final class HostPort {
213 private final int hash;
214 private final String host;
215 private final int port;
216
217 HostPort(@Nullable String host, int port) {
218 this.host = host;
219 this.port = port;
220
221 this.hash = 31 * AsciiString.hashCode(host) + port;
222 }
223
224 @Override
225 public int hashCode() {
226 return hash;
227 }
228
229 @Override
230 public boolean equals(Object obj) {
231 if (!(obj instanceof HostPort)) {
232 return false;
233 }
234 HostPort other = (HostPort) obj;
235 return port == other.port && host.equalsIgnoreCase(other.host);
236 }
237
238 @Override
239 public String toString() {
240 return "HostPort{" +
241 "host='" + host + '\'' +
242 ", port=" + port +
243 '}';
244 }
245 }
246 }