import abc
import contextlib
import logging
import os
import shelve
import dill
import redis
import river.base
import river.metrics
import river.stats
import river.utils
from django_river_ml import settings
from . import exceptions, flavors, namer
logger = logging.getLogger(__name__)
[docs]
class StorageBackend(abc.ABC):
    """Abstract storage backend.
    This interface defines a set of methods to implement in order for a database to be used as a
    storage backend. This allows using different databases in a homogeneous manner by proving a
    single interface. Since online-ml models are largely defined by Python dictionaries, we use
    key value store databases like redis.
    """
    @abc.abstractmethod
    def __setitem__(self, key, obj):
        """Store an object."""
    @abc.abstractmethod
    def __getitem__(self, key):
        """Retrieve an object."""
    @abc.abstractmethod
    def __delitem__(self, key):
        """Remove an object from storage."""
    @abc.abstractmethod
    def __iter__(self):
        """Iterate over the keys."""
[docs]
    @abc.abstractmethod
    def close(self):
        """Do something when the app shuts down.""" 
[docs]
    def get(self, key, default=None):
        try:
            return self[key]
        except KeyError:
            return default 
 
[docs]
class ShelveBackend(shelve.DbfilenameShelf, StorageBackend):  # type: ignore
    """Storage backend based on the shelve module from the standard library.
    This should mainly be used for development and testing, but not production.
    """ 
[docs]
class RedisBackend(StorageBackend):
    """Redis is the suggest backend for a more production database."""
    def __init__(self, host, port, db):
        self.r = redis.Redis(host=host, port=port, db=db)
    def __setitem__(self, key, obj):
        self.r[key] = dill.dumps(obj)
    def __getitem__(self, key):
        return dill.loads(self.r[key])
    def __delitem__(self, key):
        self.r.delete(key)
    def __iter__(self):
        for key in self.r.scan_iter():
            yield key.decode()
[docs]
    def close(self):
        return 
 
# The following will make it so that shelve.open returns ShelveBackend instead of DbfilenameShelf
shelve.DbfilenameShelf = ShelveBackend  # type: ignore
[docs]
def get_db() -> StorageBackend:
    """
    Get the database, an attribute of settings.
    """
    if not hasattr(settings, "db"):
        backend = settings.STORAGE_BACKEND
        if backend == "shelve":
            settings.db = shelve.open(settings.SHELVE_PATH)
        elif backend == "redis":
            settings.db = RedisBackend(
                host=settings.REDIS_HOST,
                port=int(settings.REDIS_PORT),
                db=int(settings.REDIS_DB),
            )
        else:
            raise ValueError(f"Unknown storage backend: {backend}")
    return settings.db 
[docs]
def close_db(e=None):
    if hasattr(settings, "db"):
        if settings.db is not None:
            settings.db.close()
        delattr(settings, "db") 
[docs]
def drop_db():
    """This function's responsability is to wipe out a database.
    This could be implement within each StorageBackend, it's just a bit more akward because at this
    point the database connection is not stored in the app anymore.
    """
    backend = settings.STORAGE_BACKEND
    if backend == "shelve":
        path = settings.SHELVE_PATH
        with contextlib.suppress(FileNotFoundError):
            os.remove(f"{path}")
    elif backend == "redis":
        r = redis.Redis(
            host=settings.REDIS_HOST,
            port=int(settings.REDIS_PORT or 6379),
            db=int(settings.REDIS_DB or 0),
        )
        r.flushdb() 
[docs]
def set_flavor(flavor: str, name: str):
    try:
        flavor = flavors.allowed_flavors()[flavor]
    except KeyError:
        raise exceptions.UnknownFlavor
    db = get_db()
    db[f"flavor/{name}"] = flavor 
[docs]
def init_stats(name: str):
    db = get_db()
    db[f"stats/{name}"] = {
        "learn_mean": river.stats.Mean(),
        "learn_ewm": river.stats.EWMean(0.3),
        "predict_mean": river.stats.Mean(),
        "predict_ewm": river.stats.EWMean(0.3),
    } 
[docs]
def init_metrics(name: str):
    db = get_db()
    try:
        flavor = db[f"flavor/{name}"]
    except KeyError:
        raise exceptions.FlavorNotSet
    db[f"metrics/{name}"] = flavor.default_metrics() 
[docs]
def add_model(model: river.base.Estimator, flavor: str, name: str = None) -> str:
    db = get_db()
    # Pick a name if none is given
    if name is None:
        while True:
            name = _random_slug()
            if f"models/{name}" not in db:
                break
    # Make sure flavor is valid before continuing
    # it will be associated with the model name
    set_flavor(flavor=flavor, name=name)
    db[f"models/{name}"] = model
    init_stats(name)
    init_metrics(name)
    return name 
[docs]
def delete_model(name: str):
    db = get_db()
    del db[f"models/{name}"]
    del db[f"stats/{name}"]
    del db[f"metrics/{name}"] 
def _random_slug() -> str:
    return namer.namer.generate()