__author__ = "Vanessa Sochat"
__copyright__ = "Copyright 2022, Vanessa Sochat"
__license__ = "MPL 2.0"
from riverapi.logger import logger
from riverapi.auth import parse_auth_header
import riverapi.defaults as defaults
from copy import deepcopy
import base64
import os
import json
import dill
import requests
[docs]class Client:
"""
Interact with a River Server
"""
def __init__(self, baseurl=None, quiet=False, prefix="api"):
self.baseurl = (baseurl or defaults.baseurl).strip("/")
self.quiet = quiet
self.flavors = [
"regression",
"binary",
"creme",
"multiclass",
"cluster",
"custom",
"neighbor",
]
self.session = requests.session()
self.headers = {"Accept": "application/json", "User-Agent": "riverapi-python"}
self.prefix = prefix
self.getenv()
def __repr__(self):
return str(self)
def __str__(self):
return "[riverapi-client]"
@property
def apiroot(self):
"""
Combine the baseurl and prefix to get the complete root.
"""
return self.baseurl + "/" + self.prefix.strip("/")
[docs] def check(self):
"""
The user can run check to perform a service info, and update the
prefix or baseurl if the server provides different ones.
"""
info = self.info()
for field in ["prefix", "baseurl"]:
if field in info:
updated = info[field].strip("/")
print("Updating %s to %s" % (field, updated))
setattr(self, field, updated)
[docs] def getenv(self):
"""
Get any token / username set in the environment
"""
self.token = os.environ.get("RIVER_ML_TOKEN")
self.user = os.environ.get("RIVER_ML_USER")
[docs] def check_flavor(self, flavor):
"""
Verify that the flavor is known
"""
if flavor not in self.flavors:
logger.exit(
"%s is not a valid flavor. Choices are: %s"
% (flavor, " ".join(self.flavors))
)
[docs] def check_response(self, typ, r, return_json=True, stream=False, retry=True):
"""
Ensure the response status code is 20x
"""
if r.status_code == 401 and retry:
if self.authenticate_request(r):
r.request.headers.update(self.headers)
r = self.session.send(r.request)
# Call itself once more just to check the status code
return self.check_response(typ, r, return_json, stream, retry=False)
if r.status_code not in [200, 201]:
logger.exit("Unsuccessful response: %s, %s" % (r.status_code, r.reason))
# All data is typically json
if return_json and not stream:
return r.json()
return r
[docs] def set_basic_auth(self, username, password):
"""
A wrapper to adding basic authentication to the Request
"""
auth_str = "%s:%s" % (username, password)
auth_header = base64.b64encode(auth_str.encode("utf-8"))
self.set_header("Authorization", "Basic %s" % auth_header.decode("utf-8"))
[docs] def authenticate_request(self, originalResponse):
"""
Authenticate the request.
Given a response (an HTTPError 401), look for a Www-Authenticate
header to parse. We return True/False to indicate if the request
should be retried.
"""
authHeaderRaw = originalResponse.headers.get("Www-Authenticate")
if not authHeaderRaw:
return False
# If we have a username and password, set basic auth automatically
if self.token and self.user:
self.set_basic_auth(self.user, self.token)
headers = deepcopy(self.headers)
if "Authorization" not in headers:
logger.exit(
"This endpoint requires a token. Please export RIVER_ML_TOKEN and RIVER_ML_USER first."
)
return False
# Prepare request to retry
h = parse_auth_header(authHeaderRaw)
headers.update(
{
"service": h.Service,
"Accept": "application/json",
"User-Agent": "riverapi-python",
}
)
# Currently we don't set a scope (it defaults to build)
try:
authResponse = requests.get(h.Realm, headers=headers).json()
except:
logger.exit("Failed to get token from %s" % h.Realm)
# Request the token
token = authResponse.get("token")
if not token:
return False
# Set the token to the original request and retry
self.headers.update({"Authorization": "Bearer %s" % token})
return True
[docs] def print_response(self, r):
"""
Print the result of a response
"""
response = r.json()
logger.info("%s: %s" % (r.url, json.dumps(response, indent=4)))
[docs] def info(self):
"""
Get basic server information
"""
return self.get("/")
[docs] def do_request(
self,
typ,
url,
data=None,
json=None,
headers=None,
return_json=True,
stream=False,
):
"""
Do a request (get, post, etc)
"""
# If we have a cached token, use it!
headers = headers or {}
headers.update(self.headers)
if not self.quiet:
logger.info("%s %s" % (typ.upper(), url))
# The first post when you upload the model defines the flavor (regression)
if json:
r = requests.request(
typ, self.apiroot + url, json=json, headers=headers, stream=stream
)
else:
r = requests.request(
typ, self.apiroot + url, data=data, headers=headers, stream=stream
)
if not self.quiet and not stream and return_json:
self.print_response(r)
return self.check_response(typ, r, return_json=return_json, stream=stream)
[docs] def post(self, url, data=None, json=None, headers=None, return_json=True):
"""
Perform a POST request
"""
return self.do_request(
"post", url, data=data, json=json, headers=headers, return_json=return_json
)
[docs] def delete(self, url, data=None, json=None, headers=None, return_json=True):
"""
Perform a DELETE request
"""
return self.do_request(
"delete",
url,
data=data,
json=json,
headers=headers,
return_json=return_json,
)
[docs] def get(
self, url, data=None, json=None, headers=None, return_json=True, stream=False
):
"""
Perform a GET request
"""
return self.do_request(
"get",
url,
data=data,
json=json,
headers=headers,
return_json=return_json,
stream=stream,
)
[docs] def upload_model(self, model, flavor, model_name=None):
"""
Given a model / pipeline, upload to an online-ml server.
model = preprocessing.StandardScaler() | linear_model.LinearRegression()
"""
self.check_flavor(flavor)
if model_name:
r = self.post(
"/model/%s/%s/" % (flavor, model_name), data=dill.dumps(model)
)
else:
r = self.post("/model/%s/" % flavor, data=dill.dumps(model))
model_name = r["name"]
logger.info("Created model %s" % model_name)
return model_name
[docs] def label(self, label, identifier, model_name):
"""
Given a label we know for a prediction after the fact (which we can
look up with an identifier from the server), use the label endpoint
to update the model metrics and call learn one. Note that the
model_name is not technically required (it's stored with the cached
entry) however we require providing it to validate the association.
If you have a label at the time of running predict you can use it
then and should not need this endpoint. Also note that ground_truth
of a prediction is synonymous with label here.
"""
return self.post(
"/label/",
json={"model": model_name, "identifier": identifier, "label": label},
)
[docs] def learn(self, model_name, x, y=None):
"""
Train on some data. You are required to provide at least the model
name known to the server and x (data).
for x, y in datasets.TrumpApproval().take(100):
cli.train(x, y)
"""
return self.post(
"/learn/", json={"model": model_name, "features": x, "ground_truth": y}
)
[docs] def delete_model(self, model_name):
"""
Delete a model by name
"""
return self.delete("/model/", data={"model": model_name})
[docs] def get_model_json(self, model_name):
"""
Get a json respresentation of a model.
"""
return self.get("/model/%s/" % model_name)
[docs] def download_model(self, model_name, dest=None):
"""
Download a model to file (e.g., pickle)
with open("muffled-pancake-9439.pkl", "rb") as fd:
content=pickle.load(fd)
"""
# Get the model (this is a download of the pickled model with dill)
r = self.get("/model/download/%s/" % model_name, return_json=False)
# Default to pickle in PWD
dest = dest or "%s.pkl" % model_name
# Save to dest file
with open(dest, "wb") as f:
for chunk in r:
f.write(chunk)
return dest
[docs] def predict(self, model_name, x):
"""
Make a prediction
"""
return self.post("/predict/", json={"model": model_name, "features": x})
[docs] def models(self):
"""
Get a listing of known models
"""
return self.get("/models/")
[docs] def stats(self, model_name):
"""
Get stats for a model name
"""
return self.get("/stats/", json={"model": model_name})
[docs] def metrics(self, model_name):
"""
Get metrics for a model name
"""
return self.get("/metrics/", json={"model": model_name})
[docs] def stream(self, url):
"""
General stream endpoint
"""
with self.get(url, stream=True, return_json=False) as r:
for line in r.iter_lines():
if line:
if isinstance(line, bytes):
line = line.decode("utf-8")
yield line
[docs] def stream_metrics(self):
"""
Stream metrics
"""
return self.stream("/stream/metrics/")
[docs] def stream_events(self):
"""
Stream events
"""
return self.stream("/stream/events/")