You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
import json
|
|
|
|
from fastapi import WebSocket
|
|
|
|
from . import ws_router
|
|
from project import broadcast
|
|
from project.celery_utils import get_task_info
|
|
|
|
|
|
@ws_router.websocket("/ws/task_status/{task_id}")
|
|
async def ws_task_status(websocket: WebSocket):
|
|
# Accept websocket connection and obtain task id
|
|
await websocket.accept()
|
|
|
|
task_id = websocket.scope["path_params"]["task_id"]
|
|
|
|
# Subscribe to specific channel and send message to browser if
|
|
# it exists
|
|
async with broadcast.subscribe(channel=task_id) as subscriber:
|
|
# just in case the task is already finished
|
|
data = get_task_info(task_id)
|
|
await websocket.send_json(data=data)
|
|
|
|
async for event in subscriber:
|
|
await websocket.send_json(json.loads(event.message))
|
|
|
|
|
|
async def update_celery_task_status(task_id: str):
|
|
"""
|
|
Called by Celery worker in `task_postrun` signal handler after tasks finishes.
|
|
"""
|
|
await broadcast.connect()
|
|
# broadcast task info to channel identified by task_id through redis.
|
|
await broadcast.publish(
|
|
channel=task_id,
|
|
message=json.dumps(get_task_info(task_id)), # RedisProtocol.publish expect str
|
|
)
|
|
await broadcast.disconnect()
|