Build a Streaming API for a Large Language Model#

Large Language Models can power a wide range of applications from chatbots to assistants and intelligent systems. However, these models can be heavy and slow and your users want systems that are both intelligent and fast!

Large language models work by turning your questions into tokens and then generating new token one at a time until it decides that generation should stop. This means you want to stream the output tokens generated by a large language model to the client. In this tutorial, we will discuss how to achieve this with Streaming Endpoints in Jina.

Service Schemas#

The first step is to define the streaming service schemas, as you would do in any other service framework. The input to the service is the prompt and the maximum number of tokens to generate, while the output is simply the token ID:

from docarray import BaseDoc


class PromptDocument(BaseDoc):
    prompt: str
    max_tokens: int


class ModelOutputDocument(BaseDoc):
    token_id: int
    generated_text: str

Note

Thanks to DocArray’s flexibility, you can implement very flexible services. For instance, you can use Tensor types to efficiently stream token logits back to the client and implement complex token sampling strategies on the client side.

Service initialization#

Our service depends on a large language model. As an example, we will use the gpt2 model. This is how you would load such a model in your executor

from jina import Executor, requests
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')


class TokenStreamingExecutor(Executor):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.model = GPT2LMHeadModel.from_pretrained('gpt2')

Implement the streaming endpoint#

Our streaming endpoint accepts a PromptDocument as input and streams ModelOutputDocuments. To stream a document back to the client, use the yield keyword in the endpoint implementation. Therefore, we use the model to generate up to max_tokens tokens and yield them until the generation stops:

class TokenStreamingExecutor(Executor):
    ...

    @requests(on='/stream')
    async def task(self, doc: PromptDocument, **kwargs) -> ModelOutputDocument:
        input = tokenizer(doc.prompt, return_tensors='pt')
        input_len = input['input_ids'].shape[1]
        for _ in range(doc.max_tokens):
            output = self.model.generate(**input, max_new_tokens=1)
            if output[0][-1] == tokenizer.eos_token_id:
                break
            yield ModelOutputDocument(
                token_id=output[0][-1],
                generated_text=tokenizer.decode(
                    output[0][input_len:], skip_special_tokens=True
                ),
            )
            input = {
                'input_ids': output,
                'attention_mask': torch.ones(1, len(output[0])),
            }

Learn more about streaming endpoints from the Executor documentation.

Serve and send requests#

The final step is to serve the Executor and send requests using the client. To serve the Executor using gRPC:

from jina import Deployment

with Deployment(uses=TokenStreamingExecutor, port=12345, protocol='grpc') as dep:
    dep.block()

To send requests from a client:

import asyncio
from jina import Client


async def main():
    client = Client(port=12345, protocol='grpc', asyncio=True)
    async for doc in client.stream_doc(
        on='/stream',
        inputs=PromptDocument(prompt='what is the capital of France ?', max_tokens=10),
        return_type=ModelOutputDocument,
    ):
        print(doc.generated_text)


asyncio.run(main())
The
The capital
The capital of
The capital of France
The capital of France is
The capital of France is Paris
The capital of France is Paris.