Skip to content

Commit 38e718f

Browse files
committed
address pr comments
1 parent 0080bfd commit 38e718f

File tree

2 files changed

+91
-37
lines changed

2 files changed

+91
-37
lines changed

src/cert_fetcher.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ impl CertFetcherImpl {
9999
// We only get certs for our own node
100100
Some(w.node.as_ref()) == self.local_node.as_deref() &&
101101
// If it doesn't support HBONE it *probably* doesn't need a cert.
102-
(w.native_tunnel || w.protocol == InboundProtocol::HBONE && !self.cfg.spire_enabled)
102+
(w.native_tunnel || w.protocol == InboundProtocol::HBONE) && !self.cfg.spire_enabled
103103
}
104104

105105
fn build_key(&self, w: &Workload) -> CompositeId<RequestKey> {

src/identity/spireclient.rs

Lines changed: 90 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ impl<C: DelegatedIdentityApi> SpireClient<C> {
120120
async fn get_cert_by_pid(
121121
&self,
122122
pid: i32,
123-
wl_uid: &WorkloadUid,
123+
id: &CompositeId<WorkloadUid>,
124124
) -> Result<tls::WorkloadCertificate, Error> {
125125
let certs = self
126-
.get_cert_from_spire(DelegateAttestationRequest::Pid(pid))
126+
.get_cert_from_spire(DelegateAttestationRequest::Pid(pid), id.clone())
127127
.await;
128128

129129
let certs = match certs {
@@ -136,14 +136,15 @@ impl<C: DelegatedIdentityApi> SpireClient<C> {
136136
Ok(certs) => certs,
137137
};
138138

139-
let pid_verify = self.pid.fetch_pid(wl_uid).await;
139+
// Verify that the PID we used for attestation matches the PID associated with the workload UID
140+
let pid_verify = self.pid.fetch_pid(id.key()).await;
140141

141142
match pid_verify {
142143
Ok(fetched_pid) => {
143144
if fetched_pid.into_i32() != pid {
144145
return Err(Error::UnableToDeterminePidForWorkload(format!(
145146
"PID mismatch for workload UID {}: expected {}, got {}",
146-
wl_uid.clone().into_string(),
147+
id.key().clone().into_string(),
147148
pid,
148149
fetched_pid.into_i32()
149150
)));
@@ -152,7 +153,7 @@ impl<C: DelegatedIdentityApi> SpireClient<C> {
152153
}
153154
Err(e) => Err(Error::UnableToDeterminePidForWorkload(format!(
154155
"Failed to verify PID for workload UID {}: {}",
155-
wl_uid.clone().into_string(),
156+
id.key().clone().into_string(),
156157
e
157158
))),
158159
}
@@ -171,18 +172,18 @@ impl<C: DelegatedIdentityApi> SpireClient<C> {
171172
/// no certificates are received within timeout, or certificate construction fails.
172173
async fn get_cert_by_workload_uid(
173174
&self,
174-
wl_uid: &WorkloadUid,
175+
id: &CompositeId<WorkloadUid>,
175176
) -> Result<tls::WorkloadCertificate, Error> {
176177
tracing::info!(
177178
"Fetching PID for workload UID: {}",
178-
wl_uid.clone().into_string()
179+
id.key().clone().into_string()
179180
);
180-
let pid = self.pid.fetch_pid(wl_uid).await;
181+
let pid = self.pid.fetch_pid(id.key()).await;
181182
match pid {
182-
Ok(pid) => self.get_cert_by_pid(pid.into_i32(), wl_uid).await,
183+
Ok(pid) => self.get_cert_by_pid(pid.into_i32(), id).await,
183184
Err(e) => Err(Error::UnableToDeterminePidForWorkload(format!(
184185
"Failed to fetch PID for workload UID {}: {}",
185-
wl_uid.clone().into_string(),
186+
id.key().clone().into_string(),
186187
e
187188
))),
188189
}
@@ -265,6 +266,7 @@ impl<C: DelegatedIdentityApi> SpireClient<C> {
265266
async fn get_cert_from_spire(
266267
&self,
267268
value: DelegateAttestationRequest,
269+
id: CompositeId<WorkloadUid>,
268270
) -> Result<tls::WorkloadCertificate, Error> {
269271
// Handle nested Result types from timeout + stream operations
270272
let svid_response = self.subscribe_and_wait_for_workload_cert(value).await?;
@@ -275,14 +277,20 @@ impl<C: DelegatedIdentityApi> SpireClient<C> {
275277
// Construct the final WorkloadCertificate combining SVID and trust bundle
276278
let certs = tls::WorkloadCertificate::new_svid(&svid_response, &bundle)?;
277279

278-
let id = format!(
280+
let id_strg = format!(
279281
"spiffe://{}{}",
280282
svid_response.spiffe_id().trust_domain(),
281283
svid_response.spiffe_id().path()
282284
);
283285

284286
// Validate that the returned identity matches the requested one
285-
Identity::from_str(&id)?;
287+
if id.id().to_string() != Identity::from_str(&id_strg)?.to_string() {
288+
return Err(Error::Spiffe(format!(
289+
"Mismatched identity: expected {}, got {}",
290+
id.id(),
291+
id_strg
292+
)));
293+
}
286294

287295
Ok(certs)
288296
}
@@ -345,7 +353,7 @@ impl<C: DelegatedIdentityApi> crate::identity::CaClientTrait for SpireClient<C>
345353
&self,
346354
id: &CompositeId<WorkloadUid>,
347355
) -> Result<tls::WorkloadCertificate, Error> {
348-
self.get_cert_by_workload_uid(id.key()).await
356+
self.get_cert_by_workload_uid(id).await
349357
}
350358
}
351359

@@ -380,7 +388,7 @@ pub mod spire_tests {
380388
#[tokio::test]
381389
async fn test_get_bundle_success() {
382390
let mut mock_client = MockDelegatedIdentityApi::new();
383-
let mut pid_client = MockPidClientTrait::new();
391+
let pid_client = MockPidClientTrait::new();
384392

385393
mock_client
386394
.expect_get_x509_bundles()
@@ -403,7 +411,7 @@ pub mod spire_tests {
403411
#[tokio::test]
404412
async fn test_get_bundle_trust_domain_not_found() {
405413
let mut mock_client = MockDelegatedIdentityApi::new();
406-
let mut pid_client = MockPidClientTrait::new();
414+
let pid_client = MockPidClientTrait::new();
407415

408416
mock_client
409417
.expect_get_x509_bundles()
@@ -459,9 +467,9 @@ pub mod spire_tests {
459467

460468
let identity =
461469
Identity::from_parts("example.org".into(), "default".into(), "test-sa".into());
462-
let result = spire_client
463-
.get_cert_by_pid(10, &WorkloadUid::new("uid-123456".to_string()))
464-
.await;
470+
let composite_id =
471+
CompositeId::with_key(identity.clone(), WorkloadUid::new("uid-123456".to_string()));
472+
let result = spire_client.get_cert_by_pid(10, &composite_id).await;
465473

466474
assert!(result.is_ok());
467475

@@ -473,9 +481,6 @@ pub mod spire_tests {
473481

474482
assert!(identity.to_string() == "spiffe://example.org/ns/default/sa/test-sa");
475483

476-
let composite_id =
477-
CompositeId::with_key(identity, WorkloadUid::new("uid-123456".to_string()));
478-
479484
let fetch_result = spire_client.fetch_certificate(&composite_id).await;
480485

481486
assert!(fetch_result.is_ok());
@@ -514,9 +519,11 @@ pub mod spire_tests {
514519
Arc::new(cfg),
515520
);
516521

517-
let result = spire_client
518-
.get_cert_by_pid(10, &WorkloadUid::new("uid-123456".to_string()))
519-
.await;
522+
let identity =
523+
Identity::from_parts("example.org".into(), "default".into(), "test-sa".into());
524+
let composite_id =
525+
CompositeId::with_key(identity, WorkloadUid::new("uid-123456".to_string()));
526+
let result = spire_client.get_cert_by_pid(10, &composite_id).await;
520527

521528
assert!(result.is_err());
522529
}
@@ -660,9 +667,9 @@ pub mod spire_tests {
660667

661668
let identity =
662669
Identity::from_parts("example.org".into(), "default".into(), "test-sa".into());
663-
let result = spire_client
664-
.get_cert_by_pid(10, &WorkloadUid::new("uid-123456".to_string()))
665-
.await;
670+
let composite_id =
671+
CompositeId::with_key(identity.clone(), WorkloadUid::new("uid-123456".to_string()));
672+
let result = spire_client.get_cert_by_pid(10, &composite_id).await;
666673

667674
assert!(result.is_ok());
668675

@@ -674,9 +681,6 @@ pub mod spire_tests {
674681

675682
assert!(identity.to_string() == "spiffe://example.org/ns/default/sa/test-sa");
676683

677-
let composite_id =
678-
CompositeId::with_key(identity, WorkloadUid::new("uid-123456".to_string()));
679-
680684
let fetch_result = spire_client.fetch_certificate(&composite_id).await;
681685

682686
assert!(fetch_result.is_ok());
@@ -715,9 +719,11 @@ pub mod spire_tests {
715719
Arc::clone(&cfg),
716720
);
717721

718-
let result = spire_client
719-
.get_cert_by_pid(10, &WorkloadUid::new("uid-123456".to_string()))
720-
.await;
722+
let identity =
723+
Identity::from_parts("example.org".into(), "default".into(), "test-sa".into());
724+
let composite_id =
725+
CompositeId::with_key(identity, WorkloadUid::new("uid-123456".to_string()));
726+
let result = spire_client.get_cert_by_pid(10, &composite_id).await;
721727

722728
assert!(result.is_err());
723729

@@ -757,9 +763,11 @@ pub mod spire_tests {
757763
Arc::clone(&cfg),
758764
);
759765

760-
let result = spire_client
761-
.get_cert_by_pid(10, &WorkloadUid::new("uid-123456".to_string()))
762-
.await;
766+
let identity =
767+
Identity::from_parts("example.org".into(), "default".into(), "test-sa".into());
768+
let composite_id =
769+
CompositeId::with_key(identity, WorkloadUid::new("uid-123456".to_string()));
770+
let result = spire_client.get_cert_by_pid(10, &composite_id).await;
763771

764772
assert!(result.is_err());
765773

@@ -834,6 +842,52 @@ pub mod spire_tests {
834842
Box::new(rx)
835843
}
836844

845+
#[tokio::test]
846+
async fn test_get_cert_by_pid_mismatched_identity() {
847+
let mut mock_client = MockDelegatedIdentityApi::new();
848+
let pid_client = MockPidClientTrait::new();
849+
850+
// SPIRE returns a certificate for a DIFFERENT identity than requested
851+
mock_client.expect_get_x509_svids().returning(|_req| {
852+
let stream = mock_stream_svid_success_response(
853+
"spiffe://example.org/ns/other-namespace/sa/other-sa".to_string(),
854+
);
855+
Ok(stream)
856+
});
857+
858+
mock_client
859+
.expect_get_x509_bundles()
860+
.returning(|| Ok(mock_bundle_response()));
861+
862+
let mut cfg = config::parse_config().unwrap();
863+
cfg.spire_enabled = true;
864+
865+
let spire_client = SpireClient::new(
866+
mock_client,
867+
"example.org".to_string(),
868+
Box::new(pid_client),
869+
Arc::new(cfg),
870+
);
871+
872+
// Request certificate for ns/default/sa/test-sa but SPIRE returns ns/other-namespace/sa/other-sa
873+
let identity =
874+
Identity::from_parts("example.org".into(), "default".into(), "test-sa".into());
875+
let composite_id =
876+
CompositeId::with_key(identity, WorkloadUid::new("uid-123456".to_string()));
877+
878+
let result = spire_client
879+
.get_cert_from_spire(DelegateAttestationRequest::Pid(10), composite_id)
880+
.await;
881+
882+
assert!(result.is_err());
883+
let err = result.err().unwrap().to_string();
884+
assert!(
885+
err.contains("Mismatched identity"),
886+
"Expected error to contain 'Mismatched identity', got: {}",
887+
err
888+
);
889+
}
890+
837891
fn ca_params(cn: &str) -> Result<CertificateParams, rcgen::Error> {
838892
// Empty Vec<String> => no DNS SANs; we'll just use DN for the CA
839893
let mut params = CertificateParams::new(Vec::<String>::new())?;

0 commit comments

Comments
 (0)