Skip to content

Commit a18d1a6

Browse files
committed
update AdvancedTlsX509KeyManager to support key alias for reloaded cert
1 parent 31fdb6c commit a18d1a6

3 files changed

Lines changed: 161 additions & 38 deletions

File tree

netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -436,16 +436,13 @@ public void onFileReloadingTrustManagerBadInitialContentTest() throws Exception
436436
}
437437

438438
@Test
439-
public void keyManagerAliasesTest() {
439+
public void keyManagerAliasesTest() throws Exception {
440440
AdvancedTlsX509KeyManager km = new AdvancedTlsX509KeyManager();
441-
assertArrayEquals(
442-
new String[] {"default"}, km.getClientAliases("", null));
443-
assertEquals(
444-
"default", km.chooseClientAlias(new String[] {"default"}, null, null));
445-
assertArrayEquals(
446-
new String[] {"default"}, km.getServerAliases("", null));
447-
assertEquals(
448-
"default", km.chooseServerAlias("default", null, null));
441+
km.updateIdentityCredentials(serverCert0, serverKey0);
442+
assertArrayEquals(new String[] {"key-1"}, km.getClientAliases("", null));
443+
assertEquals("key-1", km.chooseClientAlias(new String[] {"key-1"}, null, null));
444+
assertArrayEquals(new String[] {"key-1"}, km.getServerAliases("", null));
445+
assertEquals("key-1", km.chooseServerAlias("key-1", null, null));
449446
}
450447

451448
@Test

util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.security.PrivateKey;
3030
import java.security.cert.X509Certificate;
3131
import java.util.Arrays;
32+
import java.util.concurrent.atomic.AtomicInteger;
3233
import java.util.concurrent.ScheduledExecutorService;
3334
import java.util.concurrent.ScheduledFuture;
3435
import java.util.concurrent.TimeUnit;
@@ -40,59 +41,97 @@
4041
/**
4142
* AdvancedTlsX509KeyManager is an {@code X509ExtendedKeyManager} that allows users to configure
4243
* advanced TLS features, such as private key and certificate chain reloading.
44+
*
45+
* <p>The key material alias increments on every credential load (e.g. {@code "key-1"},
46+
* {@code "key-2"}, ...), ensuring the same alias always maps to the same key material. This is
47+
* required by Netty's {@code OpenSslCachingX509KeyManagerFactory} to correctly cache key
48+
* material and create a new cache entry on cert reload.
49+
*
50+
* <p>When using {@code SslProvider.OPENSSL}, wrap this key manager in Netty's
51+
* {@code OpenSslCachingX509KeyManagerFactory} to avoid per-handshake key material encoding
52+
* overhead, e.g. {@code new OpenSslCachingX509KeyManagerFactory(
53+
* new KeyManagerFactoryWrapper(advancedTlsKeyManager))}, and pass the factory to
54+
* {@code SslContextBuilder} instead of the key manager directly.
4355
*/
4456
public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager {
4557
private static final Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName());
4658
// Minimum allowed period for refreshing files with credential information.
47-
private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1 ;
59+
private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1;
60+
// Prefix for the key material alias; revision counter appended on each credential load.
61+
static final String ALIAS_PREFIX = "key-";
62+
63+
private final AtomicInteger revision = new AtomicInteger(0);
64+
private final int revisionWarningThreshold;
4865
// The credential information to be sent to peers to prove our identity.
4966
private volatile KeyInfo keyInfo;
5067

68+
public AdvancedTlsX509KeyManager() {
69+
// Netty's default OpenSslCachingX509KeyManagerFactory maxCachedEntries.
70+
this(1024);
71+
}
72+
73+
/**
74+
* Creates a key manager with a custom revision warning threshold.
75+
* @param revisionWarningThreshold the number of credential loads after which a warning is logged.
76+
* Only relevant when using {@code SslProvider.OPENSSL} with
77+
* {@code OpenSslCachingX509KeyManagerFactory}.
78+
*/
79+
public AdvancedTlsX509KeyManager(int revisionWarningThreshold) {
80+
this.revisionWarningThreshold = revisionWarningThreshold;
81+
}
82+
83+
private String alias() {
84+
KeyInfo info = this.keyInfo;
85+
if (info == null) {
86+
return null;
87+
}
88+
return info.alias;
89+
}
90+
5191
@Override
5292
public PrivateKey getPrivateKey(String alias) {
53-
if (alias.equals("default")) {
54-
return this.keyInfo.key;
55-
}
56-
return null;
93+
KeyInfo info = this.keyInfo;
94+
return info != null && info.alias.equals(alias) ? info.key : null;
5795
}
5896

5997
@Override
6098
public X509Certificate[] getCertificateChain(String alias) {
61-
if (alias.equals("default")) {
62-
return Arrays.copyOf(this.keyInfo.certs, this.keyInfo.certs.length);
63-
}
64-
return null;
99+
KeyInfo info = this.keyInfo;
100+
return info != null && info.alias.equals(alias)
101+
? Arrays.copyOf(info.certs, info.certs.length) : null;
65102
}
66103

67104
@Override
68105
public String[] getClientAliases(String keyType, Principal[] issuers) {
69-
return new String[] {"default"};
106+
String alias = alias();
107+
return alias != null ? new String[] {alias} : null;
70108
}
71109

72110
@Override
73111
public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) {
74-
return "default";
112+
return alias();
75113
}
76114

77115
@Override
78116
public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) {
79-
return "default";
117+
return alias();
80118
}
81119

82120
@Override
83121
public String[] getServerAliases(String keyType, Principal[] issuers) {
84-
return new String[] {"default"};
122+
String alias = alias();
123+
return alias != null ? new String[] {alias} : null;
85124
}
86125

87126
@Override
88127
public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) {
89-
return "default";
128+
return alias();
90129
}
91130

92131
@Override
93132
public String chooseEngineServerAlias(String keyType, Principal[] issuers,
94133
SSLEngine engine) {
95-
return "default";
134+
return alias();
96135
}
97136

98137
/**
@@ -116,7 +155,15 @@ public void updateIdentityCredentials(PrivateKey key, X509Certificate[] certs) {
116155
* @param key the private key that is going to be used
117156
*/
118157
public void updateIdentityCredentials(X509Certificate[] certs, PrivateKey key) {
119-
this.keyInfo = new KeyInfo(checkNotNull(certs, "certs"), checkNotNull(key, "key"));
158+
// When using SslProvider.OPENSSL with OpenSslCachingX509KeyManagerFactory, its cache stops
159+
// accepting new aliases once maxCachedEntries is reached (default: 1024). Beyond this,
160+
// handshakes still succeed but per-handshake re-encoding overhead resumes.
161+
if (revision.get() >= revisionWarningThreshold) {
162+
log.warning("AdvancedTlsX509KeyManager: revision counter has reached "
163+
+ revisionWarningThreshold);
164+
}
165+
this.keyInfo = new KeyInfo(checkNotNull(certs, "certs"), checkNotNull(key, "key"),
166+
ALIAS_PREFIX + revision.incrementAndGet());
120167
}
121168

122169
/**
@@ -218,10 +265,12 @@ private static class KeyInfo {
218265
// The private key and the cert chain we will use to send to peers to prove our identity.
219266
final X509Certificate[] certs;
220267
final PrivateKey key;
268+
final String alias;
221269

222-
public KeyInfo(X509Certificate[] certs, PrivateKey key) {
270+
public KeyInfo(X509Certificate[] certs, PrivateKey key, String alias) {
223271
this.certs = certs;
224272
this.key = key;
273+
this.alias = alias;
225274
}
226275
}
227276

@@ -309,4 +358,3 @@ public interface Closeable extends java.io.Closeable {
309358
void close();
310359
}
311360
}
312-

util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java

Lines changed: 89 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import static org.junit.Assert.assertArrayEquals;
2020
import static org.junit.Assert.assertEquals;
21+
import static org.junit.Assert.assertFalse;
22+
import static org.junit.Assert.assertNull;
2123
import static org.junit.Assert.assertThrows;
2224
import static org.junit.Assert.assertTrue;
2325
import static org.junit.Assert.fail;
@@ -48,7 +50,6 @@ public class AdvancedTlsX509KeyManagerTest {
4850
private static final String SERVER_0_PEM_FILE = "server0.pem";
4951
private static final String CLIENT_0_KEY_FILE = "client.key";
5052
private static final String CLIENT_0_PEM_FILE = "client.pem";
51-
private static final String ALIAS = "default";
5253

5354
private ScheduledExecutorService executor;
5455

@@ -79,22 +80,99 @@ public void setUp() throws Exception {
7980
public void updateTrustCredentials_replacesIssuers() throws Exception {
8081
// Overall happy path checking of public API.
8182
AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager();
83+
8284
serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0);
83-
assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS));
84-
assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS));
85+
String alias1 = serverKeyManager.chooseEngineServerAlias(null, null, null);
86+
assertEquals(AdvancedTlsX509KeyManager.ALIAS_PREFIX + "1", alias1);
87+
assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias1));
88+
assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias1));
8589

8690
serverKeyManager.updateIdentityCredentials(clientCert0File, clientKey0File);
87-
assertEquals(clientKey0, serverKeyManager.getPrivateKey(ALIAS));
88-
assertArrayEquals(clientCert0, serverKeyManager.getCertificateChain(ALIAS));
89-
90-
serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File,1,
91+
String alias2 = serverKeyManager.chooseEngineServerAlias(null, null, null);
92+
assertEquals(AdvancedTlsX509KeyManager.ALIAS_PREFIX + "2", alias2);
93+
assertEquals(clientKey0, serverKeyManager.getPrivateKey(alias2));
94+
assertArrayEquals(clientCert0, serverKeyManager.getCertificateChain(alias2));
95+
// Old alias no longer resolves — ensures alias stability contract is enforced.
96+
assertNull(serverKeyManager.getPrivateKey(alias1));
97+
98+
serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File, 1,
9199
TimeUnit.MINUTES, executor);
92-
assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS));
93-
assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS));
100+
String alias3 = serverKeyManager.chooseEngineServerAlias(null, null, null);
101+
assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias3));
102+
assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias3));
94103

95104
serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0);
96-
assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS));
97-
assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS));
105+
String alias4 = serverKeyManager.chooseEngineServerAlias(null, null, null);
106+
assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias4));
107+
assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias4));
108+
}
109+
110+
@Test
111+
public void allAliasMethods_returnNullBeforeCredentialsLoaded() {
112+
AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager();
113+
114+
assertNull(keyManager.chooseClientAlias(null, null, null));
115+
assertNull(keyManager.chooseServerAlias(null, null, null));
116+
assertNull(keyManager.chooseEngineClientAlias(null, null, null));
117+
assertNull(keyManager.chooseEngineServerAlias(null, null, null));
118+
assertNull(keyManager.getClientAliases(null, null));
119+
assertNull(keyManager.getServerAliases(null, null));
120+
assertNull(keyManager.getPrivateKey("key-1"));
121+
assertNull(keyManager.getCertificateChain("key-1"));
122+
}
123+
124+
@Test
125+
public void allAliasMethods_agreeAfterCredentialLoad() throws Exception {
126+
AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager();
127+
keyManager.updateIdentityCredentials(serverCert0, serverKey0);
128+
129+
String expectedAlias = AdvancedTlsX509KeyManager.ALIAS_PREFIX + "1";
130+
assertEquals(expectedAlias, keyManager.chooseClientAlias(null, null, null));
131+
assertEquals(expectedAlias, keyManager.chooseServerAlias(null, null, null));
132+
assertEquals(expectedAlias, keyManager.chooseEngineClientAlias(null, null, null));
133+
assertEquals(expectedAlias, keyManager.chooseEngineServerAlias(null, null, null));
134+
assertArrayEquals(new String[]{expectedAlias}, keyManager.getClientAliases(null, null));
135+
assertArrayEquals(new String[]{expectedAlias}, keyManager.getServerAliases(null, null));
136+
}
137+
138+
@Test
139+
public void revisionWarningThreshold_logsWarningAtThreshold() throws Exception {
140+
Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName());
141+
TestHandler handler = new TestHandler();
142+
log.addHandler(handler);
143+
log.setUseParentHandlers(false);
144+
log.setLevel(Level.ALL);
145+
146+
try {
147+
// Custom threshold: warning when revision reaches threshold.
148+
int threshold = 3;
149+
AdvancedTlsX509KeyManager customKeyManager = new AdvancedTlsX509KeyManager(threshold);
150+
for (int i = 0; i < threshold; i++) {
151+
customKeyManager.updateIdentityCredentials(serverCert0, serverKey0);
152+
}
153+
assertFalse(hasRevisionWarning(handler));
154+
customKeyManager.updateIdentityCredentials(serverCert0, serverKey0);
155+
assertTrue(hasRevisionWarning(handler));
156+
157+
// Key manager must still provide credentials correctly after soft threshold is exceeded.
158+
String alias = customKeyManager.chooseEngineServerAlias(null, null, null);
159+
assertEquals(serverKey0, customKeyManager.getPrivateKey(alias));
160+
assertArrayEquals(serverCert0, customKeyManager.getCertificateChain(alias));
161+
162+
// Further credential updates must also work.
163+
customKeyManager.updateIdentityCredentials(clientCert0File, clientKey0File);
164+
String newAlias = customKeyManager.chooseEngineServerAlias(null, null, null);
165+
assertEquals(clientKey0, customKeyManager.getPrivateKey(newAlias));
166+
assertArrayEquals(clientCert0, customKeyManager.getCertificateChain(newAlias));
167+
} finally {
168+
log.removeHandler(handler);
169+
}
170+
}
171+
172+
private static boolean hasRevisionWarning(TestHandler handler) {
173+
return handler.getRecords().stream()
174+
.anyMatch(r -> Level.WARNING.equals(r.getLevel())
175+
&& r.getMessage().contains("revision counter has reached"));
98176
}
99177

100178
@Test

0 commit comments

Comments
 (0)