diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index b107f7b2b..33201c9f2 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -272,3 +272,12 @@ def get_access_token(self): def sync_auth_flow(self, request): request.headers["Authorization"] = f"Bearer {self.get_access_token()}" yield request + + +class OPAClient: # placeholder until https://jira.diamond.ac.uk/browse/ACQP-550 is done + def do_some_checks(self, task_request) -> bool: + return True + + +def get_opa_client() -> OPAClient: # placeholder + return OPAClient() diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index a53c46885..2075b7a74 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -37,6 +37,7 @@ from blueapi import __version__ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.service import interface +from blueapi.service.authentication import OPAClient, get_opa_client from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum @@ -166,6 +167,16 @@ def inner(request: Request, access_token: str = Depends(oauth_scheme)): TRACER = get_tracer("interface") +def submit_permission( + opa: Annotated[OPAClient, Depends(get_opa_client)], + task_request: TaskRequest, +): + allowed = opa.do_some_checks(task_request) + + if not allowed: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + async def on_key_error_404(_: Request, __: Exception): return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, @@ -291,6 +302,7 @@ def submit_task( request: Request, response: Response, task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])], + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TaskResponse: """Submit a task to the worker.""" @@ -336,6 +348,7 @@ def submit_task( @start_as_current_span(TRACER, "task_id") def delete_submitted_task( task_id: str, + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TaskResponse: return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) @@ -354,6 +367,7 @@ def validate_task_status(v: str) -> TaskStatusEnum: @start_as_current_span(TRACER) def get_tasks( runner: Annotated[WorkerDispatcher, Depends(_runner)], + _: Annotated[None, Depends(submit_permission)], task_status: str | SkipJsonSchema[None] = None, ) -> TasksListResponse: """ @@ -390,6 +404,7 @@ def get_tasks( def set_active_task( request: Request, task: WorkerTask, + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerTask: """Set a task to active status, the worker should begin it as soon as possible. @@ -420,6 +435,7 @@ def get_passthrough_headers(request: Request) -> dict[str, str]: @start_as_current_span(TRACER, "task_id") def get_task( task_id: str, + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TrackableTask: """Retrieve a task""" @@ -497,6 +513,7 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt def set_state( state_change_request: StateChangeRequest, response: Response, + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: """