blur-bg

Scaling Virtual Reality Medical Conversations using Multi-Head Classification

Techniques:

Content categorization

Client Website:

Objectives

Implement a fast and accurate multi-model solution able to scale to thousands of labels

Problem

We worked with a client that develops virtual reality technology for the medical sector, and needed to extract structured information from dialogues between doctors and non-player character (NPC) patients, for example symptoms, and body parts. There were a very large number of possible entities of interest (labels) to extract: at the time of implementation around 700, with plans for this number to eventually hit several thousands.

The client’s application offered a range of scenarios to their clients, each one representing one particular conversation between a doctor and their patient about a given health issue . Consequently not all labels were relevant in all situations, which informed our decision-making when it came to the technological solution.

Due to the nature of this domain and use case, the possible number of entities - or labels - to consider was in the hundreds and potentially thousands. The client already had a machine learning model in place to apply these labels, developed jointly with Mantis in a previous iteration of the project. This model performed well when classifying which labels (e.g. symptoms, affected body parts) occurred in any given utterance, up to a few hundred labels in the dataset. After a new batch of labels was added the model unexpectedly dropped in accuracy.

Our investigation showed that the model became saturated if it needed to learn too many labels, causing performance to suffer. Obvious solutions to this problem were to use a larger model, or multiple versions of the original model, which could be fine-tuned on different subsets of the labels. However, the client considered minimising resource use to be critical and was aiming to be as efficient as possible when it came to:

  • server disk space
  • training time
  • inference time

So we had to find a solution performing efficiently on those fronts while being able to scale to thousands of labels.

Solution

The insight that not all labels are relevant in all application scenarios formed the foundation for our solution. Instead of having to train one model dealing with thousands of labels, we followed a divide-and-conquer approach: train several models, each dealing with different types of scenarios and thus only a subset of labels. This very much simplifies the learning process for the models and facilitates better results.

These days the canonical answer for challenges like this is usually parameter-efficient fine-tuning (like LoRA), which allows for easily creating different model “flavors” while adding only a small overhead for each one. At the time we began working on this problem, these techniques and the corresponding tooling weren’t sufficiently advanced, so we opted for a different approach: multi-head classification.

Multi-head classification is the idea of having one shared base model learning the fundamental properties of the domain and an arbitrary number of light-weight head models plugged onto the base model. The latter’s task, in our scenario, is to learn how to classify a subset of the overall labels. This checked all of our boxes:

  • It would scale to an arbitrary number of labels by allowing an arbitrary numbers of head models.
  • Head models could be very small, which 1) cut down on disk space usage and 2) means that the overhead with inference time was negligible
  • Training time overhead would be kept low by sharing the base model among head models.

Impact

The multi-head model setup exhibited stable and strongly improved performance across the roughly 700 labels used by the client at the time of implementation. Training time and inference time increased slightly and proportionally to the number of labels. The client was able to scale their model training and deployment robustly with only a small overhead for training and inference time, as well as disk space usage.

Next Project

Evaluating a Decentralized Data Sharing Platform for Global Telecoms Association

To evaluate the suitability of a decentralised platform to enable members of an association to share their data in a safe and secure way

Visit Project

Do you have a Natural Language Processing problem you need help with?

Let's Talk