diff --git a/news/app.py b/news/app.py index 5545dd5..fff4060 100644 --- a/news/app.py +++ b/news/app.py @@ -1,39 +1,54 @@ import os +import typing from apistar import Route, http, App from sqlalchemy import create_engine -from .models import NewsArticle, NewsSource -from .schema import NewsSourceSchema, NewsArticleSchema +from sqlalchemy.exc import InvalidRequestError, IntegrityError + +from .models import Article, Source, MediaProtocol +from .schema import NewsArticleSchema, SourceSchema, NewsSourceSchema from .util import Session, SQLAlchemySession, SQLAlchemyHook +from .exc import Conflict news_source_schema = NewsSourceSchema() news_article_schema = NewsArticleSchema() -def get_sources(session: Session): +def get_sources(session: Session) -> typing.List[SourceSchema]: """Retrieves all News Sources""" - sources = session.query(NewsSource).all() - return http.JSONResponse(news_source_schema.dump(sources, many=True).data, status_code=200) + sources = session.query(Source).all() + return [SourceSchema(source) for source in sources] def add_source(session: Session, request_data: http.RequestData, app: App): """Adds a single source to the News Source collection""" - news_source_data, errors = news_source_schema.load(request_data) - if errors: - msg = {"message": "400 Bad Request", "error": errors} - return http.JSONResponse(msg, status_code=400) - news_source = NewsSource(**news_source_data) - session.add(news_source) - session.flush() + media_protocol = MediaProtocol.get_or_create(session, request_data.get('media_protocol')) - headers = {"Location": app.reverse_url('get_source', id=news_source.id)} + if media_protocol.id is None: + session.add(media_protocol) + session.flush() + + source_data = SourceSchema(url=request_data.get('url'), name=request_data.get('name')) + + try: + source = Source(media_protocol_id=media_protocol.id, **source_data) + session.add(source) + session.flush() + + except (InvalidRequestError, IntegrityError): + session.rollback() + source = session.query(Source).filter_by(url=source_data.url).one_or_none() + raise Conflict(location=app.reverse_url('get_source', id=source.id)) + + headers = {"Location": app.reverse_url('get_source', id=source.id)} session.commit() - return http.JSONResponse(news_source_schema.dump(news_source).data, status_code=201, headers=headers) + + return http.JSONResponse(SourceSchema(source), status_code=201, headers=headers) def delete_source(session: Session, id: int): """Delete a single News Source from the collection by id""" - news_source = session.query(NewsSource).filter_by(id=id).one_or_none() + news_source = session.query(Source).filter_by(id=id).one_or_none() if news_source is None: msg = {"message": "404 Not Found"} return http.JSONResponse(msg, status_code=404) @@ -46,7 +61,7 @@ def delete_source(session: Session, id: int): def get_source(session: Session, id: int): """Retrieves a single News Source by id""" - news_source = session.query(NewsSource).filter_by(id=id).one_or_none() + news_source = session.query(Source).filter_by(id=id).one_or_none() if news_source is None: msg = {"message": "404 Not Found"} return http.JSONResponse(msg, status_code=404) @@ -56,13 +71,13 @@ def get_source(session: Session, id: int): def get_articles(session: Session): """Retrieves all articles""" - articles = session.query(NewsArticle).all() + articles = session.query(Article).all() return http.JSONResponse(news_article_schema.dump(articles, many=True).data, status_code=200) def get_article(session: Session, id: int): """Retrieves a single article by id""" - news_article = session.query(NewsArticle).filter_by(id=id).one_or_none() + news_article = session.query(Article).filter_by(id=id).one_or_none() if news_article is None: msg = {"message": "404 Not Found"} return http.JSONResponse(msg, status_code=404) @@ -75,7 +90,7 @@ def add_article(session: Session, request_data: http.RequestData, app: App): if errors: msg = {"message": "400 Bad Request", "error": errors} return http.JSONResponse(msg, status_code=400) - news_article = NewsArticle(**news_article_data) + news_article = Article(**news_article_data) session.add(news_article) session.flush() @@ -87,7 +102,7 @@ def add_article(session: Session, request_data: http.RequestData, app: App): def delete_article(session: Session, id): """Delete a single News Sources from the collection by id""" - news_article = session.query(NewsArticle).filter_by(id=id).one_or_none() + news_article = session.query(Article).filter_by(id=id).one_or_none() if news_article is None: msg = {"message": "404 Not Found"} return http.JSONResponse(msg, status_code=404) diff --git a/news/exc.py b/news/exc.py new file mode 100644 index 0000000..9690631 --- /dev/null +++ b/news/exc.py @@ -0,0 +1,20 @@ +from typing import Union +from apistar.exceptions import HTTPException + + +class Conflict(HTTPException): + default_status_code = 409 + default_detail = 'The request could not be completed due to a conflict with the current state of the resource' + + def __init__(self, + location: str=None, + detail: Union[str, dict]=None, + status_code: int=None) -> None: + self.location = location + super().__init__(detail, status_code) + + def get_headers(self): + if self.location: + return {'Location': self.location} + else: + return {} \ No newline at end of file diff --git a/news/models.py b/news/models.py index c2daa29..0e982ff 100644 --- a/news/models.py +++ b/news/models.py @@ -1,6 +1,14 @@ from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import Column, ForeignKey, DateTime, BigInteger, Text, Table, ARRAY -from sqlalchemy.orm import relationship, backref +from sqlalchemy import ( + Column, + Integer, + ForeignKey, + DateTime, + BigInteger, + Text, + Table, +) +from sqlalchemy.orm import relationship from sqlalchemy.sql import expression from sqlalchemy.ext.compiler import compiles from sqlalchemy.types import DateTime as DateTimeType @@ -31,63 +39,174 @@ class DBMixin: return d -def ReferenceCol(tablename, nullable=False, **kw): - return Column(ForeignKey('{}.id'.format(tablename)), nullable=nullable, **kw) +# Reference Tables +article_author = Table( + 'article_author', Base.metadata, + Column('article_id,', ForeignKey('article.id')), + Column('author_id,', ForeignKey('author.id')) +) +article_tag = Table( + 'articles_tag', Base.metadata, + Column('article_id,', ForeignKey('article.id')), + Column('tag_id,', ForeignKey('tag.id')) +) -categories = Table('news_source_category', Base.metadata, - Column('news_source_id,', ForeignKey('news_source.id')), - Column('category_id,', ForeignKey('category.id')) - ) +source_author = Table( + 'sources_author', Base.metadata, + Column('source_id,', ForeignKey('source.id')), + Column('author_id,', ForeignKey('author.id')) +) -tags = Table('news_article_tag', Base.metadata, - Column('news_article_id', ForeignKey('news_article.id')), - Column('tag_id', ForeignKey('tag.id')), - ) +source_category = Table( + 'source_category', Base.metadata, + Column('source_id,', ForeignKey('source.id')), + Column('category_id,', ForeignKey('category.id')) +) +# Tables class Tag(DBMixin, Base): __tablename__ = 'tag' - tag_name = Column(Text, unique=True) + tag = Column(Text, unique=True, nullable=False) + articles = relationship( + 'Article', + secondary=article_tag, + back_populates='tags', + ) @staticmethod def get_or_create(session, name): try: - return session.query(Tag).filter_by(tag_name=name).one() + return session.query(Tag).filter_by(tag=name).one() except NoResultFound: - return Tag(tag_name=name) + return Tag(tag=name) class Category(DBMixin, Base): __tablename__ = 'category' - category_name = Column(Text, unique=True) + category = Column(Text, unique=True, nullable=False) + sources = relationship( + 'Source', + secondary=source_category, + back_populates='categories', + ) @staticmethod def get_or_create(session, name): try: - return session.query(Category).filter_by(category_name=name).one() + return session.query(Category).filter_by(category=name).one() except NoResultFound: - return Category(category_name=name) - - -class NewsSource(DBMixin, Base): - __tablename__ = 'news_source' - url = Column(Text, unique=True) - source_name = Column(Text) - source_type = Column(Text) - categories = relationship('Category', secondary=categories, - backref=backref('news_sources', lazy='dynamic')) - articles = relationship('NewsArticle', - backref=backref('news_source'), - cascade="all, delete, delete-orphan") - - -class NewsArticle(DBMixin, Base): - __tablename__ = 'news_article' - news_source_id = ReferenceCol('news_source') - url = Column(Text, unique=True) - title = Column(Text) - authors = Column(ARRAY(Text)) + return Category(category=name) + + +class MediaProtocol(DBMixin, Base): + __tablename__ = 'media_protocol' + protocol = Column(Text, nullable=False) + sources = relationship('Source', back_populates='media_protocol') + + @staticmethod + def get_or_create(session, name): + try: + return session.query(MediaProtocol).filter_by(protocol=name).one() + except NoResultFound: + return MediaProtocol(protocol=name) + + +class Source(DBMixin, Base): + __tablename__ = 'source' + url = Column(Text, unique=True, nullable=False) + name = Column(Text, nullable=False) + # One-to-Many + media_protocol_id = Column(Integer, ForeignKey('media_protocol.id')) + media_protocol = relationship('MediaProtocol', back_populates='sources') + # Many-to-one + articles = relationship('Article', back_populates='source') + # Many-to-Many + authors = relationship( + 'Author', + secondary=source_author, + back_populates='sources', + ) + categories = relationship( + 'Category', + secondary=source_category, + back_populates='sources', + ) + + @staticmethod + def get_or_create(session, source, url): + try: + return session.query(Source).filter_by( + source=source, + url=url + ).one() + except NoResultFound: + return Source(source=source, url=url) + + +class Author(DBMixin, Base): + __tablename__ = 'author' + first_name = Column(Text, nullable=False) + middle_name = Column(Text, nullable=True) + last_name = Column(Text, nullable=False) + # __table_args__ = ( + # UniqueConstraint('first_name', 'last_name', name='full_name'), + # ) + # Many-to-Many + articles = relationship( + 'Article', + secondary=article_author, + back_populates='authors', + ) + sources = relationship( + 'Source', + secondary=source_author, + back_populates='authors', + ) + + @staticmethod + def get_or_create(session, first_name, middle_name, last_name): + try: + return session.query(Author).filter_by( + first_name=first_name, + middle_name=middle_name, + last_name=middle_name, + ).one() + except NoResultFound: + return Author( + first_name=first_name, + middle_name=middle_name, + last_name=last_name + ) + + +class Article(DBMixin, Base): + __tablename__ = 'article' + url = Column(Text, unique=True, nullable=False) + title = Column(Text, nullable=False) published_date = Column(DateTime) - news_blob = Column(Text) - tags = relationship('Tag', secondary=tags, backref=backref('articles', lazy='dynamic')) + blob = Column(Text, nullable=False) + # Many-to-One + source_id = Column(Integer, ForeignKey('source.id')) + source = relationship('Source', back_populates='articles') + # Many-to-Many + authors = relationship( + 'Author', + secondary=article_author, + back_populates='articles', + lazy='dynamic', + ) + tags = relationship( + 'Tag', + secondary=article_tag, + back_populates='articles', + lazy='dynamic', + ) + + @staticmethod + def get_or_create(session, url): + try: + return session.query(Article).filter_by(url=url).one() + except NoResultFound: + return Article(url=url) diff --git a/news/schema.py b/news/schema.py index 8ee8580..b329253 100644 --- a/news/schema.py +++ b/news/schema.py @@ -1,5 +1,5 @@ import datetime as dt - +from apistar import types, validators from marshmallow import Schema, fields @@ -52,6 +52,18 @@ class NewsSourceSchema(Schema): ordered = True +# class UnixTimestamp(validators.DateTime): +# pass + + +class SourceSchema(types.Type): + id = validators.Integer(allow_null=True) + created_date = validators.DateTime(allow_null=True) + modified_date = validators.DateTime(allow_null=True) + url = validators.String() + name = validators.String() + + # TODO deserialization of timestamp to class NewsArticleSchema(Schema): id = fields.Int(dump_only=True) diff --git a/scripts/manage.py b/scripts/manage.py new file mode 100644 index 0000000..7ad7776 --- /dev/null +++ b/scripts/manage.py @@ -0,0 +1,36 @@ +import os +from sqlalchemy import create_engine +from news.models import Base +import click + +engine = create_engine(os.getenv('NEWS_DB', '')) + +metadata = Base.metadata + +metadata.bind = engine + + +@click.group() +def cli(): + pass + + +@click.command('create_db') +def create_db(): + click.echo("Creating news database") + Base.metadata.create_all() + click.echo("Database created") + + +@click.command('drop_db') +def drop_db(): + print("Dropping news database") + Base.metadata.drop_all() + print("Database dropped") + + +cli.add_command(create_db) +cli.add_command(drop_db) + +if __name__ == "__main__": + cli()