Yet another deep learning serving framework that is easy to use.

Previously, I tested the performance of some deep learning serving frameworks like TensorFlow Serving, Triton, and I found that these frameworks are not that easy to use. By the way, they don't have much advantage in the performance. So I just write one as a prototype.

Feel free to give it a try.

Basic features

  • serve the deep learning models (HTTP)
  • preprocess and postprocess (optional)
  • dynamic batching (increase the throughput)
  • load balancing (idle workers first)
  • monitoring metrics (Prometheus)
  • health check (need to provide examples)
  • request & response validation
  • model inference warm-up (need to provide examples)
  • OpenAPI document
  • supports both JSON and msgpack serialization

Advantages

  • support all kinds of deep learning runtime
  • easy to implement the preprocess and postprocess part
  • validation for request
  • health check and warm-up with examples
  • OpenAPI document

Design

                                       +-----------+
                                       | Inference |
                               +-------+ Worker    |
                               |       +-----------+
+----------+        +--------+ |
|          |        |        +-+       +-----------+
| Batching |        | Unix   |         | Inference |
| Service  +--------> Domain +---------+ Worker    |
| (HTTP)   <--------+ Socket |         +-----------+
|          |        |        +-+
+----------+        +--------+ |       +-----------+
                               |       | Inference |
                               +-------+ Worker    |
                                       +-----------+

Dynamic Batching

To implement the dynamic batching, we need a high-performance job queue that can be consumed by multiple workers. Go channel will be a good choice. In this situation, we have one producer and multiple consumers, so it's very easy to close the channel for the graceful shutdown.

type Batching struct {
	Name       string // socket name
	socket     net.Listener
	maxLatency time.Duration // max latency for a batch inference to wait
	batchSize  int // max batch size for a batch inference
	capacity   int // the capacity of the batching queue
	timeout    time.Duration // timeout for jobs in the queue
	logger     *zap.Logger
	queue      chan *Job // job queue
	jobs       map[string]*Job // use job id as the key to find the job
	jobsLock   sync.Mutex // lock for jobs
}

For jobs in this queue, we need to create a UUID as a key. So after the inference, we can find this job by searching the key in a hash table. That means we also need a mutex for the hash table.

type Job struct {
	id        string
	done      chan bool
	data      []byte // request data
	result    []byte // inference result or error message
	errorCode int // HTTP Error Code
	expire    time.Time
}

Because the batching service and Python inference workers are on the same machine (or the same pod), so the most efficient communication should be the Unix domain socket. And we also need to define a simple protocol for our use case. Since we only need to transfer the data of a batch jobs, let's keep everything as simple as we can.

| length  |       data        |
| 4 bytes |   {length} bytes  |

Data used in this protocol is a hash table <string:bytes>.

  1. workers send the first request with empty data to the batching service
  2. batching service collects a batch of jobs and sends to the workers
  3. worker processes these jobs
    • preprocess (for a single job)
    • inference (for a batch of jobs)
    • postprocess (for a single job)
    • send to the results to the batching service
  4. batching service notifies the handler that this job is done, then the handler sends the result to the original client and goes to #2

Error handling

  • timeout

If a job is not processed by one of the workers for a long time, the batching service will delete this job from the hash table and return 408.

When the batching service tries to collect these jobs from the queue channel, it will check the expire attribute first.

  • validation error

To make sure the requested data is valid, we use pydantic to do the validation. So the user needs to define the data schema with pydantic.

If one job data is invalid, this one will be marked and the result for this job is the validation error message generated by pydantic. And this won't affect other jobs in the same batch. That part is handled by the ventu.

  • pass the errors from the workers to the batching service

The data received by workers will be validated first. If some of the jobs are invalided, the job ids will be put into a error ids array. Only the valid jobs will be processed through the preprocess -> inference -> postprocess progress. After that, the results will be:

{
	'valid job ID': inference result
	'error job ID': error message
	'`error_id`': all the error job IDs
}

When the batching server receives these data, it will mark the error jobs' status code as 422. After that, all the jobs in this batch will be attached with the corresponding results and marked as done. So the handler know this job has validation errors and can return the error message to the client.

Simple HTTP service without dynamic batching

For this part, we use falcon which is a very powerful Python framework for web APIs. To generate the OpenAPI document and validate the request data, we use spectree.

If you would like to use gunicorn, ventu also expose the app element.