Source code for ventu.protocol

import json
import logging
import socket
import struct

import msgpack
from pydantic import ValidationError


[docs]class BatchProtocol: """ protocol used to communicate with batching service :param infer: model infer function (contains `preprocess`, `batch_inference` and `postprocess`) :param req_schema: request schema defined with `pydantic` :param resp_schema: response schema defined with `pydantic` :param bool use_msgpack: use msgpack for serialization or not (default: JSON) """ STRUCT_FORMAT = '!i' INT_BYTE_SIZE = 4 INIT_MESSAGE = struct.pack(STRUCT_FORMAT, 0) def __init__(self, infer, req_schema, resp_schema, use_msgpack): self.req_schema = req_schema self.resp_schema = resp_schema self.use_msgpack = use_msgpack self.packer = msgpack.Packer(autoreset=True, use_bin_type=True) self.logger = logging.getLogger(__name__) self.infer = infer self.sock = None def _pack(self, data): return self.packer.pack(data) if self.use_msgpack else json.dumps(data) def _unpack(self, data): return msgpack.unpackb(data, raw=False) if self.use_msgpack else json.loads(data) def _init_request(self, conn): self.logger.info('Send init message') conn.sendall(self.INIT_MESSAGE) def _request(self, conn): length_bytes = conn.recv(self.INT_BYTE_SIZE) length = struct.unpack(self.STRUCT_FORMAT, length_bytes)[0] data = conn.recv(length) return data
[docs] def process(self, conn): """ process batch queries and return the inference results :param conn: socket connection """ batch = msgpack.unpackb(self._request(conn), raw=False) ids = list(batch.keys()) self.logger.debug(f'Received job ids: {ids}') # validate request validated = [] errors = [] for i, byte in enumerate(batch.values()): try: data = self._unpack(byte) obj = self.req_schema.parse_obj(data) validated.append(obj) self.logger.debug(f'{obj} passes the validation') except ValidationError as err: errors.append((i, self._pack(err.errors()))) self.logger.info( f'Job {ids[i]} validation error', extra={'Validation': err.errors()} ) except (json.JSONDecodeError, msgpack.ExtraData, msgpack.FormatError, msgpack.StackError) as err: errors.append((i, self._pack(str(err)))) self.logger.info(f'Job {ids[i]} error: {err}') # inference self.logger.debug(f'Validated: {validated}, Errors: {errors}') result = [] if validated: result = self.infer(validated) assert len(result) == len(validated), ( 'Wrong number of inference results. ' f'Expcet {len(validated)}, get{len(result)}.' ) # validate response for data in result: self.resp_schema.parse_obj(data) # add errors information err_ids = '' result = [self._pack(data) for data in result] for index, err_msg in errors: err_ids += ids[index] result.insert(index, err_msg) # build batch job table resp = dict(zip(ids, result)) if err_ids: resp['error_ids'] = err_ids self._response(conn, resp)
def _response(self, conn, data): data = self.packer.pack(data) conn.sendall(struct.pack(self.STRUCT_FORMAT, len(data))) conn.sendall(data)
[docs] def run(self, addr, protocol='unix'): """ run socket communication this should run **after** the socket file is created by the batching service :param string protocol: 'unix' or 'tcp' :param addr: socket file path or (host:str, port:int) """ self.sock = socket.socket( socket.AF_UNIX if protocol == 'unix' else socket.AF_INET, socket.SOCK_STREAM, ) self.logger.info(f'Connect to socket: {addr}') while True: try: self.sock.connect(addr) self.logger.info(f'Connect to {self.sock.getpeername()}') self._init_request(self.sock) while True: self.process(self.sock) except BrokenPipeError as err: self.logger.warning(f'Broken socket: {err}') continue
[docs] def stop(self): """ stop the socket communication """ self.logger.info('Close socket') self.sock.close()