The Role of Dataset Classes in Transfer Learning
Last Updated on July 24, 2023 by Editorial Team
Author(s): Akula Hemanth Kumar
Originally published on Towards AI.
Making computer vision easy with Monk, low code Deep Learning tool and a unified wrapper for Computer Vision
What do you do with a deep learning model in transfer learning?
These are the steps already done by contributors in pytorch, keras and mxnet
- You take a deep learning architecture, such as resnet, densenet, or even custom network.
- Train the architecture on large datasets such as Imagenet, coco, etc.
- The trained weights become your starting point for transfer learning.
The final layer of this pre-trained model has a number of neurons = number of classes in the large dataset
In transfer learning
- You take the network and load the pre-trained weights on the network.
- Then remove the final layer that has the extra(or less) number of neurons.
- You add a new layer with a number of neurons = number of classes in your custom dataset.
- Optionally you can add more layers in between this newly added final layer and the old network.
Now you have two parts in your network
- One that already existed( the pre-trained one, the base network).
- The new sub-network or a single layer you added.
The hyper-parameter we can see here: Freeze base network
- Freezing base network makes the base network untrainable.
- The base network now acts as a feature extractor and only the next half is trained.
- If you do not freeze the base network the entire network is trained.
Here we have 2 datasets
- Cats-Dogs dataset having 2 classes.
- Logo classification having 16 classes.
Creating and managing experiments
- Provide project name
- Provide experiment name
This creates files and directories as per the following structure
workspace
U+007C--Project
U+007C--study-num-classes U+007C U+007C--experiment-state.json U+007C U+007C--output U+007C U+007C--logs (All training logs and graphs saved here) U+007C U+007C--models (all trained models saved here)
Setup Default Params with Cats-Dogs dataset
gtf.Default(dataset_path="study_classes/dogs_vs_cats",
model_name="resnet18",
num_epochs=5)
Visualize network
gtf.Visualize_With_Netron(data_shape=(3, 224, 224), port=8081)
The final layer
Reset Default Params with a new dataset β Logo classification
gtf.Default(dataset_path="study_classes/logos",
model_name="resnet18",
num_epochs=5)
Visualize network
gtf.Visualize_With_Netron(data_shape=(3, 224, 224), port=8082)
The final layer
You can find the complete jupyter notebook on Github.
If you have any questions, you can reach Abhishek and Akash. Feel free to reach out to them.
I am extremely passionate about computer vision and deep learning in general. I am an open-source contributor to Monk Libraries.
You can also see my other writings at:
Akula Hemanth Kumar – Medium
Read writing from Akula Hemanth Kumar on Medium. Computer vision enthusiast. Every day, Akula Hemanth Kumar andβ¦
medium.com
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI