diff --git a/project/__init__.py b/project/__init__.py index 1ee0803..d7c5630 100644 --- a/project/__init__.py +++ b/project/__init__.py @@ -1,16 +1,36 @@ +from broadcaster import Broadcast from fastapi import FastAPI -from project.celery_utils import create_celery +from project.config import settings + +broadcast = Broadcast(settings.WS_MESSAGE_QUEUE) + +# Event Hooks +async def startup_event(): + await broadcast.connect() + + +async def shutdown_event(): + await broadcast.disconnect() def create_app() -> FastAPI: - app = FastAPI() + # Starlette discourages decorator syntax now for eventhooks + app = FastAPI(on_startup=[startup_event], on_shutdown=[shutdown_event]) + + # do this before loading routes + from project.celery_utils import create_celery + app.celery_app = create_celery() from .users import users_router app.include_router(users_router) + from project.ws import ws_router + + app.include_router(ws_router) + @app.get("/") async def root(): return {"message": "Hello World"} diff --git a/project/celery_utils.py b/project/celery_utils.py index a97a25f..f15bf37 100644 --- a/project/celery_utils.py +++ b/project/celery_utils.py @@ -1,7 +1,11 @@ +import typing from celery import current_app as celery_current_app +from celery.result import AsyncResult from project.config import settings +TaskInfoResponse = typing.Dict[str, str] + def create_celery(): """Clerey factory function.""" @@ -9,3 +13,23 @@ def create_celery(): celery_app.config_from_object(settings, namespace="CELERY") return celery_app + + +def get_task_info(task_id: str) -> TaskInfoResponse: + """ + return task info according to the task_id + """ + task = AsyncResult(task_id) + state = task.state + + if state == "FAILURE": + error = str(task.result) + response = { + "state": task.state, + "error": error, + } + else: + response = { + "state": task.state, + } + return response diff --git a/project/config.py b/project/config.py index c4273d0..6f047cb 100644 --- a/project/config.py +++ b/project/config.py @@ -17,6 +17,8 @@ class BaseConfig: CELERY_RESULT_BACKEND: str = os.environ.get( "CELERY_RESULT_BACKEND", "redis://127.0.0.1:6379/0" ) + # will be used for multi-process communication by FastAPI and the Celery Worker. + WS_MESSAGE_QUEUE: str = os.environ.get("WS_MESSAGE_QUEUE", "redis://redis:6379/0") class DevelopmentConfig(BaseConfig): @@ -32,7 +34,7 @@ class TestingConfig(BaseConfig): @lru_cache() -def get_settings(): +def get_settings() -> BaseConfig: config_cls_dict = { "development": DevelopmentConfig, "production": ProductionConfig, diff --git a/project/users/tasks.py b/project/users/tasks.py index 7d22f28..b6b2048 100644 --- a/project/users/tasks.py +++ b/project/users/tasks.py @@ -5,9 +5,11 @@ Now, we can copy this file anywhere in the app and it will work as expected. In having to import the Celery instance, so it can make your task code more reusable. """ import random + import requests +from asgiref.sync import async_to_sync from celery import shared_task - +from celery.signals import task_postrun from celery.utils.log import get_task_logger # See https://docs.celeryq.dev/en/stable/userguide/tasks.html#logging @@ -57,3 +59,17 @@ def task_process_notification(self): # Remember to raise the exception returned by the self.retry method to make it work. # By setting the countdown argument to 5, the task will retry after a 5 second delay. raise self.retry(exc=e, countdown=5) + + +@task_postrun.connect +def task_postrun_handler(task_id, **kwargs): # pylint: disable=unused-argument + """ + Celery signal handler called after each Celery task is executed. + + Sends a message to the relevant channel via the `update_celery_task_status`. + Celery does not support asyncio so we must convert the async function, to a + synchronous function with asgiref. + """ + from project.ws.views import update_celery_task_status + + async_to_sync(update_celery_task_status)(task_id) diff --git a/project/users/templates/form_ws.html b/project/users/templates/form_ws.html new file mode 100644 index 0000000..fd627d5 --- /dev/null +++ b/project/users/templates/form_ws.html @@ -0,0 +1,115 @@ + + + + + + Celery example + + + +
+
+
+
+
+ + +
+
+ + +
+
+ +
+
+
+
+ + + + + diff --git a/project/users/views.py b/project/users/views.py index e6d9c13..56bd921 100644 --- a/project/users/views.py +++ b/project/users/views.py @@ -83,3 +83,8 @@ def webhook_test_async(): task = task_process_notification.delay() print(task.id) return "pong" + + +@users_router.get("/form_ws/") +def form_ws_example(request: Request): + return templates.TemplateResponse("form_ws.html", {"request": request}) diff --git a/project/ws/__init__.py b/project/ws/__init__.py new file mode 100644 index 0000000..d5bbe61 --- /dev/null +++ b/project/ws/__init__.py @@ -0,0 +1,5 @@ +from fastapi import APIRouter + +ws_router = APIRouter() + +from . import views diff --git a/project/ws/views.py b/project/ws/views.py new file mode 100644 index 0000000..59a196e --- /dev/null +++ b/project/ws/views.py @@ -0,0 +1,38 @@ +import json + +from fastapi import WebSocket + +from . import ws_router +from project import broadcast +from project.celery_utils import get_task_info + + +@ws_router.websocket("/ws/task_status/{task_id}") +async def ws_task_status(websocket: WebSocket): + # Accept websocket connection and obtain task id + await websocket.accept() + + task_id = websocket.scope["path_params"]["task_id"] + + # Subscribe to specific channel and send message to browser if + # it exists + async with broadcast.subscribe(channel=task_id) as subscriber: + # just in case the task is already finished + data = get_task_info(task_id) + await websocket.send_json(data=data) + + async for event in subscriber: + await websocket.send_json(json.loads(event.message)) + + +async def update_celery_task_status(task_id: str): + """ + Called by Celery worker in `task_postrun` signal handler after tasks finishes. + """ + await broadcast.connect() + # broadcast task info to channel identified by task_id through redis. + await broadcast.publish( + channel=task_id, + message=json.dumps(get_task_info(task_id)), # RedisProtocol.publish expect str + ) + await broadcast.disconnect() diff --git a/requirements.in b/requirements.in index dd60f9e..5398262 100644 --- a/requirements.in +++ b/requirements.in @@ -1,4 +1,7 @@ alembic==1.8.1 +asgiref==3.5.2 +asyncio-redis==0.16.0 +broadcaster==0.2.0 celery==5.2.7 fastapi==0.79.0 flower==1.2.0 diff --git a/requirements.txt b/requirements.txt index 4be0a45..ace92a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,10 +12,16 @@ anyio==3.6.1 # via # starlette # watchfiles +asgiref==3.5.2 + # via -r requirements.in async-timeout==4.0.2 # via redis +asyncio-redis==0.16.0 + # via -r requirements.in billiard==3.6.4.0 # via celery +broadcaster==0.2.0 + # via -r requirements.in celery==5.2.7 # via # -r requirements.in