One-stop NLP: Multi-task prompts for LLMs

machine learning
Author

Paul Simmering

Published

October 29, 2023

In NLP, we often want to extract multiple pieces of information from a text. Each extraction task is typically done by one model. For example, we might want to classify the topic of a text, do named entity recognition and extract the sentiment. To build such a pipeline, we need to train three different models.

What if we asked a large language model (LLM) to do it all in one step and return a god-view JSON object with all the structured information we need? That’s the idea I’d like to explore in this article.

Illustration generated with DALL·E 3

I’ll use the instructor package to describe the desired JSON object using a Pydantic model. Then I’ll send the requests to the OpenAI API with the texttunnel package. I’m the main developer of texttunnel.

Note

This article is an exploration, not a recommendation. Please refer to the last section for a discussion of the pros and cons of this approach.

Data: News articles

Let’s say we are building a news analysis tool.

We’ll use the cc_news dataset from Hugging Face. It contains 708,241 English language news articles published between January 2017 and December 2019.

from datasets import load_dataset

dataset = load_dataset("cc_news", split="train")

We won’t be training a model in this article, so we’ll just use the first 500 unique articles from the training set and run them through a pre-trained LLM. Let’s load the data into a Polars dataframe and take a look at the first five rows.

import polars as pl

news = pl.from_arrow(dataset.data.table).unique(subset="text").head(500)

news.head(5)

# Save to disk for later use
news.write_parquet("news.parquet")

Defining the God-View JSON

Pydantic allows us to define a detailed schema for the JSON object we want to get from the LLM.

This is what it looks like:

from enum import Enum
from typing import List
from pydantic import BaseModel
from instructor import OpenAISchema


# Define the labels for the different tasks
class TopicLabel(Enum):
    ARTS = "ARTS"
    BUSINESS = "BUSINESS"
    ENTERTAINMENT = "ENTERTAINMENT"
    HEALTH = "HEALTH"
    POLITICS = "POLITICS"
    SCIENCE = "SCIENCE"
    SPORTS = "SPORTS"
    TECHNOLOGY = "TECHNOLOGY"


class SentimentLabel(Enum):
    POSITIVE = "POSITIVE"
    NEGATIVE = "NEGATIVE"
    NEUTRAL = "NEUTRAL"


class NamedEntityLabel(Enum):
    PERSON = "PERSON"
    ORG = "ORG"
    PRODUCT = "PRODUCT"
    LOCATION = "LOCATION"
    EVENT = "EVENT"


# Define how named entities are represented
class NamedEntity(BaseModel):
    text: str
    label: NamedEntityLabel


# Define the schema for the JSON object that
# we want the LLM to return
class News(OpenAISchema):
    topics: List[TopicLabel]
    sentiment: SentimentLabel
    named_entities: List[NamedEntity]

Now, how do we get the LLM to return this JSON object?

The OpenAI API has the function calling feature, which allows us to send a JSON schema describing a Python function to the API. The model will respond with a JSON object that matches the schema.

The instructor package lets us take a Pydantic model and convert it to a JSON schema that we can send to the OpenAI API.

import pprint

function_schema = News.openai_schema

pprint.pprint(function_schema)
{'description': 'Correctly extracted `News` with all the required parameters '
                'with correct types',
 'name': 'News',
 'parameters': {'$defs': {'NamedEntity': {'properties': {'label': {'$ref': '#/$defs/NamedEntityLabel'},
                                                         'text': {'title': 'Text',
                                                                  'type': 'string'}},
                                          'required': ['text', 'label'],
                                          'title': 'NamedEntity',
                                          'type': 'object'},
                          'NamedEntityLabel': {'enum': ['PERSON',
                                                        'ORG',
                                                        'PRODUCT',
                                                        'LOCATION',
                                                        'EVENT'],
                                               'title': 'NamedEntityLabel',
                                               'type': 'string'},
                          'SentimentLabel': {'enum': ['POSITIVE',
                                                      'NEGATIVE',
                                                      'NEUTRAL'],
                                             'title': 'SentimentLabel',
                                             'type': 'string'},
                          'TopicLabel': {'enum': ['ARTS',
                                                  'BUSINESS',
                                                  'ENTERTAINMENT',
                                                  'HEALTH',
                                                  'POLITICS',
                                                  'SCIENCE',
                                                  'SPORTS',
                                                  'TECHNOLOGY'],
                                         'title': 'TopicLabel',
                                         'type': 'string'}},
                'properties': {'named_entities': {'items': {'$ref': '#/$defs/NamedEntity'},
                                                  'title': 'Named Entities',
                                                  'type': 'array'},
                               'sentiment': {'$ref': '#/$defs/SentimentLabel'},
                               'topics': {'items': {'$ref': '#/$defs/TopicLabel'},
                                          'title': 'Topics',
                                          'type': 'array'}},
                'required': ['named_entities', 'sentiment', 'topics'],
                'type': 'object'}}

This clearly defines what we want the LLM to return. It uses the enum, required and properties keywords from the JSON schema specification.

Sending requests

Next, we need to send the requests to the OpenAI API. The texttunnel package makes this easy and efficient. We start by defining the requests. Each article is sent as a separate request.

from texttunnel import chat, models
import polars as pl

news = pl.read_parquet("news.parquet")

requests = chat.build_requests(
    model=models.GPT_3_5_TURBO,
    function=function_schema,
    system_message="Analyze news articles. Strictly stick to the allowed labels.",
    params=models.Parameters(max_tokens=1024),
    texts=news["text"].to_list(),
    long_text_handling="truncate",
)

print(f"Built {len(requests)} requests")
Built 500 requests

And how much will it cost to send these requests?

cost_usd = sum([x.estimate_cost_usd() for x in requests])

print(f"Estimated cost: ${cost_usd:.2f}")
Estimated cost: $1.68

Next, let’s set up a cache to store the responses. This way, we can experiment and never have to pay for the same request twice.

from aiohttp_client_cache import SQLiteBackend
from pathlib import Path

cache = SQLiteBackend("cache.sqlite", allowed_methods="POST")

This will create a file called cache.sqlite in the current directory, which will hold a copy of the responses.

Now we’re ready to actually send the requests.

from texttunnel import processor
import logging
import pickle

logging.basicConfig(level=logging.INFO)

# Setup logging for the texttunnel package
logging.getLogger("texttunnel").setLevel(logging.INFO)

logging.info(f"Sending {len(requests)} requests to the OpenAI API")

responses = processor.process_api_requests(
    requests=requests,
    cache=cache,
)

# Save to disk for later use
with open("responses.pickle", "wb") as f:
    pickle.dump(responses, f)

The texttunnel package sends the requests in parallel and caches the responses.

Results

Parsing and validation

For each request, process_api_requests returned a list containing two dicts: one containing the request, the other the API’s response. Inside the response is the arguments key, which contains a string that should be parseable into a Python dict that matches the schema we defined.

We parse the responses and count the parsing errors.

import pickle
from texttunnel import processor

with open("responses.pickle", "rb") as f:
    responses = pickle.load(f)

parsing_errors = 0


def parse(response):
    global parsing_errors
    try:
        return processor.parse_arguments(response)
    except Exception:
        parsing_errors += 1
        return None


arguments = [parse(response) for response in responses]

print(f"Parsing errors: {parsing_errors} out of {len(arguments)} responses")
Parsing errors: 3 out of 500 responses

Next, we verify that they conform to the schema we defined.

from pydantic import ValidationError


def validate(argument):
    News.model_validate(argument)
    return argument


def run_validation(arguments, validation_fun):
    validation_errors = 0
    out = []
    for argument in arguments:
        if argument is None:
            # JSON parsing error
            out.append(None)
            continue
        try:
            argument = validation_fun(argument)
            out.append(argument)
        except ValidationError:
            validation_errors += 1
            out.append(None)

    print(f"Validation error in {validation_errors} out of {len(arguments)} responses")

    return out


valid_arguments = run_validation(arguments, validate)
Validation error in 130 out of 500 responses

The LLM doesn’t always follow the expected format. It adds extra labels to topics and entities that are not in the schema.

These can be fixed automatically. Let’s try again.

def fix_and_validate(argument):
    fixed_argument = argument.copy()

    topics = list(TopicLabel.__members__)

    # Remove topics that are not in the schema
    fixed_argument["topics"] = [x for x in argument["topics"] if x in topics]

    entities = list(NamedEntityLabel.__members__)

    if argument["named_entities"] is not None:
        fixed_argument["named_entities"] = [
            x for x in argument["named_entities"] if x["label"] in entities
        ]

    validate(fixed_argument)
    return fixed_argument


valid_arguments = run_validation(arguments, fix_and_validate)
Validation error in 0 out of 500 responses

Removing the invalid labels fixed all validation errors.

Next, let’s bring the answers into a Polars dataframe.

valid_arguments = [x for x in valid_arguments if x is not None]
answers = pl.DataFrame(valid_arguments, orient="records")

print(answers.head(5))
shape: (5, 3)
┌────────────────────────────┬───────────┬───────────────────────────────────┐
│ topics                     ┆ sentiment ┆ named_entities                    │
│ ---                        ┆ ---       ┆ ---                               │
│ list[str]                  ┆ str       ┆ list[struct[2]]                   │
╞════════════════════════════╪═══════════╪═══════════════════════════════════╡
│ ["POLITICS", "TECHNOLOGY"] ┆ NEGATIVE  ┆ [{"James Clapper","PERSON"}, {"R… │
│ ["POLITICS", "TECHNOLOGY"] ┆ NEGATIVE  ┆ [{"Canadian troops","ORG"}, {"Ma… │
│ ["BUSINESS"]               ┆ POSITIVE  ┆ [{"Moshe Kahlon","PERSON"}, {"Is… │
│ ["BUSINESS"]               ┆ NEUTRAL   ┆ [{"Bailoy Irrigation Control Sys… │
│ ["SPORTS"]                 ┆ NEUTRAL   ┆ [{"Pep Guardiola","PERSON"}, {"B… │
└────────────────────────────┴───────────┴───────────────────────────────────┘

Note that the topics and named entities are now represented as nested elements.

Visualization

The LLM’s answers could be used to power a dashboard that shows the most common topics, positive and negative sentiment and the most frequently mentioned named entities. Let’s get a preview of what that could look like.

import plotly.express as px

topic_sentiment = (
    answers.drop_nulls().explode("topics")
    # Sort for legend
    .sort(
        pl.when(pl.col("sentiment") == "POSITIVE")
        .then(pl.lit(0))
        .when(pl.col("sentiment") == "NEUTRAL")
        .then(pl.lit(1))
        .otherwise(pl.lit(2))
    )
)

sentiment_colors = {
    "POSITIVE": "#98FB98",
    "NEUTRAL": "#B0C4DE",
    "NEGATIVE": "#F08080",
}

fig = px.histogram(
    data_frame=topic_sentiment,
    x="topics",
    color="sentiment",
    barmode="group",
    labels={"topics": "Topic", "sentiment": "Sentiment"},
    color_discrete_map=sentiment_colors,
)

fig.update_yaxes(title_text="Mentions")
fig.update_layout(title="Topic and sentiment distribution")
fig.show()

We see that business, technology and politics are the most common topics. Politics topics are most commonly negative, while entertainment topics are most commonly positive.

named_entities = (
    answers.explode("named_entities")
    .unnest("named_entities")
    .group_by("text", "label")
    .agg(pl.count("label").alias("count"))
    .sort(by="count")
    .drop_nulls()
)

# Top 5 named entities by label
top_named_entities = pl.concat(
    [x.top_k(5, by="count") for x in named_entities.partition_by("label")]
)

fig = px.bar(
    data_frame=top_named_entities,
    facet_row="label",
    color="label",
    x="count",
    y="text",
    orientation="h",
    labels={"count": "Mentions"},
)

fig.update_yaxes(matches=None, title_text="", autorange="reversed")
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.update_layout(showlegend=False, title="Most frequent named entities by label")

fig.show()

The most common people are American politicians. Products are dominated by tech products. Events are dominated by Sports events. China stands out as the most commonly mentioned location.

Unvalidated model

All of this is based on zero shot classification and zero shot named entity recognition. We don’t have a validation set, so we don’t know how accurate the model is. For production use, this would need to be tested.

Discussion

The one-stop approach is diametrically opposed to Matthew Honnibal’s article “Against LLM Maximalism”.

They [LLMs] are extremely useful, but if you want to deliver reliable software you can improve over time, you can’t just write a prompt and call it a day

The alternate pipeline with a modular approach of specialized models could look like this:

graph LR
    A([Text]) --> B[Tokenization]
    B --> C[Sentence splitting]
    C --> D[Topic classification]
    D --> E[Sentiment classification]
    E --> F[NER]

The tokenization and sentence splitting don’t require trainable models.

Explosion AI’s spaCy package is excellent for constructing such pipelines. With the extension spacy-llm, it can also feature LLMs in the pipeline and Prodigy integrates them into the annotation workflow.

Advantages of multi-task prompts compared to pipelines

  • Simplicity: No training required and only one model to deploy or call by API. That means less code, infrastructure, and documentation to maintain. It also requires less knowledge about various model architectures. The is article showed that it’s possible to build a multi-task prompt pipeline with just a few lines of code. Note that spaCy also allows training a regular model to perform multiple tasks.
  • Easy upgrading: If the LLM gets better, all tasks benefit from it. No need to retrain specialized models. When OpenAI releases GPT-5, one could switch to it with a single line of code.
  • Easy extension: If we want to add a new label, we just add it to the schema and we’re done. Same with adding a new task, e.g. summarization.
  • Cheaper than chained LLM calls: If we were to call an LLM separately for each step, we’d have to send over the text multiple times. That’s more expensive than sending it once and getting all the analysis in one go. But it may still be more expensive than a chain of specialized models.

Disadvantages of multi-task prompts compared to pipelines

  • Tempts to skip validation: Wouldn’t it be nice to just trust that the LLM gets it right? Unfortunately, we can’t. LLMs still suffer from hallucinations, biases, and other problems.
  • Lack of modularity: Can’t reuse one task in another pipeline and can’t use specialized models that others have trained.
  • New error types: JSON parsing errors, use of labels that are not in the schema.
  • Monolithic model: If you wish to fine-tune the LLM, it must be trained on all tasks at once. Training data must be available for all tasks. If you want to add a new task, you have to retrain the whole model.
  • High inference cost: Compared to efficient models like DistilBERT that comfortably run on a single GPU from a few years ago, LLMs are very expensive to run, requiring a cluster of the latest GPUs.
  • High latency: LLMs have to do a lot more matrix multiplication than smaller models. That means they take longer to respond, which is a problem for interactive applications.

To conclude, I see unvalidated multi-task prompts as a tool for low-stakes exploratory work. If proper validation is added they can be viable in batch processing scenarios where simplicity is valued over modularity and computational efficiency.