Towards AI Can Help your Team Adopt AI: Corporate Training, Consulting, and Talent Solutions.


Two-Headed Classifier Use Case
Computer Vision   Data Science   Latest   Machine Learning

Two-Headed Classifier Use Case

Last Updated on November 5, 2023 by Editorial Team

Author(s): Argo Saakyan

Originally published on Towards AI.

Two-Headed Classifier Use Case

Photo by Vincent van Zalinge on Unsplash


Let’s talk about some real world cases of computer vision tasks. In the first glance, the classification problem is as simple as it gets, and that’s kind of true. But in the real world you often have a lot of constraints like: model’s speed, size, ability to run on mobile. Moreover, you’ll probably have several tasks, and it’s not the best idea to have a separate model for each task. At least when you can optimize the architecture of your system and use less models — you should. But also, you don’t want to lose accuracy, right?
So when you consider all constraints and optimizations, your task becomes more complex. I want to show an example of a classification problem with several classes, when visually they might not be that similar.

I’ll start with a simple task: classify if an image is a real paper document, or it’s an image of a screen with some document on it. It could be a tablet/phone or a big monitor.

Real document

And this one is pretty straightforward. You start with a dataset, you collect it, so it is representative clean and big enough. Then you take a model which works with your constraints (speed, accuracy, exportability) and use an ordinary training pipeline, just paying attention to an imbalanced data. That should give you a pretty good results.

But let’s say now you need to add a new feature, so your model is able to classify if the coming input is an image of a document or something which is not a document, like a bag of chips/can or some marketing material. And this task is not that important as your original one, and it is not as hard too.

Not a document

Here is the structure of our dataset:

├── documents/
│ ├── img_1.jpg
U+007C ...
│ └── img_100.jpg
├── screens/
│ ├── img_1.jpg
U+007C ...
│ └── img_100.jpg
├── not a documents/
│ ├── img_1.jpg
│ ...
│ └── img_100.jpg
├── train.csv
├── val.csv
└── test.csv

And csv file structure:

documents/img_1.jpg U+007C 0
not a document/img_1.jpg U+007C 1
screens/img_1.jpg U+007C 2

First column contains a relative path to the image, second column — class id. Now let’s talk about two approaches to solve this task.

Three output neurons approach (simple)

As we want to have an optimal system architecture, we are not going to have a new model which is again a binary classifier just for every small task. The first idea that comes to mind is to add this (not a document class) as a third class to our original model, so we end up with classes like: ‘document’, ‘screen’, ‘not_document’.

And this is a viable option, but maybe the importance of these tasks are not equal and also visually these classes might not be that similar, and you might want to have a little bit different features extracted for your classification layer. Also, let’s not forget that it’s very important not to lose accuracy of the original task.

Two heads with binary classification approach (custom)

Another approach would be to use mainly one backbone and two heads with binary classification, one head for each task. This way we will have 1 model for 2 tasks, each task will be separated, and we will have a lot of control over each task.

Speed practically is not going to suffer (I got ~5–7% slower inference on 1 image with 3060), the size of the model will get a little bit bigger (in my case after exporting to TFLlite it went from 500kb to 700kb). One more handy thing for our case would be to weight our losses, so loss of the first head has N times more weight than loss of the second head. This way we can be sure, that our focus in on the first (main) task, and we are less likely to lose accuracy on it.

Here is a how it looks like:

Two headed output

I am using SuffleNetV2 for this task and I split the architecture into two parts, starting with the last convolutional layer. Each head has its own last conv layer, global pooling and fully connected layer for classification.

Code examples

Now when we understand the model architecture, it’s clear that we need to make some changes to our training pipeline, starting with dataset generator. While writing code for the dataset and dataloader, we now need to return 1 image and 2 labels for each iteration. First label will be used for the first head and second one for the second head, let’s take a look at a code example:

class CustomDataset(Dataset):
def __init__(
root_path: Path,
split: pd.DataFrame,
train_mode: bool,
) -> None:
self.root_path = root_path
self.split = split
self.img_size = (256, 256)
self.norm = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

def _init_augs(self, train_mode: bool) -> None:
if train_mode:
self.transform = transforms.Compose(
self.transform = transforms.Compose(

def _convert_rgb(self, x: torch.Tensor) -> torch.Tensor:
return x.convert("RGB")

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, int]:
image_path, label = self.split.iloc[idx]

image = / image_path)
image.draft("RGB", self.img_size)
image = ImageOps.exif_transpose(image) # fix rotation
image = self.transform(image)

label_lcd = int(label == 2)
label_other = int(label == 1)

return image, label_lcd, label_other

def __len__(self) -> int:
return len(self.split)

We are interested only in __getitem__, where we split label to label_lcd and label_other (our 2 heads). label_lcd is 1 for a 'screen' and 0 for other cases. label_other is 1 for 'not a document' and 0 for other cases.

For our architecture, we have the following:

class CustomShuffleNet(nn.Module):
def __init__(self, n_outputs_1: int, n_outputs_2: int) -> None:
super(CustomShuffleNet, self).__init__()
self.base_model = models.shufflenet_v2_x0_5(

# Create head convolution layers
self.head1_conv = self._create_head_conv()
self.head2_conv = self._create_head_conv()

# Create fully connected layers for both heads
in_features = self.base_model.fc.in_features
del self.base_model.fc
self.fc1 = nn.Linear(in_features, n_outputs_1)
self.fc2 = nn.Linear(in_features, n_outputs_2)

def _create_head_conv(self) -> nn.Module:
return nn.Sequential(
nn.Conv2d(192, 1024, kernel_size=1, stride=1, bias=False),

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.base_model.conv1(x)
x = self.base_model.maxpool(x)
x = self.base_model.stage2(x)
x = self.base_model.stage3(x)
x = self.base_model.stage4(x)

# Pass through the separate convolutions for each head
x1 = self.head1_conv(x)
x1 = x1.mean([2, 3]) # globalpool for first head
out1 = self.fc1(x1)

x2 = self.head2_conv(x)
x2 = x2.mean([2, 3]) # globalpool for second head
out2 = self.fc2(x2)
return out1, out2

From the last conv layer (included) architecture is split into two parallel heads. Now the model has 2 outputs, as we need.

Training loop:

def train(
train_loader: DataLoader,
val_loader: DataLoader,
device: str,
model: nn.Module,
loss_func: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
epochs: int,
path_to_save: Path,
) -> None:

best_metric = 0, log_freq=100)
for epoch in range(1, epochs + 1):

with tqdm(train_loader, unit="batch") as tepoch:
for inputs, labels_1, labels_2 in tepoch:
inputs, labels_1, labels_2 = (,,,
tepoch.set_description(f"Epoch {epoch}/{epochs}")


outputs_1, outputs_2 = model(inputs)
loss_1 = loss_func(outputs_1, labels_1)
loss_2 = loss_func(outputs_2, labels_2)
loss = 2 * loss_1 + loss_2



metrics = evaluate(
test_loader=val_loader, model=model, device=device, mode="val"

if scheduler is not None:

if metrics["f1_1"] > best_metric:
best_metric = metrics["f1_1"]

print("Saving new best model...")
path_to_save.parent.mkdir(parents=True, exist_ok=True), path_to_save)

wandb_logger(loss, metrics, mode="val")

We get image, label_1, label_2 from dataset, run the image (actually a batch) through the model, then compute losses 2 times (1 time for each head output). We multiply our main loss to 2 to stay focused on our 'main' head. For sure, we need to change things like metrics computing to accommodate our two-headed model (you can find a full example in the repo). And what's also important – we save our model based on the metric we get from our 'main' head.


It doesn’t make sense to compare F1-scores from training pipeline, as they are computed for 3 and 2 classes, and we are interested in metrics separately. That’s why I used a specific test dataset, ran both models and compared precision and recall for task document/screen and document/not_document separately.

Both models use 256×256 input size, but I also added a version of a simple 3 output neurons approach with 320×320 input size, as its inference time was pretty much the same as a two-headed model, so it was interesting to compare. The second task ended up with exactly the same results for both approaches (as it is an easy task for a model in my case), but there are differences with the main task.

U+007C Model (img size) U+007C Precision U+007C Recall U+007C Latency (s)* U+007C
U+007C Three output neurons (256) U+007C 0.993 U+007C 0.855 U+007C 0.027 U+007C
U+007C Three output neurons (320) U+007C 1.0 U+007C 0.846 U+007C 0.029 U+007C
U+007C Two heads (256) U+007C 1.0 U+007C 0.873 U+007C 0.029 U+007C

Latency (s)* — mean inference time on 1 image, including transforms and softmax.

And here is the boost we needed! Two-headed model has the same scores for the secondary task, but on the main task it has same or better precision and higher recall. And this is a with real world data (not from train/val/test splits).

Note: For this task not only there is a more important task (document/screen) but also precision is more important than recall, so in ‘Three output neurons’ approach input size 320 wins. But in the end, two-headed model still gets better scores with the same inference time.

One more important thing. This approach worked better in my case with specific model and data. It worked for me in some other tasks too, but it is critical to always create hypothesis and run experiments to test them out and come up with the best approach. For that, I recommend using tools to save your configs and experiment results. Here I used Hydra for configs and Wandb for tracking experiments.

To sum up

  • Classification is easy, but it gets harder with all real world constrains
  • Optimize subtasks and try not to create K models for every big task
  • Customize models and training pipelines to have a better control
  • Test your hypothesis, run experiments and save results (hydra, wandb…)

And that’s basically it, you can find full code example here, so you can run tests yourself. Feel free to contact me if you have any questions or suggestions!

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI

Feedback ↓