Source code for modelhubapi.restapi

from flask import Flask, jsonify, abort, make_response, \
                    send_file, url_for, send_from_directory, request
from .pythonapi import ModelHubAPI
import os
import io
import json
import shutil
from mimetypes import MimeTypes
import requests
from datetime import datetime
from flask_cors import CORS
import magic
import re


[docs]class ModelHubRESTAPI: def __init__(self, model, contrib_src_dir): self.app = Flask(__name__) CORS(self.app) self.model = model self.contrib_src_dir = contrib_src_dir self.working_folder = '/working' self.api = ModelHubAPI(model, contrib_src_dir) # routes self.app.add_url_rule('/api/samples/<sample_name>', 'samples', self._samples) self.app.add_url_rule('/api/thumbnail/<thumbnail_name>', 'thumbnail', self._thumbnail) self.app.add_url_rule('/api/output/<output_name>', 'output', self._output) # primary REST API calls self.app.add_url_rule('/api/get_config', 'get_config', self.get_config) self.app.add_url_rule('/api/get_legal', 'get_legal', self.get_legal) self.app.add_url_rule('/api/get_model_io', 'get_model_io', self.get_model_io) self.app.add_url_rule('/api/get_model_files', 'get_model_files', self.get_model_files) self.app.add_url_rule('/api/get_samples', 'get_samples', self.get_samples) self.app.add_url_rule('/api/predict', 'predict', self.predict, methods=['GET', 'POST']) self.app.add_url_rule('/api/predict_sample', 'predict_sample', self.predict_sample)
[docs] def get_config(self): """ GET method Returns: application/json: Model configuration dictionary. """ return self._jsonify(self.api.get_config())
[docs] def get_model_io(self): """ GET method Returns: application/json: The model's input/output sizes and types as dictionary. Convenience function, as this is a subset of what :func:`~get_config` returns """ return self._jsonify(self.api.get_model_io())
[docs] def get_model_files(self): """ GET method Returns: application/zip: The trained deep learning model in its native format and all its asscociated files in a single zip folder. """ # TODO # * This returns a error: [Errno 32] Broken pipe when url is typed # into chrome and before hitting enter - chrome sends # request earlier, and this messes up with flask. # * Currently no mechanism for catching errors (except what flask # will catch). try: model_name = self.api.get_config()["meta"]["name"].lower() archive_name = os.path.join(self.working_folder, model_name + "_model") shutil.make_archive(archive_name, "zip", self.contrib_src_dir, "model") return send_file(archive_name + ".zip", as_attachment=True) except Exception as e: return self._jsonify({'error': str(e)})
[docs] def get_samples(self): """ GET method Returns: application/json: List of URLs to all sample files associated with the model. """ try: url = request.url url = url.replace("api/get_samples", "api/samples/") samples = [url + sample_name for sample_name in self.api.get_samples()["files"]] return self._jsonify(samples) except Exception as e: return self._jsonify({'error': str(e)})
[docs] def predict(self): """ GET/POST method Returns: application/json: Prediction result on input data. Return type/format as specified in the model configuration (see :func:`~get_model_io`), and wrapped in json. In case of an error, returns a dictionary with error info. GET method Args: fileurl: URL to input data for prediciton. Input type must match specification in the model configuration (see :func:`~get_model_io`) URL must not contain any arguments and should end with the file extension. GET Example: :code: `curl -X GET http://localhost:80/api/predict?fileurl=<URL_OF_FILE>` POST method Args: file: Input file with data for prediction. Input type must match specification in the model configuration (see :func:`~get_model_io`) POST Example: :code: `curl -i -X POST -F file=@<PATH_TO_FILE> `http://localhost:80/api/predict` """ try: file_name, mime_type = self._save_file_get_mime_type(request) if str(mime_type) in self._get_allowed_extensions(): file_name = self._check_multi_inputs(file_name) result = self._jsonify(self.api.predict(file_name, url_root=request.url_root)) self._delete_temp_files(self.working_folder) return result else: return self._jsonify({'error': 'Incorrect file type.'}) except Exception as e: return self._jsonify({'error': str(e)})
[docs] def predict_sample(self): """ GET method Performs prediction on sample data. .. note:: Currently you cannot use :func:`~predict` for inference on sample data hosted under the same IP as the model API. This function is a temporary workaround. To be removed in the future. Returns: application/json: Prediciton result on input data. Return type as specified in the model configuration (see :func:`~get_model_io`), and wrapped in json. In case of an error, returns a dictionary with error info. Args: filename: File name of the sample data. No folders or URLs. """ try: if request.method == 'GET': file_name = request.args.get('filename') file_name = self.contrib_src_dir + "/sample_data/" + file_name if os.path.isfile(file_name): result = self.api.predict(str(file_name), url_root=request.url_root) return self._jsonify(result) else: return self._jsonify( {'error': 'The given sample file does not exist.'}) except Exception as e: return self._jsonify({'error': str(e)})
def start(self): """ Starts the flask app. """ self.app.run(host='0.0.0.0', port=80, threaded=True) # ------------------------------------------------------------------------- # Private helper functions # ------------------------------------------------------------------------- def _delete_temp_files(self, folder): """ Removes all files in the given folder """ try: for file in os.listdir(folder): file_path = os.path.join(folder, file) if os.path.isfile(file_path): os.unlink(file_path) except Exception as e: print(e) def _jsonify(self, content): """ This helper function wraps the flask jsonify function, and also allows for error checking. This is usedfor calls that use the api._get_txt_file() function that returns a dict key "error" in case the file is not found. TODO * All errors are returned as 400. Would be better to customize the error code based on the actual error. """ response = jsonify(content) if (type(content) is dict) and ("error" in content.keys()): response.status_code = 400 return response def _check_multi_inputs(self, file_name): """ If file_name is a path to a json file, the file is loaded and processed as follows: * the function iterates through each key containing files or urls, if it is a local path, it is left untouched, but if it contains a url (_check_if_url returns True), the file is downloaded * the path to the downloaded file is then returned and replaces the URL in the dictionary * afterwards, the new dictionary is dumped to a new json file containing only local paths the path to the new json file is then returned. If the passed path is no json file, it is simply returned unchanged. """ if file_name.lower().endswith('.json'): input_dict = self.api._load_json(file_name) for key, value in input_dict.items(): if key == "format": continue elif self._check_if_url(value["fileurl"]): input_dict[key]["fileurl"] = \ self._save_input_from_url(value["fileurl"], value["type"]) else: print("Local path found: " + value["fileurl"]) now = datetime.now() file_name = os.path.join(self.working_folder, "%s%s" % (now.strftime("%Y-%m-%d-%H-%M-%S-%f"), '.json')) # dump to file self.api._write_json(file_name, input_dict) return file_name def _check_if_url(self, candidate): """ Uses a standard regex to check if the passed string is a valid URL. If it is (and the string only contains a single URL), it returns true, otherwise false. Raises an IOError if multiple URLs are contained in the string. """ url = re.findall('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+] | \ [!*\(\), ]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', candidate) if len(url) == 1: return True elif len(url) == 0: return False else: raise IOError("Multiple URLs detected in the input json!") def _save_input_from_url(self, url, type): """ This function downloads an arbitrary file from a URL and saves it first without extension in its raw format and then adds the right file extension after looking up its mime type. Args: url (str): the url pointing to the file to download type (list): the mime type of the file Issues: if the resource at the url is unresponsive, get may never time out and hang indefinitely. """ now = datetime.now() r = requests.get(url, stream=True) file_name = str(url).split('/')[-1] file_ext = self._modify_mime_types_inv()[type[0]][0] file_path = os.path.join(self.working_folder, "%s" % (now.strftime("%Y-%m-%d-%H-%M-%S-%f"))) with open(file_path, 'wb') as f: for chunk in r.iter_content(chunk_size=512): f.write(chunk) # add extension file_path_with_ext = file_path + file_ext os.rename(file_path, file_path_with_ext) return file_path_with_ext def _samples(self, sample_name): """ Routing function for sample files that exist in contrib_src. """ return send_from_directory(self.contrib_src_dir + "/sample_data/", sample_name, cache_timeout=-1) def _output(self, output_name): """ Routing function for output files that may exist in the output folder. """ return send_from_directory(self.api.output_folder, output_name, cache_timeout=-1) def _thumbnail(self, thumbnail_name): """ Routing function for the thumbnail that exists in contrib_src. The thumbnail must be named "thumbnail.jpg". """ return send_from_directory(self.contrib_src_dir + "/model/", thumbnail_name, cache_timeout=-1) def _get_allowed_extensions(self): return self.api.get_model_io()["input"]["format"] def _get_file_name(self, mime_type=""): """ This utility function get the current date/time and returns a full path to save either an uploaded file or one grabbed through a url. If mimetype is provided, it will grab the appropriate extension, If no mimetype is provided, it will return the file without an extension. """ now = datetime.now() extension = self._modify_mime_types_inv()[mime_type][0] \ if mime_type != "" else mime_type file_name = os.path.join(self.working_folder, "%s%s" % (now.strftime("%Y-%m-%d-%H-%M-%S-%f"), extension)) return file_name def _save_file_get_mime_type(self, request): """ This utility checks first if the request method is POST or GET. It then saves the file with a unique date/time name but without an extension. Finally it uses magic to identify the mime type and change the filename to one with the correct extension. Returns both full path file name and mime type. """ if request.method == 'GET': file_url = request.args.get('fileurl') # cache file extension file_name_raw = str(file_url).split('/')[-1] file_ext_cache = file_name_raw.split(os.extsep, 1)[-1] r = requests.get(file_url) file_name = self._get_file_name() with open(file_name, 'wb') as f: f.write(r.content) elif request.method == 'POST': file = request.files.get('file') # cache file extension file_name_raw = str(file).split('/')[-1] file_ext_cache = file_name_raw.split(os.extsep, 1)[-1] file_name = self._get_file_name() file.save(file_name) _magic = magic.Magic(mime=True) mime_type = _magic.from_file(file_name) # checks if a catchall type has been set and takes action: if mime_type == "text/plain" or \ mime_type == "application/octet-stream": types = self._modify_mime_types() try: mime_type = types["."+file_ext_cache] except KeyError as e: raise KeyError("The file extension " + e + " is not supported.") file_name_with_extension = self._get_file_name(mime_type) os.rename(file_name, file_name_with_extension) return file_name_with_extension, mime_type def _modify_mime_types(self): """ Returns a dictionary with file extensions as keys Because some file extensions are not available in the MimeTypes library, this helper function modeifies it on the fly. This houses all the edge conditions. Open an issue to request a new extension. """ mime = MimeTypes() original_mime_types = mime.types_map[1] original_mime_types[".npy"] = ["application/octet-stream"] original_mime_types[".nii"] = ["application/nii"] original_mime_types[".nii.gz"] = ["application/nii-gzip"] original_mime_types[".nrrd"] = ["application/nrrd"] original_mime_types[".dcm"] = ["application/dicom"] return original_mime_types def _modify_mime_types_inv(self): """ Inverse: returns a dictionary with mime types as keys. Because some file extensions are not available in the MimeTypes library, this helper function modifies it on the fly. This houses all the edge conditions. Open an issue to request a new extension. """ mime = MimeTypes() original_mime_types = mime.types_map_inv[1] original_mime_types["application/octet-stream"] = [".npy"] original_mime_types["application/nii"] = [".nii"] original_mime_types["application/nii-gzip"] = [".nii.gz"] original_mime_types["application/nrrd"] = [".nrrd"] original_mime_types["application/dicom"] = [".dcm"] return original_mime_types