You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
106 lines
3.4 KiB
Python
106 lines
3.4 KiB
Python
7 years ago
|
import datetime as dt
|
||
|
import decimal
|
||
|
from math import isfinite
|
||
|
import typing
|
||
|
from apistar import types, validators
|
||
|
from apistar.http import JSONResponse, Response
|
||
|
from apistar.server.components import Component
|
||
|
from sqlalchemy.engine import Engine
|
||
|
from sqlalchemy.orm import sessionmaker, Session, scoped_session
|
||
|
|
||
|
|
||
|
class Decimal(validators.NumericType):
|
||
|
numeric_type = decimal.Decimal
|
||
|
|
||
|
def validate(self, value, definitions=None, allow_coerce=False):
|
||
|
if value is None and self.allow_null:
|
||
|
return None
|
||
|
elif value is None:
|
||
|
self.error('null')
|
||
|
elif isinstance(value, bool):
|
||
|
self.error('type')
|
||
|
elif self.numeric_type is int and isinstance(value, float) and not value.is_integer():
|
||
|
self.error('integer')
|
||
|
elif not isinstance(value, (int, float, decimal.Decimal)) and not allow_coerce:
|
||
|
self.error('type')
|
||
|
elif isinstance(value, float) and not isfinite(value):
|
||
|
self.error('finite')
|
||
|
|
||
|
try:
|
||
|
value = self.numeric_type(value)
|
||
|
except (TypeError, ValueError):
|
||
|
self.error('type')
|
||
|
|
||
|
if self.enum is not None:
|
||
|
if value not in self.enum:
|
||
|
if len(self.enum) == 1:
|
||
|
self.error('exact')
|
||
|
self.error('enum')
|
||
|
|
||
|
if self.minimum is not None:
|
||
|
if self.exclusive_minimum:
|
||
|
if value <= self.minimum:
|
||
|
self.error('exclusive_minimum')
|
||
|
else:
|
||
|
if value < self.minimum:
|
||
|
self.error('minimum')
|
||
|
|
||
|
if self.maximum is not None:
|
||
|
if self.exclusive_maximum:
|
||
|
if value >= self.maximum:
|
||
|
self.error('exclusive_maximum')
|
||
|
else:
|
||
|
if value > self.maximum:
|
||
|
self.error('maximum')
|
||
|
|
||
|
if self.multiple_of is not None:
|
||
|
if isinstance(self.multiple_of, float):
|
||
|
if not (value * (1 / self.multiple_of)).is_integer():
|
||
|
self.error('multiple_of')
|
||
|
else:
|
||
|
if value % self.multiple_of:
|
||
|
self.error('multiple_of')
|
||
|
|
||
|
return value
|
||
|
|
||
|
|
||
|
class ExtJSONResponse(JSONResponse):
|
||
|
"""JSON Response with support for ISO 8601 datetime serialization and Decimal to float casting"""
|
||
|
|
||
|
def default(self, obj: typing.Any) -> typing.Any:
|
||
|
if isinstance(obj, types.Type):
|
||
|
return dict(obj)
|
||
|
if isinstance(obj, dt.datetime):
|
||
|
return obj.isoformat()
|
||
|
elif isinstance(obj, decimal.Decimal):
|
||
|
return float(obj)
|
||
|
error = "Object of type '%s' is not JSON serializable."
|
||
|
return TypeError(error % type(obj).__name_)
|
||
|
|
||
|
|
||
|
DBSession = scoped_session(sessionmaker())
|
||
|
|
||
|
|
||
|
class SQLAlchemySession(Component):
|
||
|
def __init__(self, engine=None):
|
||
|
if not isinstance(engine, Engine):
|
||
|
raise ValueError('SQLAlchemySession must be instantiated with a sqlalchemy.engine.Engine object')
|
||
|
self.engine = engine
|
||
|
DBSession.configure(bind=self.engine)
|
||
|
|
||
|
def resolve(self) -> Session:
|
||
|
return DBSession()
|
||
|
|
||
|
|
||
|
class SQLAlchemyHook:
|
||
|
def on_request(self, session: Session):
|
||
|
return
|
||
|
|
||
|
def on_response(self, session: Session, response: Response):
|
||
|
DBSession.remove()
|
||
|
return response
|
||
|
|
||
|
def on_error(self, session: Session, response: Response):
|
||
|
session.rollback()
|
||
|
DBSession.remove()
|
||
|
return response
|