Adding websocket form

drew/tilt-local-dev
Drew Bednar 2 years ago
parent 82691dd9fb
commit 80a42f88aa

@ -1,16 +1,36 @@
from broadcaster import Broadcast
from fastapi import FastAPI
from project.celery_utils import create_celery
from project.config import settings
broadcast = Broadcast(settings.WS_MESSAGE_QUEUE)
# Event Hooks
async def startup_event():
await broadcast.connect()
async def shutdown_event():
await broadcast.disconnect()
def create_app() -> FastAPI:
app = FastAPI()
# Starlette discourages decorator syntax now for eventhooks
app = FastAPI(on_startup=[startup_event], on_shutdown=[shutdown_event])
# do this before loading routes
from project.celery_utils import create_celery
app.celery_app = create_celery()
from .users import users_router
app.include_router(users_router)
from project.ws import ws_router
app.include_router(ws_router)
@app.get("/")
async def root():
return {"message": "Hello World"}

@ -1,7 +1,11 @@
import typing
from celery import current_app as celery_current_app
from celery.result import AsyncResult
from project.config import settings
TaskInfoResponse = typing.Dict[str, str]
def create_celery():
"""Clerey factory function."""
@ -9,3 +13,23 @@ def create_celery():
celery_app.config_from_object(settings, namespace="CELERY")
return celery_app
def get_task_info(task_id: str) -> TaskInfoResponse:
"""
return task info according to the task_id
"""
task = AsyncResult(task_id)
state = task.state
if state == "FAILURE":
error = str(task.result)
response = {
"state": task.state,
"error": error,
}
else:
response = {
"state": task.state,
}
return response

@ -17,6 +17,8 @@ class BaseConfig:
CELERY_RESULT_BACKEND: str = os.environ.get(
"CELERY_RESULT_BACKEND", "redis://127.0.0.1:6379/0"
)
# will be used for multi-process communication by FastAPI and the Celery Worker.
WS_MESSAGE_QUEUE: str = os.environ.get("WS_MESSAGE_QUEUE", "redis://redis:6379/0")
class DevelopmentConfig(BaseConfig):
@ -32,7 +34,7 @@ class TestingConfig(BaseConfig):
@lru_cache()
def get_settings():
def get_settings() -> BaseConfig:
config_cls_dict = {
"development": DevelopmentConfig,
"production": ProductionConfig,

@ -5,9 +5,11 @@ Now, we can copy this file anywhere in the app and it will work as expected. In
having to import the Celery instance, so it can make your task code more reusable.
"""
import random
import requests
from asgiref.sync import async_to_sync
from celery import shared_task
from celery.signals import task_postrun
from celery.utils.log import get_task_logger
# See https://docs.celeryq.dev/en/stable/userguide/tasks.html#logging
@ -57,3 +59,17 @@ def task_process_notification(self):
# Remember to raise the exception returned by the self.retry method to make it work.
# By setting the countdown argument to 5, the task will retry after a 5 second delay.
raise self.retry(exc=e, countdown=5)
@task_postrun.connect
def task_postrun_handler(task_id, **kwargs): # pylint: disable=unused-argument
"""
Celery signal handler called after each Celery task is executed.
Sends a message to the relevant channel via the `update_celery_task_status`.
Celery does not support asyncio so we must convert the async function, to a
synchronous function with asgiref.
"""
from project.ws.views import update_celery_task_status
async_to_sync(update_celery_task_status)(task_id)

@ -0,0 +1,115 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Celery example</title>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/bootstrap/5.2.0/css/bootstrap.min.css"
integrity="sha512-XWTTruHZEYJsxV3W/lSXG1n3Q39YIWOstqvmFsdNEEQfHoZ6vm6E9GK2OrF6DSJSpIbRbi+Nn0WDPID9O7xB2Q=="
crossorigin="anonymous" referrerpolicy="no-referrer"/>
</head>
<body>
<div class="container">
<div class="row">
<div class="col-12 col-md-4">
<form id="your-form">
<div class="mb-3">
<label for="email" class="form-label">Email address</label>
<input type="email" class="form-control" id="email" name="email">
</div>
<div class="mb-3">
<label for="username" class="form-label">Username</label>
<input type="text" class="form-control" id="username" name="username">
</div>
<div class="mb-3" id="messages"></div>
<button type="submit" class="btn btn-primary">Submit</button>
</form>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/bootstrap/5.2.0/js/bootstrap.min.js"
integrity="sha512-8Y8eGK92dzouwpROIppwr+0kPauu0qqtnzZZNEF8Pat5tuRNJxJXCkbQfJ0HlUG3y1HB3z18CSKmUo7i2zcPpg=="
crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script>
function updateProgress(yourForm, task_id, btnHtml) {
const ws_url = `/ws/task_status/${task_id}`;
const WS = new WebSocket((location.protocol === 'https:' ? 'wss' : 'ws') + '://' + window.location.host + ws_url);
WS.onmessage = function (event) {
const res = JSON.parse(event.data);
const taskStatus = res.state;
if (['SUCCESS', 'FAILURE'].includes(taskStatus)) {
const msg = yourForm.querySelector('#messages');
const submitBtn = yourForm.querySelector('button[type="submit"]');
if (taskStatus === 'SUCCESS') {
msg.innerHTML = 'job succeeded';
} else if (taskStatus === 'FAILURE') {
msg.innerHTML = res.error;
}
submitBtn.disabled = false;
submitBtn.innerHTML = btnHtml;
// close the websocket because we do not need it now
WS.close();
}
}
}
function serialize (data) {
let obj = {};
for (let [key, value] of data) {
if (obj[key] !== undefined) {
if (!Array.isArray(obj[key])) {
obj[key] = [obj[key]];
}
obj[key].push(value);
} else {
obj[key] = value;
}
}
return obj;
}
document.addEventListener("DOMContentLoaded", function () {
const yourForm = document.getElementById("your-form");
yourForm.addEventListener("submit", function (event) {
event.preventDefault();
const submitBtn = yourForm.querySelector('button[type="submit"]');
const btnHtml = submitBtn.innerHTML;
const spinnerHtml = 'Processing...';
submitBtn.disabled = true;
submitBtn.innerHTML = spinnerHtml;
const msg = yourForm.querySelector('#messages');
msg.innerHTML = '';
// Get all field data from the form
let data = new FormData(yourForm);
// Convert to an object
let formData = serialize(data);
fetch('/users/form/', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify(formData),
})
.then(response => response.json())
.then((res) => {
// after we get Celery task id, we start polling
const task_id = res.task_id;
updateProgress(yourForm, task_id, btnHtml);
console.log(res);
}).catch((error) => {
console.error('Error:', error);
});
});
});
</script>
</body>
</html>

@ -83,3 +83,8 @@ def webhook_test_async():
task = task_process_notification.delay()
print(task.id)
return "pong"
@users_router.get("/form_ws/")
def form_ws_example(request: Request):
return templates.TemplateResponse("form_ws.html", {"request": request})

@ -0,0 +1,5 @@
from fastapi import APIRouter
ws_router = APIRouter()
from . import views

@ -0,0 +1,38 @@
import json
from fastapi import WebSocket
from . import ws_router
from project import broadcast
from project.celery_utils import get_task_info
@ws_router.websocket("/ws/task_status/{task_id}")
async def ws_task_status(websocket: WebSocket):
# Accept websocket connection and obtain task id
await websocket.accept()
task_id = websocket.scope["path_params"]["task_id"]
# Subscribe to specific channel and send message to browser if
# it exists
async with broadcast.subscribe(channel=task_id) as subscriber:
# just in case the task is already finished
data = get_task_info(task_id)
await websocket.send_json(data=data)
async for event in subscriber:
await websocket.send_json(json.loads(event.message))
async def update_celery_task_status(task_id: str):
"""
Called by Celery worker in `task_postrun` signal handler after tasks finishes.
"""
await broadcast.connect()
# broadcast task info to channel identified by task_id through redis.
await broadcast.publish(
channel=task_id,
message=json.dumps(get_task_info(task_id)), # RedisProtocol.publish expect str
)
await broadcast.disconnect()

@ -1,4 +1,7 @@
alembic==1.8.1
asgiref==3.5.2
asyncio-redis==0.16.0
broadcaster==0.2.0
celery==5.2.7
fastapi==0.79.0
flower==1.2.0

@ -12,10 +12,16 @@ anyio==3.6.1
# via
# starlette
# watchfiles
asgiref==3.5.2
# via -r requirements.in
async-timeout==4.0.2
# via redis
asyncio-redis==0.16.0
# via -r requirements.in
billiard==3.6.4.0
# via celery
broadcaster==0.2.0
# via -r requirements.in
celery==5.2.7
# via
# -r requirements.in

Loading…
Cancel
Save