Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .ai/ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -905,12 +905,14 @@ app = Dash(
server=server,
websocket_callbacks=True,
websocket_inactivity_timeout=300000, # 5 minutes (default)
websocket_heartbeat_interval=30000, # 30 seconds (default)
websocket_allowed_origins=['https://example.com'],
)
```

- **`websocket_callbacks`** - Enable WebSocket for all callbacks (default: `False`)
- **`websocket_inactivity_timeout`** - Close WebSocket after period of inactivity in milliseconds (default: `300000` = 5 minutes). Heartbeats do not count as activity. Set to `0` to disable timeout. Connection automatically reconnects when needed.
- **`websocket_heartbeat_interval`** - Interval for heartbeat/keep-alive checks in milliseconds (default: `30000` = 30 seconds). Also determines how frequently inactivity timeout is checked.
- **`websocket_allowed_origins`** - List of allowed origins for WebSocket connections (security)

### Architecture
Expand Down Expand Up @@ -968,6 +970,7 @@ WebSocket callbacks can stream updates to the client during execution using `set
```python
import asyncio
from dash import callback, Output, Input, set_props, ctx
from dash.exceptions import PreventUpdate

@callback(
Output('result', 'children'),
Expand All @@ -981,6 +984,9 @@ async def long_running_task(n_clicks):

# Stream progress updates to the client
for i in range(100):
# IMPORTANT: Check is_shutdown in loops to detect disconnections
if ws.is_shutdown:
raise PreventUpdate # Exit gracefully on disconnect
await asyncio.sleep(0.1)
set_props('progress-bar', {'value': i + 1})
set_props('status', {'children': f'Processing step {i + 1}/100...'})
Expand All @@ -991,9 +997,19 @@ async def long_running_task(n_clicks):
return f"Completed! Input was: {current_value}"
```

**IMPORTANT - Checking `is_shutdown` in Loops:**

Long-running callbacks that use loops **must** check `ws.is_shutdown` to detect when the WebSocket connection has closed. Without this check:
- Callbacks continue running after the client disconnects, wasting server resources
- `set_props` calls go to a closed connection and are lost
- The callback result is never delivered to the client

Only "persistent callbacks" (callbacks with no Output and no Input that use only `set_props`) are automatically restarted when the WebSocket reconnects. Regular callbacks with outputs are not restarted.

**API:**
- `set_props(component_id, props_dict)` - Stream prop updates immediately to client
- `ctx.websocket` - Get WebSocket interface (returns `None` if not in WS context)
- `ws.is_shutdown` - Check if the WebSocket connection has been closed
- `await ws.get_prop(component_id, prop_name)` - Read current prop value from client
- `await ws.set_prop(component_id, prop_name, value)` - Set single prop (async version)
- `await ws.close(code, reason)` - Close the WebSocket connection
Expand Down
8 changes: 8 additions & 0 deletions @plotly/dash-websocket-worker/src/WebSocketManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ export class WebSocketManager {
return this.ws !== null && this.ws.readyState === WebSocket.OPEN;
}

/**
* Reset the activity timer.
* Call this when a tab becomes visible to prevent inactivity timeout.
*/
public resetActivity(): void {
this.lastActivityTime = Date.now();
}

private createConnection(): void {
if (!this.serverUrl) {
return;
Expand Down
2 changes: 2 additions & 0 deletions @plotly/dash-websocket-worker/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export enum WorkerMessageType {
DISCONNECT = 'disconnect',
CALLBACK_REQUEST = 'callback_request',
GET_PROPS_RESPONSE = 'get_props_response',
TAB_VISIBLE = 'tab_visible',

// Worker -> Renderer
CONNECTED = 'connected',
Expand Down Expand Up @@ -39,6 +40,7 @@ export interface ConnectMessage extends WorkerMessage {
payload: {
serverUrl: string;
inactivityTimeout?: number;
heartbeatInterval?: number;
};
}

Expand Down
21 changes: 18 additions & 3 deletions @plotly/dash-websocket-worker/src/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,24 @@ self.onconnect = (event: MessageEvent) => {
const rendererId = connectMsg.rendererId;
const newServerUrl = connectMsg.payload.serverUrl;
const inactivityTimeout = connectMsg.payload.inactivityTimeout;
const heartbeatInterval = connectMsg.payload.heartbeatInterval;

// Register the renderer
router.registerRenderer(rendererId, port);
rendererIds.add(rendererId);

console.log(`[DashWSWorker] Renderer ${rendererId} connected, inactivityTimeout: ${inactivityTimeout}`);
console.log(`[DashWSWorker] Renderer ${rendererId} connected, inactivityTimeout: ${inactivityTimeout}, heartbeatInterval: ${heartbeatInterval}`);

// Update inactivity timeout if provided
// Update config if provided
const configUpdate: {inactivityTimeout?: number; heartbeatInterval?: number} = {};
if (typeof inactivityTimeout === 'number') {
wsManager.setConfig({ inactivityTimeout });
configUpdate.inactivityTimeout = inactivityTimeout;
}
if (typeof heartbeatInterval === 'number') {
configUpdate.heartbeatInterval = heartbeatInterval;
}
if (Object.keys(configUpdate).length > 0) {
wsManager.setConfig(configUpdate);
}

// Connect to server if not already connected
Expand Down Expand Up @@ -122,6 +130,13 @@ self.onconnect = (event: MessageEvent) => {
break;
}

case WorkerMessageType.TAB_VISIBLE: {
// Reset activity timer when tab becomes visible
// This prevents inactivity timeout while user is viewing the tab
wsManager.resetActivity();
break;
}

default:
// Forward other messages through the router
router.handleRendererMessage(message.rendererId, message);
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This project adheres to [Semantic Versioning](https://semver.org/).
- [#3669](https://github.com/plotly/dash/pull/3669) Selection for DataTable cleared with custom action settings
- [#3680](https://github.com/plotly/dash/pull/3680) Added `search_order` prop to `Dropdown` to allow users to preserve original option order during search
- Added `csrf_token_name` and `csrf_header_name` config options to allow configuring the CSRF cookie and header names. Fixes [#729](https://github.com/plotly/dash/issues/729)
- [#3797](https://github.com/plotly/dash/pull/3797) Improved websocket callback management.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the changes in this fix be documented somewhere besides this changelog entry?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more comprehensive guide for the websocket shutdown management can be found on the WebsocketCallback docstring and in the .ai/ARCHITECTURE.md

- [#3523](https://github.com/plotly/dash/pull/3523) Fall back to background callback function names if source cannot be found
- [#3785](https://github.com/plotly/dash/pull/3785) Fix patch with dcc.Graph figure.
- [#3785](https://github.com/plotly/dash/pull/3785) Fix dcc.Graph not sending duplicate clicks because it had the same payload by adding a timestamp in the click event object.
Expand Down
2 changes: 1 addition & 1 deletion dash/backends/_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ async def websocket_handler(websocket: WebSocket):
dash_app._websocket_callbacks,
)

# Create WebSocket callback instance with outbound queue
# Create WebSocket callback instance
ws_cb = DashWebsocketCallback(
pending_get_props,
renderer_id,
Expand Down
2 changes: 1 addition & 1 deletion dash/backends/_quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ async def websocket_handler(): # pylint: disable=too-many-branches
dash_app._websocket_callbacks,
)

# Create WebSocket callback instance with outbound queue
# Create WebSocket callback instance
ws_cb = DashWebsocketCallback(
pending_get_props,
renderer_id,
Expand Down
134 changes: 103 additions & 31 deletions dash/backends/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@ class DashWebsocketCallback:
Uses janus.Queue for outbound messages (serialized with to_json) and
queue.Queue for get_props responses, enabling thread-safe communication
between worker threads and the main event loop.

IMPORTANT: For long-running callbacks that use loops (e.g., streaming updates),
you MUST check `ws.is_shutdown` in your loop to detect disconnections:

@callback(Input('btn', 'n_clicks')) # No Output - uses set_props only
async def long_running(n_clicks):
ws = ctx.websocket
while True:
if ws and ws.is_shutdown:
raise PreventUpdate # Exit gracefully on disconnect
set_props('progress', {'value': get_data()})
await asyncio.sleep(0.1)

Without this check, callbacks will continue running after the client disconnects,
wasting server resources.

Note: Only "persistent callbacks" (callbacks with no Output and no Input that use
only set_props) are automatically restarted when the WebSocket reconnects. Regular
callbacks with outputs are not restarted.
"""

def __init__(
Expand Down Expand Up @@ -69,15 +88,25 @@ def is_shutdown(self) -> bool:
"""Check if the websocket connection has been shut down."""
return self._shutdown_event.is_set()

def _get_outbound_queue(self) -> janus.Queue[str] | None:
"""Get the outbound queue."""
return self._outbound_queue

def _get_pending_get_props(self) -> Dict[str, queue.Queue[Any]] | None:
"""Get the pending_get_props dict."""
return self._pending_get_props

def _queue_message(self, msg: dict) -> None:
"""Serialize and queue message for sending (thread-safe, non-blocking).

Uses to_json for proper serialization of Dash components.
Does nothing if the connection has been shut down.
"""
if self._shutdown_event.is_set():
if self.is_shutdown:
return
self._outbound_queue.sync_q.put_nowait(cast(str, to_json(msg)))
outbound_queue = self._get_outbound_queue()
if outbound_queue is not None:
outbound_queue.sync_q.put_nowait(cast(str, to_json(msg)))

async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None:
"""Send immediate prop update to the client via WebSocket.
Expand Down Expand Up @@ -115,7 +144,11 @@ async def get_prop(
WebsocketDisconnected: If the websocket connection has been closed.
TimeoutError: If the response doesn't arrive within the timeout.
"""
if self._shutdown_event.is_set():
if self.is_shutdown:
raise WebsocketDisconnected()

pending_get_props = self._get_pending_get_props()
if pending_get_props is None:
raise WebsocketDisconnected()

request_id = str(uuid.uuid4())
Expand All @@ -128,7 +161,7 @@ async def get_prop(

# Use standard queue.Queue for response
response_queue: queue.Queue = queue.Queue()
self._pending_get_props[request_id] = response_queue
pending_get_props[request_id] = response_queue

# Queue the outbound request via janus sync interface
self._queue_message(msg)
Expand All @@ -146,7 +179,10 @@ async def get_prop(
f"Timeout waiting for {component_id}.{prop_name}"
) from exc
finally:
self._pending_get_props.pop(request_id, None)
# Get fresh reference in case of reconnection
current_pending = self._get_pending_get_props()
if current_pending is not None:
current_pending.pop(request_id, None)


def create_ws_context(
Expand Down Expand Up @@ -209,31 +245,60 @@ async def run_ws_sender(
messages: list[str] = []
try:
while True:
# Wait indefinitely for first message, then use timeout for batching
timeout = batch_delay if messages else None
try:
msg = await asyncio.wait_for(q.get(), timeout=timeout)
if msg == SHUTDOWN_SIGNAL:
if messages:
await _send_batched(send_text, messages)
return
if msg == FLUSH_SIGNAL:
if messages:
await _send_batched(send_text, messages)
messages = []
continue
if not batch_delay:
await send_text(msg)
else:
messages.append(msg)
except asyncio.TimeoutError:
await _send_batched(send_text, messages)
messages = []
result = await _process_ws_message(q, send_text, messages, batch_delay)
if result is False:
return
except asyncio.CancelledError:
pass


async def _send_batched(send_text: Callable[[str], Any], messages: list) -> None:
async def _process_ws_message(
q: "janus._AsyncQueueProxy[str]",
send_text: Callable[[str], Any],
messages: list[str],
batch_delay: float,
) -> bool:
"""Process a single WebSocket message from the queue.

Args:
q: The async queue to read from
send_text: Async function to send text data over WebSocket
messages: List to accumulate messages for batching (mutated in place)
batch_delay: Batch delay in seconds

Returns:
True to continue processing, False to stop the sender loop.
"""
timeout = batch_delay if messages else None
try:
msg = await asyncio.wait_for(q.get(), timeout=timeout)
except asyncio.TimeoutError:
success = await _send_batched(send_text, messages)
messages.clear()
return success

if msg == SHUTDOWN_SIGNAL:
if messages:
await _send_batched(send_text, messages)
return False

if msg == FLUSH_SIGNAL:
success = not messages or await _send_batched(send_text, messages)
messages.clear()
return success

if not batch_delay:
try:
await send_text(msg)
except Exception: # pylint: disable=broad-exception-caught
return False # WebSocketDisconnect, RuntimeError, etc.
else:
messages.append(msg)

return True


async def _send_batched(send_text: Callable[[str], Any], messages: list) -> bool:
"""Send messages as a batch.

Single messages are sent as-is. Multiple messages are wrapped
Expand All @@ -242,12 +307,19 @@ async def _send_batched(send_text: Callable[[str], Any], messages: list) -> None
Args:
send_text: Async function to send text data over WebSocket
messages: List of pre-serialized JSON message strings

Returns:
True if send succeeded, False if connection was closed
"""
if len(messages) == 1:
await send_text(messages[0])
else:
# Wrap in array: "[msg1,msg2,msg3]"
await send_text("[" + ",".join(messages) + "]")
try:
if len(messages) == 1:
await send_text(messages[0])
else:
# Wrap in array: "[msg1,msg2,msg3]"
await send_text("[" + ",".join(messages) + "]")
return True
except Exception: # pylint: disable=broad-exception-caught
return False # WebSocketDisconnect, RuntimeError, etc.


def make_callback_done_handler(
Expand Down
1 change: 1 addition & 0 deletions dash/dash-renderer/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export type DashConfig = {
url: string;
worker_url: string;
inactivity_timeout?: number;
heartbeat_interval?: number;
};
csrf_token_name?: string;
csrf_header_name?: string;
Expand Down
Loading
Loading