import abc
import typing
from river import metrics
# A "flavor" is a different kind of online-ml model.
# We need to know the correct check / prediction to do for different kinds!
# https://github.com/online-ml/chantilly/blob/master/chantilly/flavors.py#L1
[docs]
def all_models():
return [
NeighborFlavor,
RegressionFlavor,
BinaryFlavor,
MultiClassFlavor,
ClusterFlavor,
CustomFlavor,
CremeFlavor,
]
[docs]
def check(model, flavor_name):
"""
Check a model against all flavors available
"""
for flavor in all_models():
flavor = flavor()
if flavor_name == flavor.name:
return flavor.check_model(model)
return False, "This model flavor %s is not recognized" % flavor_name
[docs]
def allowed_flavors():
return {f().name: f() for f in all_models()}
[docs]
class Flavor(abc.ABC):
@abc.abstractproperty
def name(self):
pass
[docs]
@abc.abstractmethod
def check_model(self, model: typing.Any) -> typing.Tuple[bool, str]:
"""Checks whether or not a model works for a flavor."""
[docs]
@abc.abstractmethod
def default_metrics(self) -> typing.List[metrics.base.Metric]:
"""Default metrics to record globally as well as for each model."""
@abc.abstractproperty
def pred_funcs(self) -> str:
"""Listing of prediction functions to try (in that order)"""
@property
def learn_func(self):
"""
Learn function consistent for all models in river, creme is fit_one
"""
return "learn_one"
[docs]
class RegressionFlavor(Flavor):
@property
def name(self):
return "regression"
[docs]
def check_model(self, model):
for method in ("learn_one", "predict_one"):
if not hasattr(model, method):
return False, f"The model does not implement {method}."
return True, None
[docs]
def default_metrics(self):
return [metrics.MAE(), metrics.RMSE(), metrics.SMAPE()]
@property
def pred_funcs(self):
return ["predict_one"]
[docs]
class NeighborFlavor(Flavor):
@property
def name(self):
return "neighbor"
[docs]
def check_model(self, model):
for method in ("learn_one", "predict_one"):
if not hasattr(model, method):
return False, f"The model does not implement {method}."
return True, None
[docs]
def default_metrics(self):
return []
@property
def pred_funcs(self):
return ["predict_one"]
[docs]
class BinaryFlavor(Flavor):
@property
def name(self):
return "binary"
[docs]
def check_model(self, model):
for method in ("learn_one", "predict_proba_one"):
if not hasattr(model, method):
return False, f"The model does not implement {method}."
return True, None
[docs]
def default_metrics(self):
return [
metrics.Accuracy(),
metrics.LogLoss(),
metrics.Precision(),
metrics.Recall(),
metrics.F1(),
]
@property
def pred_funcs(self):
return ["predict_proba_one"]
[docs]
class MultiClassFlavor(Flavor):
@property
def name(self):
return "multiclass"
[docs]
def check_model(self, model):
for method in ("learn_one", "predict_proba_one"):
if not hasattr(model, method):
return False, f"The model does not implement {method}."
return True, None
[docs]
def default_metrics(self):
return [
metrics.Accuracy(),
metrics.CrossEntropy(),
metrics.MacroPrecision(),
metrics.MacroRecall(),
metrics.MacroF1(),
metrics.MicroPrecision(),
metrics.MicroRecall(),
metrics.MicroF1(),
]
@property
def pred_funcs(self):
return ["predict_one", "predict_proba_one"]
[docs]
class CustomFlavor(Flavor):
"""
A custom flavor aims to support a user custom model.
"""
@property
def name(self):
return "custom"
[docs]
def check_model(self, model):
return True, None
[docs]
def default_metrics(self):
return []
@property
def pred_funcs(self):
return ["predict_one", "predict_proba_one"]
[docs]
class ClusterFlavor(Flavor):
@property
def name(self):
return "cluster"
[docs]
def check_model(self, model):
for method in ("learn_one", "predict_one"):
if not hasattr(model, method):
return False, f"The model does not implement {method}."
return True, None
[docs]
def default_metrics(self):
return []
@property
def pred_funcs(self):
return ["predict_one"]
[docs]
class CremeFlavor(Flavor):
@property
def name(self):
return "creme"
[docs]
def check_model(self, model):
for method in ("fit_one", "predict_one"):
if not hasattr(model, method):
return False, f"The model does not implement {method}."
return True, None
[docs]
def default_metrics(self):
return []
@property
def learn_func(self):
return "fit_one"
@property
def pred_funcs(self):
return ["predict_one"]