Source code for modelhubapi.pythonapi

import os
import io
import json
import time
from datetime import datetime
import numpy
import h5py

[docs]class ModelHubAPI: """ Generic interface to access a model. """ def __init__(self, model, contrib_src_dir): self.model = model self.output_folder = '/output' self.contrib_src_dir = contrib_src_dir this_dir = os.path.dirname(os.path.realpath(__file__)) self.framework_dir = os.path.normpath(os.path.join(this_dir, ".."))
[docs] def get_config(self): """ Returns: dict: Model configuration. """ config_file_path = self.contrib_src_dir + "/model/config.json" return self._load_json(config_file_path)
[docs] def get_model_io(self): """ Returns: dict: The model's input/output sizes and types as dictionary. Convenience function, as this is a subset of what :func:`~get_config` returns """ config_file_path = self.contrib_src_dir + "/model/config.json" config = self._load_json(config_file_path) if "error" in config: return config else: return config["model"]["io"]
[docs] def get_samples(self): """ Returns: dict: Folder and file names of sample data bundled with this model. The diconary key "folder" holds the absolute path to the sample data folder in the model container. The key "files" contains a list of all file names in that folder. Join these together to get the full path to the sample files. """ try: sample_data_dir = self.contrib_src_dir + "/sample_data" _, _, sample_files = next(os.walk(sample_data_dir)) return {"folder": sample_data_dir, "files": sample_files} except Exception as e: return {'error': repr(e)}
[docs] def predict(self, input_file_path, numpyToFile=True, url_root=""): """ Preforms the model's inference on the given input. Args: input_file_path (str): Path to input file to run inference on. numpyToFile (bool): Only effective if prediction is a numpy array. Indicates if numpy outputs should be saved and a path to it is returned. If false, a json-serializable list representation of the numpy array is returned instead. List representations is very slow with large numpy arrays. url_root (str): Url root added by the rest api. Returns: dict, list, or numpy array: Prediction result on input data. Return type/foramt as specified in the model configuration (see :func:`~get_model_io`). In case of an error, returns a dictionary with error info. """ try: config = self.get_config() start = time.time() output = self.model.infer(input_file_path) output = self._correct_output_list_wrapping(output, config) end = time.time() output_list = [] for i, o in enumerate(output): name = config["model"]["io"]["output"][i]["name"] shape = list(o.shape) if isinstance(o, numpy.ndarray) else [len(o)] if isinstance(o, numpy.ndarray): o = url_root + "api" + self._save_output(o, name) if numpyToFile else o.tolist() output_list.append({ 'prediction': o, 'shape': shape, 'type': config["model"]["io"]["output"][i]["type"], 'name': name, 'description': config["model"]["io"]["output"][i]["description"] if "description" in config["model"]["io"]["output"][i].keys() else "" }) return {'output': output_list, 'timestamp':"%Y-%m-%d-%H-%M-%S-%f"), 'processing_time': round(end-start, 3), 'model': { "id": config["id"], "name": config["meta"]["name"] } } except Exception as e: return {'error': repr(e)}
# ------------------------------------------------------------------------- # Private helper functions # ------------------------------------------------------------------------- def _load_txt_as_dict(self, file_path, return_key): try: with, mode='r', encoding='utf-8') as f: txt = return {return_key: txt} except Exception as e: return {'error': str(e)} def _load_json(self, file_path): try: with, mode='r', encoding='utf-8') as f: loaded_dict = json.load(f) return loaded_dict except Exception as e: return {'error': str(e)} def _correct_output_list_wrapping(self, output, config): if not isinstance(output, list): return [output] elif isinstance(output, list) and len(config["model"]["io"]["output"])==1: return [output] elif isinstance(output, list) and len(config["model"]["io"]["output"])>1: return output else: return [{'error': "output formatting does not match output specifications in config file"}] def _save_output(self, output, name): now = path = os.path.join(self.output_folder, "%s.%s" % (now.strftime("%Y-%m-%d-%H-%M-%S-%f"), "h5")) h5f = h5py.File(path, 'w') dataset = h5f.create_dataset(name, data=output) dataset.attrs["type"] = numpy.string_(str(output.dtype)) h5f.close() return path