383 lines
12 KiB
Python
383 lines
12 KiB
Python
import enum
|
|
import ipaddress
|
|
import re
|
|
import uuid
|
|
from distutils.version import StrictVersion
|
|
from typing import Any, Type, Union
|
|
|
|
from boltons.typeutils import classproperty
|
|
from boltons.urlutils import URL as BoltonsUrl
|
|
from ereuse_utils import if_none_return_none
|
|
from flask_sqlalchemy import BaseQuery
|
|
from flask_sqlalchemy import Model as _Model
|
|
from flask_sqlalchemy import SignallingSession
|
|
from flask_sqlalchemy import SQLAlchemy as FlaskSQLAlchemy
|
|
from sqlalchemy import CheckConstraint, SmallInteger, cast, event, types
|
|
from sqlalchemy.dialects.postgresql import ARRAY, INET
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound
|
|
from sqlalchemy_utils import Ltree
|
|
from werkzeug.exceptions import BadRequest, NotFound, UnprocessableEntity
|
|
|
|
|
|
class ResourceNotFound(NotFound):
|
|
# todo show id
|
|
def __init__(self, resource: str) -> None:
|
|
super().__init__('The {} doesn\'t exist.'.format(resource))
|
|
|
|
|
|
class MultipleResourcesFound(UnprocessableEntity):
|
|
# todo show id
|
|
def __init__(self, resource: str) -> None:
|
|
super().__init__(
|
|
'Expected only one {} but multiple where found'.format(resource)
|
|
)
|
|
|
|
|
|
POLYMORPHIC_ID = 'polymorphic_identity'
|
|
POLYMORPHIC_ON = 'polymorphic_on'
|
|
INHERIT_COND = 'inherit_condition'
|
|
DEFAULT_CASCADE = 'save-update, merge'
|
|
CASCADE_DEL = '{}, delete'.format(DEFAULT_CASCADE)
|
|
CASCADE_OWN = '{}, delete-orphan'.format(CASCADE_DEL)
|
|
DB_CASCADE_SET_NULL = 'SET NULL'
|
|
|
|
|
|
class Query(BaseQuery):
|
|
def one(self):
|
|
try:
|
|
return super().one()
|
|
except NoResultFound:
|
|
raise ResourceNotFound(self._entities[0]._label_name)
|
|
except MultipleResultsFound:
|
|
raise MultipleResourcesFound(self._entities[0]._label_name)
|
|
|
|
|
|
class Model(_Model):
|
|
# Just provide typing
|
|
query_class = Query # type: Type[Query]
|
|
query = None # type: Query
|
|
|
|
@classproperty
|
|
def t(cls):
|
|
return cls.__name__
|
|
|
|
|
|
class Session(SignallingSession):
|
|
"""A SQLAlchemy session that raises better exceptions."""
|
|
|
|
def _flush(self, objects=None):
|
|
try:
|
|
super()._flush(objects)
|
|
except IntegrityError as e:
|
|
raise DBError(e) # This creates a suitable subclass
|
|
|
|
|
|
class SchemaSession(Session):
|
|
"""Session that is configured to use a PostgreSQL's Schema.
|
|
|
|
Idea from `here <https://stackoverflow.com/a/9299021>`_.
|
|
"""
|
|
|
|
def __init__(self, db, autocommit=False, autoflush=True, **options):
|
|
super().__init__(db, autocommit, autoflush, **options)
|
|
self.execute('SET search_path TO {}, public'.format(self.app.schema))
|
|
|
|
|
|
class StrictVersionType(types.TypeDecorator):
|
|
"""StrictVersion support for SQLAlchemy as Unicode.
|
|
|
|
Idea `from official documentation <http://docs.sqlalchemy.org/en/
|
|
latest/core/custom_types.html#augmenting-existing-types>`_.
|
|
"""
|
|
|
|
impl = types.Unicode
|
|
|
|
@if_none_return_none
|
|
def process_bind_param(self, value, dialect):
|
|
return str(value)
|
|
|
|
@if_none_return_none
|
|
def process_result_value(self, value, dialect):
|
|
return StrictVersion(value)
|
|
|
|
|
|
class URL(types.TypeDecorator):
|
|
"""bolton's URL support for SQLAlchemy as Unicode."""
|
|
|
|
impl = types.Unicode
|
|
|
|
@if_none_return_none
|
|
def process_bind_param(self, value: BoltonsUrl, dialect):
|
|
return value.to_text()
|
|
|
|
@if_none_return_none
|
|
def process_result_value(self, value, dialect):
|
|
return BoltonsUrl(value)
|
|
|
|
|
|
class IP(types.TypeDecorator):
|
|
"""ipaddress support for SQLAlchemy as PSQL INET."""
|
|
|
|
impl = INET
|
|
|
|
@if_none_return_none
|
|
def process_bind_param(self, value, dialect):
|
|
return str(value)
|
|
|
|
@if_none_return_none
|
|
def process_result_value(self, value, dialect):
|
|
return ipaddress.ip_address(value)
|
|
|
|
|
|
class IntEnum(types.TypeDecorator):
|
|
"""SmallInteger -- IntEnum"""
|
|
|
|
impl = SmallInteger
|
|
|
|
def __init__(self, enumeration: Type[enum.IntEnum], *args, **kwargs):
|
|
self.enum = enumeration
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@if_none_return_none
|
|
def process_bind_param(self, value, dialect):
|
|
assert isinstance(value, self.enum), 'Value should be instance of {}'.format(
|
|
self.enum
|
|
)
|
|
return value.value
|
|
|
|
@if_none_return_none
|
|
def process_result_value(self, value, dialect):
|
|
return self.enum(value)
|
|
|
|
|
|
class UUIDLtree(Ltree):
|
|
"""This Ltree only wants UUIDs as paths elements."""
|
|
|
|
def __init__(self, path_or_ltree: Union[Ltree, uuid.UUID]):
|
|
"""
|
|
Creates a new Ltree. If the passed-in value is an UUID,
|
|
it automatically generates a suitable string for Ltree.
|
|
"""
|
|
if not isinstance(path_or_ltree, Ltree):
|
|
if isinstance(path_or_ltree, uuid.UUID):
|
|
path_or_ltree = self.convert(path_or_ltree)
|
|
else:
|
|
raise ValueError(
|
|
'Ltree does not accept {}'.format(path_or_ltree.__class__)
|
|
)
|
|
super().__init__(path_or_ltree)
|
|
|
|
@staticmethod
|
|
def convert(id: uuid.UUID) -> str:
|
|
"""Transforms an uuid to a ready-to-ltree str representation."""
|
|
return str(id).replace('-', '_')
|
|
|
|
|
|
def check_range(column: str, min=1, max=None) -> CheckConstraint:
|
|
"""Database constraint for ranged values."""
|
|
constraint = (
|
|
'>= {}'.format(min) if max is None else 'BETWEEN {} AND {}'.format(min, max)
|
|
)
|
|
return CheckConstraint('{} {}'.format(column, constraint))
|
|
|
|
|
|
def check_lower(field_name: str):
|
|
"""Constraint that checks if the string is lower case."""
|
|
return CheckConstraint(
|
|
'{0} = lower({0})'.format(field_name),
|
|
name='{} must be lower'.format(field_name),
|
|
)
|
|
|
|
|
|
class ArrayOfEnum(ARRAY):
|
|
"""
|
|
Allows to use Arrays of Enums for psql.
|
|
|
|
From `the docs <http://docs.sqlalchemy.org/en/latest/dialects/
|
|
postgresql.html?highlight=array#postgresql-array-of-enum>`_
|
|
and `this issue <https://bitbucket.org/zzzeek/sqlalchemy/issues/
|
|
3467/array-of-enums-does-not-allow-assigning>`_.
|
|
"""
|
|
|
|
def bind_expression(self, bindvalue):
|
|
return cast(bindvalue, self)
|
|
|
|
def result_processor(self, dialect, coltype):
|
|
super_rp = super(ArrayOfEnum, self).result_processor(dialect, coltype)
|
|
|
|
def handle_raw_string(value):
|
|
inner = re.match(r'^{(.*)}$', value).group(1)
|
|
return inner.split(',') if inner else []
|
|
|
|
def process(value):
|
|
if value is None:
|
|
return None
|
|
return super_rp(handle_raw_string(value))
|
|
|
|
return process
|
|
|
|
|
|
class SQLAlchemy(FlaskSQLAlchemy):
|
|
"""
|
|
Enhances :class:`flask_sqlalchemy.SQLAlchemy` by adding our
|
|
Session and Model.
|
|
"""
|
|
|
|
StrictVersionType = StrictVersionType
|
|
URL = URL
|
|
IP = IP
|
|
IntEnum = IntEnum
|
|
UUIDLtree = UUIDLtree
|
|
ArrayOfEnum = ArrayOfEnum
|
|
|
|
def __init__(
|
|
self,
|
|
app=None,
|
|
use_native_unicode=True,
|
|
session_options=None,
|
|
metadata=None,
|
|
query_class=BaseQuery,
|
|
model_class=Model,
|
|
):
|
|
super().__init__(
|
|
app, use_native_unicode, session_options, metadata, query_class, model_class
|
|
)
|
|
|
|
def create_session(self, options):
|
|
"""As parent's create_session but adding our Session."""
|
|
return sessionmaker(class_=Session, db=self, **options)
|
|
|
|
|
|
class SchemaSQLAlchemy(SQLAlchemy):
|
|
"""
|
|
Enhances :class:`flask_sqlalchemy.SQLAlchemy` by using PostgreSQL's
|
|
schemas when creating/dropping tables.
|
|
|
|
See :attr:`teal.config.SCHEMA` for more info.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
app=None,
|
|
use_native_unicode=True,
|
|
session_options=None,
|
|
metadata=None,
|
|
query_class=Query,
|
|
model_class=Model,
|
|
):
|
|
super().__init__(
|
|
app, use_native_unicode, session_options, metadata, query_class, model_class
|
|
)
|
|
# The following listeners set psql's search_path to the correct
|
|
# schema and create the schemas accordingly
|
|
|
|
# Specifically:
|
|
# 1. Creates the schemas and set ``search_path`` to app's config SCHEMA
|
|
event.listen(self.metadata, 'before_create', self.create_schemas)
|
|
# Set ``search_path`` to default (``public``)
|
|
event.listen(self.metadata, 'after_create', self.revert_connection)
|
|
# Set ``search_path`` to app's config SCHEMA
|
|
event.listen(self.metadata, 'before_drop', self.set_search_path)
|
|
# Set ``search_path`` to default (``public``)
|
|
event.listen(self.metadata, 'after_drop', self.revert_connection)
|
|
|
|
def create_all(self, bind='__all__', app=None, exclude_schema=None):
|
|
"""Create all tables.
|
|
|
|
:param exclude_schema: Do not create tables in this schema.
|
|
"""
|
|
app = self.get_app(app)
|
|
# todo how to pass exclude_schema without contaminating self?
|
|
self._exclude_schema = exclude_schema
|
|
super().create_all(bind, app)
|
|
|
|
def _execute_for_all_tables(self, app, bind, operation, skip_tables=False):
|
|
# todo how to pass app to our event listeners without contaminating self?
|
|
self._app = self.get_app(app)
|
|
super()._execute_for_all_tables(app, bind, operation, skip_tables)
|
|
|
|
def get_tables_for_bind(self, bind=None):
|
|
"""As super method, but only getting tales that are not
|
|
part of exclude_schema, if set.
|
|
"""
|
|
tables = super().get_tables_for_bind(bind)
|
|
if getattr(self, '_exclude_schema', None):
|
|
tables = [t for t in tables if t.schema != self._exclude_schema]
|
|
return tables
|
|
|
|
def create_schemas(self, target, connection, **kw):
|
|
"""
|
|
Create the schemas and set the active schema.
|
|
|
|
From `here <https://bitbucket.org/zzzeek/sqlalchemy/issues/3914/
|
|
extend-create_all-drop_all-to-include#comment-40129850>`_.
|
|
"""
|
|
schemas = set(table.schema for table in target.tables.values() if table.schema)
|
|
if self._app.schema:
|
|
schemas.add(self._app.schema)
|
|
for schema in schemas:
|
|
connection.execute('CREATE SCHEMA IF NOT EXISTS {}'.format(schema))
|
|
self.set_search_path(target, connection)
|
|
|
|
def set_search_path(self, _, connection, **kw):
|
|
app = self.get_app()
|
|
if app.schema:
|
|
connection.execute('SET search_path TO {}, public'.format(app.schema))
|
|
|
|
def revert_connection(self, _, connection, **kw):
|
|
connection.execute('SET search_path TO public')
|
|
|
|
def create_session(self, options):
|
|
"""As parent's create_session but adding our SchemaSession."""
|
|
return sessionmaker(class_=SchemaSession, db=self, **options)
|
|
|
|
def drop_schema(self, app=None, schema=None):
|
|
"""Nukes a schema and everything that depends on it."""
|
|
app = self.get_app(app)
|
|
schema = schema or app.schema
|
|
with self.engine.begin() as conn:
|
|
conn.execute('DROP SCHEMA IF EXISTS {} CASCADE'.format(schema))
|
|
|
|
def has_schema(self, schema: str) -> bool:
|
|
"""Does the db have the passed-in schema?"""
|
|
return self.engine.execute(
|
|
"SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_namespace WHERE nspname='{}')".format(
|
|
schema
|
|
)
|
|
).scalar()
|
|
|
|
|
|
class DBError(BadRequest):
|
|
"""An Error from the database.
|
|
|
|
This helper error is used to map SQLAlchemy's IntegrityError
|
|
to more precise errors (like UniqueViolation) that are understood
|
|
as a client-ready HTTP Error.
|
|
|
|
When instantiating the class it auto-selects the best error.
|
|
"""
|
|
|
|
def __init__(self, origin: IntegrityError):
|
|
super().__init__(str(origin))
|
|
self._origin = origin
|
|
|
|
def __new__(cls, origin: IntegrityError) -> Any:
|
|
msg = str(origin)
|
|
if 'unique constraint' in msg.lower():
|
|
return super().__new__(UniqueViolation)
|
|
return super().__new__(cls)
|
|
|
|
|
|
class UniqueViolation(DBError):
|
|
def __init__(self, origin: IntegrityError):
|
|
super().__init__(origin)
|
|
self.constraint = self.description.split('"')[1]
|
|
self.field_name = None
|
|
self.field_value = None
|
|
if isinstance(origin.params, dict):
|
|
self.field_name, self.field_value = next(
|
|
(k, v) for k, v in origin.params.items() if k in self.constraint
|
|
)
|