You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

106 lines
3.4 KiB
Python

import datetime as dt
import decimal
from math import isfinite
import typing
from apistar import types, validators
from apistar.http import JSONResponse, Response
from apistar.server.components import Component
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker, Session, scoped_session
class Decimal(validators.NumericType):
numeric_type = decimal.Decimal
def validate(self, value, definitions=None, allow_coerce=False):
if value is None and self.allow_null:
return None
elif value is None:
self.error('null')
elif isinstance(value, bool):
self.error('type')
elif self.numeric_type is int and isinstance(value, float) and not value.is_integer():
self.error('integer')
elif not isinstance(value, (int, float, decimal.Decimal)) and not allow_coerce:
self.error('type')
elif isinstance(value, float) and not isfinite(value):
self.error('finite')
try:
value = self.numeric_type(value)
except (TypeError, ValueError):
self.error('type')
if self.enum is not None:
if value not in self.enum:
if len(self.enum) == 1:
self.error('exact')
self.error('enum')
if self.minimum is not None:
if self.exclusive_minimum:
if value <= self.minimum:
self.error('exclusive_minimum')
else:
if value < self.minimum:
self.error('minimum')
if self.maximum is not None:
if self.exclusive_maximum:
if value >= self.maximum:
self.error('exclusive_maximum')
else:
if value > self.maximum:
self.error('maximum')
if self.multiple_of is not None:
if isinstance(self.multiple_of, float):
if not (value * (1 / self.multiple_of)).is_integer():
self.error('multiple_of')
else:
if value % self.multiple_of:
self.error('multiple_of')
return value
class ExtJSONResponse(JSONResponse):
"""JSON Response with support for ISO 8601 datetime serialization and Decimal to float casting"""
def default(self, obj: typing.Any) -> typing.Any:
if isinstance(obj, types.Type):
return dict(obj)
if isinstance(obj, dt.datetime):
return obj.isoformat()
elif isinstance(obj, decimal.Decimal):
return float(obj)
error = "Object of type '%s' is not JSON serializable."
return TypeError(error % type(obj).__name_)
DBSession = scoped_session(sessionmaker())
class SQLAlchemySession(Component):
def __init__(self, engine=None):
if not isinstance(engine, Engine):
raise ValueError('SQLAlchemySession must be instantiated with a sqlalchemy.engine.Engine object')
self.engine = engine
DBSession.configure(bind=self.engine)
def resolve(self) -> Session:
return DBSession()
class SQLAlchemyHook:
def on_request(self, session: Session):
return
def on_response(self, session: Session, response: Response):
DBSession.remove()
return response
def on_error(self, session: Session, response: Response):
session.rollback()
DBSession.remove()
return response