Pre-annotate your data with a lightning-fast zero-shot classifier
July 19, 2022
TL;DR In this blog post, you will learn how to:
- Build a lightning-fast zero-shot classifier with Neural Magic's DeepSparse;
- Use it with Rubrix to create an efficient training set in less than 30 minutes;
- Use the training set to fine-tune a DistilBERT model with an accuracy of 0.98;
Manually annotating your data from scratch can be a cumbersome task. Rubrix provides several ways to alleviate this process, and one of them is to pre-annotate your data with model predictions. However, when you deal with a text classification task, the chances are low that you will find a model already trained with your specific label schema.
Run, zero-shot, run!
Zero-shot classifiers can solve this problem. The idea here is to train a classifier on one set of labels and evaluate it on a different set it has never seen before. One approach to building a zero-shot classifier is to reframe the classification task as Natural Language Inference. It works by posing each candidate label as a "hypothesis" (such as "This text is {candidate label}") and the text we want to classify as the "premise". We can then take the candidate with the highest entailment score as the most likely label for the input text. Hugging Face's transformers library has a neat implementation of this approach; you can find more details in this discussion.
One downside of this approach is that each label needs its forward pass, and for many candidate labels, inference can become computationally quite expensive. Here is where Neural Magic's DeepSparse Engine comes to the rescue. It takes advantage of sparsification and other tricks to achieve GPU-class performance on CPUs during inference time. Their SparseZoo contains a growing number of sparsified models that are much more lightweight than their dense counterparts but achieve comparable accuracies.
Defining the experiment
This blog post shows you how to program a lightning-fast zero-shot classifier with Neural Magic's deepsparse library. To test the classifier, we will use it together with Rubrix to pre-annotate the News Popularity dataset and build a training dataset in no time to ultimately fine-tune a model for our specific classification task.
For our experiment, apart from Rubrix, we need the following libraries:
Look at our short setup guide if you haven't set up Rubrix yet.
pip install "rubrix[server]" deepsparse "transformers[torch]"
Let's start with our simple implementation of the zero-shot approach discussed above:
from typing import List, Tuple, Optional, Dictfrom deepsparse.transformers import pipelineimport numpy as npDEFAULT_LABELS = ["microsoft", "economy", "obama", "palestine"]STR2INT = {label: i for i, label in enumerate(DEFAULT_LABELS)}# We use a heavily sparsified DistilBERT model fine-tuned on the MNLI dataset, available on Neural Magic's model zoo.model_path = "zoo:nlp/text_classification/distilbert-none/pytorch/huggingface/mnli/pruned80_quant-none-vnni"sparsified_classifier = pipeline( "text-classification", model_path=model_path, return_all_scores=True, batch_size=16, max_length=64)def zero_shot_prediction( premise: List[str], labels: Optional[List[str]] = None, hypothesis_template = "This title is about {}.",) -> Tuple[np.ndarray, np.ndarray]: """Make a zero-shot prediction given some input texts and a list of labels. Args: premise: Input texts to classify. labels: Labels for the classification task. hypothesis_template: Template of the hypothesis. Will be completed with the labels. Returns: Predictions and probabilities. """ labels = labels or DEFAULT_LABELS # Formulate hypothesis hypos = [hypothesis_template.format(label) for label in labels] # Store entailment logits in an array logits = np.empty((len(premise), len(labels))) # Extract entailment logits for each hypothesis given the premises for i, hypo in enumerate(hypos): predictions = sparsified_classifier(list(zip(premise, [hypo]*len(premise)))) logits[:, i] = extract_entailment_logits(predictions) # Apply softmax probs = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) # Get predictions preds = probs.argmax(axis=1) return preds, probsdef extract_entailment_logits( predictions: List[List[Dict[str, float]]], entailment_id: int = 0) -> np.ndarray: """Helper function to extract the entailment logits of an NLI model output. This is a slight hack, since we arbitrarily assume `sum(logits) == 0` """ probs = np.array([[p["score"] for p in pred] for pred in predictions]) logits = np.log(probs) - np.sum(np.log(probs), axis=-1, keepdims=True) / 3 return logits[:, entailment_id]
Once we have coded our zero-shot classifier, it is time to establish a baseline for our small experiment.
Establish a baseline
The baseline for our small experiment is the zero-shot classifier itself applied to a test split of the News Popularity dataset. We restrict the model input to the titles of the news articles to make the task a bit harder and keep the model input small.
from datasets import load_datasetfrom sklearn.metrics import classification_report# Load the News Popularity dataset from the Hugging Face Hubds = load_dataset("newspop", split="train")train, test = ds.train_test_split(test_size=10000, seed=43).values()# Make predictions with our lightning-fast zero-shot classifierdef predict(rows): predictions, _ = zero_shot_prediction(rows["title"]) return { "prediction": [INT2STR[pred] for pred in predictions], "label": rows["topic"], }test_prediction = test.map(predict, batched=True, batch_size=16)# Print out the test accuracyprint(classification_report(test_prediction["label"], test_prediction["prediction"]))
The inference took us roughly 5 minutes on our machine (Intel i7-0750H, 6 cores), and we are already achieving an accuracy of 0.95! Looking at the titles in more detail, we notice that many explicitly contain the topic keyword, which suits our zero-shot approach. Let us compare the performance and speed with Hugging Face's vanilla zero-shot pipeline.
from transformers import pipeline# Load the zero-shot classifier of the transformers libraryclassifier = pipeline( "zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")# Make predictionsdef predict(rows): outputs = classifier( rows["title"], candidate_labels=DEFAULT_LABELS, hypothesis_template="This title is about {}." ) return { "prediction": [output["labels"][0] for output in outputs], "label": rows["topic"], }test_prediction_tr = test.map(predict, batched=True, batch_size=16)# Print out the test accuracyprint(classification_report(test_prediction_tr["label"], test_prediction_tr["prediction"]))
While it achieves the same accuracy, the inference took roughly four times longer than with the sparsified model by Neural Magic!
Build the training data
Let us try to beat the 0.95 by fine-tuning a DistilBERT base model to our specific classification task. For this, we need training examples labeled with our four topics (of course, the training split of the NewsPop dataset already contains those labels, but let us assume it did not). Instead of manually labeling the data from scratch, we will use the model predictions of our zero-shot classifier to guide our annotation process.
So first, let's apply our lightning-fast zero-shot classifier to a smaller random subset of the training split and log the data with the corresponding predictions to the Rubrix UI.
import rubrix as rb# Create predictions for our traning data that we will use as pre-annotationsdef predict(rows): _, batch_probs = zero_shot_prediction(rows["title"]) prediction = [] for probs in batch_probs: prediction.append( [ {"label": label, "score": score} for label, score in zip(DEFAULT_LABELS, probs) ] ) return {"prediction": prediction}train_predicted = train.select(range(10000)).map(predict, batched=True, batch_size=16)# Create Rubrix records from the datasetrecords = rb.read_datasets(train_predicted, task="TextClassification", inputs=["title", "headline"])# Log the records to the Rubrix UIrb.log(records=records, name="newspop")
We start by selecting the records with the highest prediction scores using the score filter of the Rubrix UI. We quickly see that those are prominent examples, and with the bulk-annotation tool, we can validate ~500 of them in a matter of minutes. While validating predictions, we also keep track of the label distribution of our annotations and use the "predicted as" filter to balance the labels.
Now it's time to get some training examples that are more likely to give the fine-tuned model an advantage over the zero-shot classifier. We select the records with the lowest prediction scores and validate or correct the zero-shot predictions. Most challenging records don't have the topic keyword in the news title, but they do contain it in the article's content, which we also logged to Rubrix. We can take advantage of this by searching for "Microsoft", for example, and bulk-annotate the results.
This way, we can build a training set of around 1000 examples in less than 30 minutes with Rubrix!
Fine-tune a model
Now that we have enough training examples let us fine-tune a DistilBERT model using the Hugging Face transformers library to see if we can beat the accuracy of the zero-shot classifier.
First, let's load our training data from Rubrix and prepare it for training a transformers model.
# Load the annotated data from the Rubrix UIrubrix_dataset = rb.load("newspop")# Prepare the Rubrix dataset for trainingtrain_ds = rubrix_dataset.prepare_for_training()# We will only use the news titles as the model 'text' inputtrain_ds = train_ds.rename_column("title", "text")
We then tokenize the input data, set up the base DistilBERT model for fine-tuning, and train it on our training set. We have already slightly optimized the learning rate and the number of epochs in the following code block. If you want to optimize them further, don't forget first to split the data into train and validation splits.
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer# Tokenize our datatokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")def tokenize(examples): return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=64)tokenized_train_ds = train_ds.map(tokenize, batched=True)# Load our pre-trained modellabels = train_ds.features["label"].namesmodel = AutoModelForSequenceClassification.from_pretrained( "distilbert-base-uncased", num_labels=4, id2label={i: label for i, label in enumerate(labels)}, label2id={label: i for i, label in enumerate(labels)})# Fine-tune the pre-trained model with our training datatraining_args = TrainingArguments( "newspop_titles", learning_rate=5e-5, num_train_epochs=3, per_device_train_batch_size=8,)trainer = Trainer( args=training_args, model=model, train_dataset=tokenized_train_ds,)trainer.train()
As the last step, we load the trained model into a text-classification pipeline and evaluate it on our test set.
# Load out fine-tuned model into a pipelinefinetuned_classifier = pipeline( "text-classification", model=model, tokenizer=tokenizer,)# Make predictions with our fine-tuned modeldef predict(rows): outputs = finetuned_classifier(rows["title"]) return { "prediction": [output["label"] for output in outputs], "label": rows["topic"], }test_prediction_ft = test.map(predict, batched=True, batch_size=32)# Print out the test accuracyprint(classification_report(test_prediction_ft["label"], test_prediction_ft["prediction"]))
That's it. With our trained model, we achieve an accuracy of 0.98, reducing the error rate of the zero-shot classifier by more than a half.
Summary
In this blog post, you learned how to:
- Build a lightning-fast zero-shot classifier with Neural Magic's DeepSparse Engine;
- Use it with Rubrix to create an efficient training set in less than 30 minutes;
- Use the training set to fine-tune a DistilBERT model with an accuracy of 0.98;
However, this is just the beginning. Neural Magic is working on implementing a zero-shot classification pipeline directly into their DeepSparse Engine, taking advantage of the models in their SparseZoo. So the usage and application of a lightning-fast zero-shot classifier should become even more effortless.
At the same time, Rubrix is constantly implementing new annotation techniques. Besides pre-annotating your data with model predictions, Rubrix also has built-in support for weak supervision. If you are familiar with this branch of machine learning, you probably noticed that the dataset and task of this blog post are prime candidates for its application. We invite you to get inspired by our tutorial on weak supervision and try to beat the 0.98 accuracy of this small experiment.