@@ -120,10 +120,10 @@ impl<C: DelegatedIdentityApi> SpireClient<C> {
120120 async fn get_cert_by_pid (
121121 & self ,
122122 pid : i32 ,
123- wl_uid : & WorkloadUid ,
123+ id : & CompositeId < WorkloadUid > ,
124124 ) -> Result < tls:: WorkloadCertificate , Error > {
125125 let certs = self
126- . get_cert_from_spire ( DelegateAttestationRequest :: Pid ( pid) )
126+ . get_cert_from_spire ( DelegateAttestationRequest :: Pid ( pid) , id . clone ( ) )
127127 . await ;
128128
129129 let certs = match certs {
@@ -136,14 +136,15 @@ impl<C: DelegatedIdentityApi> SpireClient<C> {
136136 Ok ( certs) => certs,
137137 } ;
138138
139- let pid_verify = self . pid . fetch_pid ( wl_uid) . await ;
139+ // Verify that the PID we used for attestation matches the PID associated with the workload UID
140+ let pid_verify = self . pid . fetch_pid ( id. key ( ) ) . await ;
140141
141142 match pid_verify {
142143 Ok ( fetched_pid) => {
143144 if fetched_pid. into_i32 ( ) != pid {
144145 return Err ( Error :: UnableToDeterminePidForWorkload ( format ! (
145146 "PID mismatch for workload UID {}: expected {}, got {}" ,
146- wl_uid . clone( ) . into_string( ) ,
147+ id . key ( ) . clone( ) . into_string( ) ,
147148 pid,
148149 fetched_pid. into_i32( )
149150 ) ) ) ;
@@ -152,7 +153,7 @@ impl<C: DelegatedIdentityApi> SpireClient<C> {
152153 }
153154 Err ( e) => Err ( Error :: UnableToDeterminePidForWorkload ( format ! (
154155 "Failed to verify PID for workload UID {}: {}" ,
155- wl_uid . clone( ) . into_string( ) ,
156+ id . key ( ) . clone( ) . into_string( ) ,
156157 e
157158 ) ) ) ,
158159 }
@@ -171,18 +172,18 @@ impl<C: DelegatedIdentityApi> SpireClient<C> {
171172 /// no certificates are received within timeout, or certificate construction fails.
172173 async fn get_cert_by_workload_uid (
173174 & self ,
174- wl_uid : & WorkloadUid ,
175+ id : & CompositeId < WorkloadUid > ,
175176 ) -> Result < tls:: WorkloadCertificate , Error > {
176177 tracing:: info!(
177178 "Fetching PID for workload UID: {}" ,
178- wl_uid . clone( ) . into_string( )
179+ id . key ( ) . clone( ) . into_string( )
179180 ) ;
180- let pid = self . pid . fetch_pid ( wl_uid ) . await ;
181+ let pid = self . pid . fetch_pid ( id . key ( ) ) . await ;
181182 match pid {
182- Ok ( pid) => self . get_cert_by_pid ( pid. into_i32 ( ) , wl_uid ) . await ,
183+ Ok ( pid) => self . get_cert_by_pid ( pid. into_i32 ( ) , id ) . await ,
183184 Err ( e) => Err ( Error :: UnableToDeterminePidForWorkload ( format ! (
184185 "Failed to fetch PID for workload UID {}: {}" ,
185- wl_uid . clone( ) . into_string( ) ,
186+ id . key ( ) . clone( ) . into_string( ) ,
186187 e
187188 ) ) ) ,
188189 }
@@ -265,6 +266,7 @@ impl<C: DelegatedIdentityApi> SpireClient<C> {
265266 async fn get_cert_from_spire (
266267 & self ,
267268 value : DelegateAttestationRequest ,
269+ id : CompositeId < WorkloadUid > ,
268270 ) -> Result < tls:: WorkloadCertificate , Error > {
269271 // Handle nested Result types from timeout + stream operations
270272 let svid_response = self . subscribe_and_wait_for_workload_cert ( value) . await ?;
@@ -275,14 +277,20 @@ impl<C: DelegatedIdentityApi> SpireClient<C> {
275277 // Construct the final WorkloadCertificate combining SVID and trust bundle
276278 let certs = tls:: WorkloadCertificate :: new_svid ( & svid_response, & bundle) ?;
277279
278- let id = format ! (
280+ let id_strg = format ! (
279281 "spiffe://{}{}" ,
280282 svid_response. spiffe_id( ) . trust_domain( ) ,
281283 svid_response. spiffe_id( ) . path( )
282284 ) ;
283285
284286 // Validate that the returned identity matches the requested one
285- Identity :: from_str ( & id) ?;
287+ if id. id ( ) . to_string ( ) != Identity :: from_str ( & id_strg) ?. to_string ( ) {
288+ return Err ( Error :: Spiffe ( format ! (
289+ "Mismatched identity: expected {}, got {}" ,
290+ id. id( ) ,
291+ id_strg
292+ ) ) ) ;
293+ }
286294
287295 Ok ( certs)
288296 }
@@ -345,7 +353,7 @@ impl<C: DelegatedIdentityApi> crate::identity::CaClientTrait for SpireClient<C>
345353 & self ,
346354 id : & CompositeId < WorkloadUid > ,
347355 ) -> Result < tls:: WorkloadCertificate , Error > {
348- self . get_cert_by_workload_uid ( id. key ( ) ) . await
356+ self . get_cert_by_workload_uid ( id) . await
349357 }
350358}
351359
@@ -380,7 +388,7 @@ pub mod spire_tests {
380388 #[ tokio:: test]
381389 async fn test_get_bundle_success ( ) {
382390 let mut mock_client = MockDelegatedIdentityApi :: new ( ) ;
383- let mut pid_client = MockPidClientTrait :: new ( ) ;
391+ let pid_client = MockPidClientTrait :: new ( ) ;
384392
385393 mock_client
386394 . expect_get_x509_bundles ( )
@@ -403,7 +411,7 @@ pub mod spire_tests {
403411 #[ tokio:: test]
404412 async fn test_get_bundle_trust_domain_not_found ( ) {
405413 let mut mock_client = MockDelegatedIdentityApi :: new ( ) ;
406- let mut pid_client = MockPidClientTrait :: new ( ) ;
414+ let pid_client = MockPidClientTrait :: new ( ) ;
407415
408416 mock_client
409417 . expect_get_x509_bundles ( )
@@ -459,9 +467,9 @@ pub mod spire_tests {
459467
460468 let identity =
461469 Identity :: from_parts ( "example.org" . into ( ) , "default" . into ( ) , "test-sa" . into ( ) ) ;
462- let result = spire_client
463- . get_cert_by_pid ( 10 , & WorkloadUid :: new ( "uid-123456" . to_string ( ) ) )
464- . await ;
470+ let composite_id =
471+ CompositeId :: with_key ( identity . clone ( ) , WorkloadUid :: new ( "uid-123456" . to_string ( ) ) ) ;
472+ let result = spire_client . get_cert_by_pid ( 10 , & composite_id ) . await ;
465473
466474 assert ! ( result. is_ok( ) ) ;
467475
@@ -473,9 +481,6 @@ pub mod spire_tests {
473481
474482 assert ! ( identity. to_string( ) == "spiffe://example.org/ns/default/sa/test-sa" ) ;
475483
476- let composite_id =
477- CompositeId :: with_key ( identity, WorkloadUid :: new ( "uid-123456" . to_string ( ) ) ) ;
478-
479484 let fetch_result = spire_client. fetch_certificate ( & composite_id) . await ;
480485
481486 assert ! ( fetch_result. is_ok( ) ) ;
@@ -514,9 +519,11 @@ pub mod spire_tests {
514519 Arc :: new ( cfg) ,
515520 ) ;
516521
517- let result = spire_client
518- . get_cert_by_pid ( 10 , & WorkloadUid :: new ( "uid-123456" . to_string ( ) ) )
519- . await ;
522+ let identity =
523+ Identity :: from_parts ( "example.org" . into ( ) , "default" . into ( ) , "test-sa" . into ( ) ) ;
524+ let composite_id =
525+ CompositeId :: with_key ( identity, WorkloadUid :: new ( "uid-123456" . to_string ( ) ) ) ;
526+ let result = spire_client. get_cert_by_pid ( 10 , & composite_id) . await ;
520527
521528 assert ! ( result. is_err( ) ) ;
522529 }
@@ -660,9 +667,9 @@ pub mod spire_tests {
660667
661668 let identity =
662669 Identity :: from_parts ( "example.org" . into ( ) , "default" . into ( ) , "test-sa" . into ( ) ) ;
663- let result = spire_client
664- . get_cert_by_pid ( 10 , & WorkloadUid :: new ( "uid-123456" . to_string ( ) ) )
665- . await ;
670+ let composite_id =
671+ CompositeId :: with_key ( identity . clone ( ) , WorkloadUid :: new ( "uid-123456" . to_string ( ) ) ) ;
672+ let result = spire_client . get_cert_by_pid ( 10 , & composite_id ) . await ;
666673
667674 assert ! ( result. is_ok( ) ) ;
668675
@@ -674,9 +681,6 @@ pub mod spire_tests {
674681
675682 assert ! ( identity. to_string( ) == "spiffe://example.org/ns/default/sa/test-sa" ) ;
676683
677- let composite_id =
678- CompositeId :: with_key ( identity, WorkloadUid :: new ( "uid-123456" . to_string ( ) ) ) ;
679-
680684 let fetch_result = spire_client. fetch_certificate ( & composite_id) . await ;
681685
682686 assert ! ( fetch_result. is_ok( ) ) ;
@@ -715,9 +719,11 @@ pub mod spire_tests {
715719 Arc :: clone ( & cfg) ,
716720 ) ;
717721
718- let result = spire_client
719- . get_cert_by_pid ( 10 , & WorkloadUid :: new ( "uid-123456" . to_string ( ) ) )
720- . await ;
722+ let identity =
723+ Identity :: from_parts ( "example.org" . into ( ) , "default" . into ( ) , "test-sa" . into ( ) ) ;
724+ let composite_id =
725+ CompositeId :: with_key ( identity, WorkloadUid :: new ( "uid-123456" . to_string ( ) ) ) ;
726+ let result = spire_client. get_cert_by_pid ( 10 , & composite_id) . await ;
721727
722728 assert ! ( result. is_err( ) ) ;
723729
@@ -757,9 +763,11 @@ pub mod spire_tests {
757763 Arc :: clone ( & cfg) ,
758764 ) ;
759765
760- let result = spire_client
761- . get_cert_by_pid ( 10 , & WorkloadUid :: new ( "uid-123456" . to_string ( ) ) )
762- . await ;
766+ let identity =
767+ Identity :: from_parts ( "example.org" . into ( ) , "default" . into ( ) , "test-sa" . into ( ) ) ;
768+ let composite_id =
769+ CompositeId :: with_key ( identity, WorkloadUid :: new ( "uid-123456" . to_string ( ) ) ) ;
770+ let result = spire_client. get_cert_by_pid ( 10 , & composite_id) . await ;
763771
764772 assert ! ( result. is_err( ) ) ;
765773
@@ -834,6 +842,52 @@ pub mod spire_tests {
834842 Box :: new ( rx)
835843 }
836844
845+ #[ tokio:: test]
846+ async fn test_get_cert_by_pid_mismatched_identity ( ) {
847+ let mut mock_client = MockDelegatedIdentityApi :: new ( ) ;
848+ let pid_client = MockPidClientTrait :: new ( ) ;
849+
850+ // SPIRE returns a certificate for a DIFFERENT identity than requested
851+ mock_client. expect_get_x509_svids ( ) . returning ( |_req| {
852+ let stream = mock_stream_svid_success_response (
853+ "spiffe://example.org/ns/other-namespace/sa/other-sa" . to_string ( ) ,
854+ ) ;
855+ Ok ( stream)
856+ } ) ;
857+
858+ mock_client
859+ . expect_get_x509_bundles ( )
860+ . returning ( || Ok ( mock_bundle_response ( ) ) ) ;
861+
862+ let mut cfg = config:: parse_config ( ) . unwrap ( ) ;
863+ cfg. spire_enabled = true ;
864+
865+ let spire_client = SpireClient :: new (
866+ mock_client,
867+ "example.org" . to_string ( ) ,
868+ Box :: new ( pid_client) ,
869+ Arc :: new ( cfg) ,
870+ ) ;
871+
872+ // Request certificate for ns/default/sa/test-sa but SPIRE returns ns/other-namespace/sa/other-sa
873+ let identity =
874+ Identity :: from_parts ( "example.org" . into ( ) , "default" . into ( ) , "test-sa" . into ( ) ) ;
875+ let composite_id =
876+ CompositeId :: with_key ( identity, WorkloadUid :: new ( "uid-123456" . to_string ( ) ) ) ;
877+
878+ let result = spire_client
879+ . get_cert_from_spire ( DelegateAttestationRequest :: Pid ( 10 ) , composite_id)
880+ . await ;
881+
882+ assert ! ( result. is_err( ) ) ;
883+ let err = result. err ( ) . unwrap ( ) . to_string ( ) ;
884+ assert ! (
885+ err. contains( "Mismatched identity" ) ,
886+ "Expected error to contain 'Mismatched identity', got: {}" ,
887+ err
888+ ) ;
889+ }
890+
837891 fn ca_params ( cn : & str ) -> Result < CertificateParams , rcgen:: Error > {
838892 // Empty Vec<String> => no DNS SANs; we'll just use DN for the CA
839893 let mut params = CertificateParams :: new ( Vec :: < String > :: new ( ) ) ?;
0 commit comments