import enum
import ipaddress
import re
import uuid
from distutils.version import StrictVersion
from typing import Any, Type, Union

from boltons.typeutils import classproperty
from boltons.urlutils import URL as BoltonsUrl
from ereuse_devicehub.ereuse_utils import if_none_return_none
from flask_sqlalchemy import BaseQuery
from flask_sqlalchemy import Model as _Model
from flask_sqlalchemy import SignallingSession
from flask_sqlalchemy import SQLAlchemy as FlaskSQLAlchemy
from sqlalchemy import CheckConstraint, SmallInteger, cast, event, types
from sqlalchemy.dialects.postgresql import ARRAY, INET
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound
from sqlalchemy_utils import Ltree
from werkzeug.exceptions import BadRequest, NotFound, UnprocessableEntity


class ResourceNotFound(NotFound):
    # todo show id
    def __init__(self, resource: str) -> None:
        super().__init__('The {} doesn\'t exist.'.format(resource))


class MultipleResourcesFound(UnprocessableEntity):
    # todo show id
    def __init__(self, resource: str) -> None:
        super().__init__(
            'Expected only one {} but multiple where found'.format(resource)
        )


POLYMORPHIC_ID = 'polymorphic_identity'
POLYMORPHIC_ON = 'polymorphic_on'
INHERIT_COND = 'inherit_condition'
DEFAULT_CASCADE = 'save-update, merge'
CASCADE_DEL = '{}, delete'.format(DEFAULT_CASCADE)
CASCADE_OWN = '{}, delete-orphan'.format(CASCADE_DEL)
DB_CASCADE_SET_NULL = 'SET NULL'


class Query(BaseQuery):
    def one(self):
        try:
            return super().one()
        except NoResultFound:
            raise ResourceNotFound(self._entities[0]._label_name)
        except MultipleResultsFound:
            raise MultipleResourcesFound(self._entities[0]._label_name)


class Model(_Model):
    # Just provide typing
    query_class = Query  # type: Type[Query]
    query = None  # type: Query

    @classproperty
    def t(cls):
        return cls.__name__


class Session(SignallingSession):
    """A SQLAlchemy session that raises better exceptions."""

    def _flush(self, objects=None):
        try:
            super()._flush(objects)
        except IntegrityError as e:
            raise DBError(e)  # This creates a suitable subclass


class SchemaSession(Session):
    """Session that is configured to use a PostgreSQL's Schema.

    Idea from `here <https://stackoverflow.com/a/9299021>`_.
    """

    def __init__(self, db, autocommit=False, autoflush=True, **options):
        super().__init__(db, autocommit, autoflush, **options)
        self.execute('SET search_path TO {}, public'.format(self.app.schema))


class StrictVersionType(types.TypeDecorator):
    """StrictVersion support for SQLAlchemy as Unicode.

    Idea `from official documentation <http://docs.sqlalchemy.org/en/
    latest/core/custom_types.html#augmenting-existing-types>`_.
    """

    impl = types.Unicode

    @if_none_return_none
    def process_bind_param(self, value, dialect):
        return str(value)

    @if_none_return_none
    def process_result_value(self, value, dialect):
        return StrictVersion(value)


class URL(types.TypeDecorator):
    """bolton's URL support for SQLAlchemy as Unicode."""

    impl = types.Unicode

    @if_none_return_none
    def process_bind_param(self, value: BoltonsUrl, dialect):
        return value.to_text()

    @if_none_return_none
    def process_result_value(self, value, dialect):
        return BoltonsUrl(value)


class IP(types.TypeDecorator):
    """ipaddress support for SQLAlchemy as PSQL INET."""

    impl = INET

    @if_none_return_none
    def process_bind_param(self, value, dialect):
        return str(value)

    @if_none_return_none
    def process_result_value(self, value, dialect):
        return ipaddress.ip_address(value)


class IntEnum(types.TypeDecorator):
    """SmallInteger -- IntEnum"""

    impl = SmallInteger

    def __init__(self, enumeration: Type[enum.IntEnum], *args, **kwargs):
        self.enum = enumeration
        super().__init__(*args, **kwargs)

    @if_none_return_none
    def process_bind_param(self, value, dialect):
        assert isinstance(value, self.enum), 'Value should be instance of {}'.format(
            self.enum
        )
        return value.value

    @if_none_return_none
    def process_result_value(self, value, dialect):
        return self.enum(value)


class UUIDLtree(Ltree):
    """This Ltree only wants UUIDs as paths elements."""

    def __init__(self, path_or_ltree: Union[Ltree, uuid.UUID]):
        """
        Creates a new Ltree. If the passed-in value is an UUID,
        it automatically generates a suitable string for Ltree.
        """
        if not isinstance(path_or_ltree, Ltree):
            if isinstance(path_or_ltree, uuid.UUID):
                path_or_ltree = self.convert(path_or_ltree)
            else:
                raise ValueError(
                    'Ltree does not accept {}'.format(path_or_ltree.__class__)
                )
        super().__init__(path_or_ltree)

    @staticmethod
    def convert(id: uuid.UUID) -> str:
        """Transforms an uuid to a ready-to-ltree str representation."""
        return str(id).replace('-', '_')


def check_range(column: str, min=1, max=None) -> CheckConstraint:
    """Database constraint for ranged values."""
    constraint = (
        '>= {}'.format(min) if max is None else 'BETWEEN {} AND {}'.format(min, max)
    )
    return CheckConstraint('{} {}'.format(column, constraint))


def check_lower(field_name: str):
    """Constraint that checks if the string is lower case."""
    return CheckConstraint(
        '{0} = lower({0})'.format(field_name),
        name='{} must be lower'.format(field_name),
    )


class ArrayOfEnum(ARRAY):
    """
    Allows to use Arrays of Enums for psql.

    From `the docs <http://docs.sqlalchemy.org/en/latest/dialects/
    postgresql.html?highlight=array#postgresql-array-of-enum>`_
    and `this issue <https://bitbucket.org/zzzeek/sqlalchemy/issues/
    3467/array-of-enums-does-not-allow-assigning>`_.
    """

    def bind_expression(self, bindvalue):
        return cast(bindvalue, self)

    def result_processor(self, dialect, coltype):
        super_rp = super(ArrayOfEnum, self).result_processor(dialect, coltype)

        def handle_raw_string(value):
            inner = re.match(r'^{(.*)}$', value).group(1)
            return inner.split(',') if inner else []

        def process(value):
            if value is None:
                return None
            return super_rp(handle_raw_string(value))

        return process


class SQLAlchemy(FlaskSQLAlchemy):
    """
    Enhances :class:`flask_sqlalchemy.SQLAlchemy` by adding our
    Session and Model.
    """

    StrictVersionType = StrictVersionType
    URL = URL
    IP = IP
    IntEnum = IntEnum
    UUIDLtree = UUIDLtree
    ArrayOfEnum = ArrayOfEnum

    def __init__(
        self,
        app=None,
        use_native_unicode=True,
        session_options=None,
        metadata=None,
        query_class=BaseQuery,
        model_class=Model,
    ):
        super().__init__(
            app, use_native_unicode, session_options, metadata, query_class, model_class
        )

    def create_session(self, options):
        """As parent's create_session but adding our Session."""
        return sessionmaker(class_=Session, db=self, **options)


class SchemaSQLAlchemy(SQLAlchemy):
    """
    Enhances :class:`flask_sqlalchemy.SQLAlchemy` by using PostgreSQL's
    schemas when creating/dropping tables.

    See :attr:`teal.config.SCHEMA` for more info.
    """

    def __init__(
        self,
        app=None,
        use_native_unicode=True,
        session_options=None,
        metadata=None,
        query_class=Query,
        model_class=Model,
    ):
        super().__init__(
            app, use_native_unicode, session_options, metadata, query_class, model_class
        )
        # The following listeners set psql's search_path to the correct
        # schema and create the schemas accordingly

        # Specifically:
        # 1. Creates the schemas and set ``search_path`` to app's config SCHEMA
        event.listen(self.metadata, 'before_create', self.create_schemas)
        # Set ``search_path`` to default (``public``)
        event.listen(self.metadata, 'after_create', self.revert_connection)
        # Set ``search_path`` to app's config SCHEMA
        event.listen(self.metadata, 'before_drop', self.set_search_path)
        # Set ``search_path`` to default (``public``)
        event.listen(self.metadata, 'after_drop', self.revert_connection)

    def create_all(self, bind='__all__', app=None, exclude_schema=None):
        """Create all tables.

        :param exclude_schema: Do not create tables in this schema.
        """
        app = self.get_app(app)
        # todo how to pass exclude_schema without contaminating self?
        self._exclude_schema = exclude_schema
        super().create_all(bind, app)

    def _execute_for_all_tables(self, app, bind, operation, skip_tables=False):
        # todo how to pass app to our event listeners without contaminating self?
        self._app = self.get_app(app)
        super()._execute_for_all_tables(app, bind, operation, skip_tables)

    def get_tables_for_bind(self, bind=None):
        """As super method, but only getting tales that are not
        part of exclude_schema, if set.
        """
        tables = super().get_tables_for_bind(bind)
        if getattr(self, '_exclude_schema', None):
            tables = [t for t in tables if t.schema != self._exclude_schema]
        return tables

    def create_schemas(self, target, connection, **kw):
        """
        Create the schemas and set the active schema.

        From `here <https://bitbucket.org/zzzeek/sqlalchemy/issues/3914/
        extend-create_all-drop_all-to-include#comment-40129850>`_.
        """
        schemas = set(table.schema for table in target.tables.values() if table.schema)
        if self._app.schema:
            schemas.add(self._app.schema)
        for schema in schemas:
            connection.execute('CREATE SCHEMA IF NOT EXISTS {}'.format(schema))
        self.set_search_path(target, connection)

    def set_search_path(self, _, connection, **kw):
        app = self.get_app()
        if app.schema:
            connection.execute('SET search_path TO {}, public'.format(app.schema))

    def revert_connection(self, _, connection, **kw):
        connection.execute('SET search_path TO public')

    def create_session(self, options):
        """As parent's create_session but adding our SchemaSession."""
        return sessionmaker(class_=SchemaSession, db=self, **options)

    def drop_schema(self, app=None, schema=None):
        """Nukes a schema and everything that depends on it."""
        app = self.get_app(app)
        schema = schema or app.schema
        with self.engine.begin() as conn:
            conn.execute('DROP SCHEMA IF EXISTS {} CASCADE'.format(schema))

    def has_schema(self, schema: str) -> bool:
        """Does the db have the passed-in schema?"""
        return self.engine.execute(
            "SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_namespace WHERE nspname='{}')".format(
                schema
            )
        ).scalar()


class DBError(BadRequest):
    """An Error from the database.

    This helper error is used to map SQLAlchemy's IntegrityError
    to more precise errors (like UniqueViolation) that are understood
    as a client-ready HTTP Error.

    When instantiating the class it auto-selects the best error.
    """

    def __init__(self, origin: IntegrityError):
        super().__init__(str(origin))
        self._origin = origin

    def __new__(cls, origin: IntegrityError) -> Any:
        msg = str(origin)
        if 'unique constraint' in msg.lower():
            return super().__new__(UniqueViolation)
        return super().__new__(cls)


class UniqueViolation(DBError):
    def __init__(self, origin: IntegrityError):
        super().__init__(origin)
        self.constraint = self.description.split('"')[1]
        self.field_name = None
        self.field_value = None
        if isinstance(origin.params, dict):
            self.field_name, self.field_value = next(
                (k, v) for k, v in origin.params.items() if k in self.constraint
            )