Skip to main content

A General Purpose Transfer Learning Framework Based on Keras


Objective

Transfer Learning technique is used when the dataset is not of sufficient size. It is common to fine tune a network which is pre-trained on large datasets like Imagenet for classification tasks. For further information, the reader is advised to refer to CS231N by Standford.

A framework for general purpose transfer learning is proposed. This framework is developed for my MSc. thesis study and made publicly available to let researchers make use of it.

Using this framework, the researcher will easily be able to fine-tune a network for a classification task.

Audience

This article can be useful for anyone seeking information about transfer learning implementation.

Python knowledge is required to make use of the supplied code. Introductory information about Keras, a deep learning API, is necessary.

Also in order to run the code, a proper deep learning system with a decent graphics card (GPU) with CUDA Compute Capability 3.0 or higher is necessary is necessary.

Information on how to create a deep learning system can be found in one of the following resources:


Why Keras?

Keras is a deep learning API which can use be used with Tensorflow, Theano or CNTK. Keras introduces a simple and intuitive API. It is easy to find resources about Keras. Keras comes with network weights for popular convolutional neural networks.

Why a Transfer Learning Framework?

Keras already provides a simple and intuitive interface for transfer learning. For more information, Keras Applications page worths visiting.

However, doing a research requires more than what Keras provides. The proposed Transfer Learning Framework aims to eliminate boilerplate code for researchers. To give an idea, researchers need to run many tests with different parameters. Often it becomes hard to keep track of experiment results and configurations. Keras does not come with tools to visualize experiment results.

The Transfer Learning Framework tries to provide configuration options to execute different experiments without requiring code changes.

User Guide

A. Fine Tune A Model

Let's look at how a model is fine-tuned. For a demo, VGGFace model is fine-tuned for emotion recognition using FER-13 dataset.

Input parameters are the location where the dataset is, weights file to finetune and working directory to put outputs.

Dimension is determined by the network to fine-tune. For VGGFace model it is 224X224. Batch size is determined by image dimensions and GPU memory available. The larger the batch size the faster the learning process will be. Learning rate is better left at 0,001. The optimizer can be left as stochastic gradient descent. The number of layers to train during training is also important.

All configurations which are likely to affect model performance are recorded in a parameter.txt file.

from keras.optimizers import SGD
from work_image.abstract_train_cnn_base import TrainCnnConfiguration
from work_image.emot_train_cnn_base import EmotTrainCnnBase
# working directory
WORK_DIR = 'workdir/'
# specify weights file. It can be a path to a trained model or it can be one of vgg16,
# resnet or inception for imagenet weights or vggface for oxford face weights.
# vggface is based on https://github.com/rcmalli/keras-vggface
WEIGHTS_FILE = 'vggface'
# point to the dataset. Note that target dir must have data.csv
DATA_DIR = 'dataset/fer2013/'
BATCH_SIZE = 64 # adjust according to available GPU memoru
DIMENSION = 224
LR = 0.001
OPTIMIZER = SGD(lr=LR, momentum=0.9, decay=0.0005)
# img_train_gen_params={'rotation_range':20,'zoom_range':0.2,'horizontal_flip':True}
img_train_gen_params = None
config = TrainCnnConfiguration(data_dir=DATA_DIR, batch_size=BATCH_SIZE, dimension=DIMENSION, optimizer=OPTIMIZER,
weights_file=WEIGHTS_FILE, freeze_layer=12, reduce_lr_factor=None, reduce_lr_patience=5,
img_train_gen_params=img_train_gen_params, per_class_log=False, top_epochs=5)
train = EmotTrainCnnBase(work_dir=WORK_DIR, config=config)
train.train(nb_epoch=100)
view raw train_image.py hosted with ❤ by GitHub
All outputs can be found in an experiment directory which is created under the working directory. The experiment directory is named using hash value of the parameters.txt file and timestamp. This way it is easy to track multiple experiments with the same configuration.

The most important output file is the weights file which can be found under checkpoints directory. The weights files also contain information about network so they are enough to use the trained network later.

Validation results are kept in a results.txt file. Results file starts with the confusion matrix. Results row, results file contains loss, accuracy, precision and recall values. At the last raw, the number of validation samples used is printed.

Logs directıry keeps the log of validation and training accuracy/loss values after each epoch. Before fine-tuning a model, a short fine-tuning step is run of which the length is specified by a top_epochs parameter. This is why you will find two logs files.

Confusion matrices are created in the experiment directory. loss and accuracy curves are created in the same experiment directory. Accuracy and Loss curves are divided into two sections by a verticle red line. That line shows the boundary between short and actual fine-tuning procedures.

As an example, GitHub repository can be examined.

https://github.com/habanoz/deep-emotion-recognition/tree/master/thesis/models/vggface/1-1512579280.12-17403535-best

Training/Validation Accuracy Curve
Normalized Confusion Matrix

B. Fine Tune Using Weights Files

Suppose you want to fine-tune again a model you created with another dataset. It is as easy as pointing the location weights file in the WEIGHTS_FILE variable. Also, point to the new dataset in DATA_DIR variable. That's all.

C. Training For Other Problems

Extend the class KerasTrainCnnBase class as in the case of EmotTrainCnn class for a new problem. The only major difference will probably be the number of classes.

D. Changing the Underlying Network

Currently, InceptionV3, ResNet50, VGG16 and VGGFace networks supported. Model weights are for Imagenet dataset except for the VGGFace mode which is trained using Oxford Face dataset. InceptionV3, ResNet50 and VGG16 weights are provided by Keras. VGGFace weights are provided by https://github.com/rcmalli/keras-vggface.

E. Processing Input Images

Input images are not always ready for training and may require some processing. FER-13 dataset is ready for training. But others may not be ready. extract and util.dataset_scripts packages provide utilities for image processing needs. In order to utility classes on datasets, it is required to have a data_file.csv file at the dataset root directory. Training and Validation images should be put inside train and val directories respectively. 

util.dataset_scripts package contains structure scripts for creating data_file.csv for a dataset. Note that the structuring process is special to each dataset. If there is not a structuring script for your dataset in the package you will need to add it. Present structuring scripts can be used as examples.

E.1. Face Alignment

In all the images, the position of the eyes and nose should be at the same location. Without alignment, training may require more data. 

Datasets may require face alignment before training. If multiple datasets will be used, they should be sharing the same face alignment. 

extract.extract_faces.py contains ExtractFaces class which can be used to extract and align faces from source images. For face extraction and alignment MtCnnDetector from https://github.com/pangyupo/mxnet_mtcnn_face_detection.

E.2. Pre-Processing

It is important to prepare source images for training. extact.preprocess.PreporcessFaces can be used to pre-process aligned face images. It applies grayscale conversion, histogram equalization and resizes using cubic interpolation. 

Conclusion

General Purpose Transfer Learning framework can easily be used for fine-tuning tasks. It expects minimal parameters and creates as much as possible relevant outputs.


Comments

Popular posts from this blog

Obfuscating Spring Boot Projects Using Maven Proguard Plugin

Introduction Obfuscation is the act of reorganizing bytecode such that it becomes hard to decompile. Many developers rely on obfuscation to save their sensitive code from undesired eyes. Publishing jars without obfuscation may hinder competitiveness because rivals may take advantage of easily decompilable nature of java binaries. Objective Spring Boot applications make use of public interfaces, annotations which makes applications harder to obfuscate. Additionally, maven Spring Boot plugin creates a fat jar which contains all dependent jars. It is not viable to obfuscate the whole fat jar. Thus obfuscating Spring Boot applications is different than obfuscating regular java applications and requires a suitable strategy. Audience Those who use Spring Boot and Maven and wish to obfuscate their application using Proguard are the target audience for this article. Sample Application As the sample application, I will use elastic search synch application from my G...

Hadoop Installation Document - Standalone Mode

This document shows my experience on following apache document titled “Hadoop:Setting up a Single Node Cluster”[1] which is for Hadoop version 3.0.0-Alpha2 [2]. A. Prepare the guest environment Install VirtualBox. Create a virtual 64 bit Linux machine. Name it “ubuntul_hadoop_master”. Give it 500MB memory. Create a VMDK disc which is dynamically allocated up to 30GB. In network settings in first tab you should see Adapter 1 enabled and attached to “NAT”. In second table enable adapter 2 and attach to “Host Only Adaptor”. First adapter is required for internet connection. Second one is required for letting outside connect to a guest service. In storage settings, attach a Linux iso file to IDE channel. Use any distribution you like. Because of small installation size, I choose minimal Ubuntu iso [1]. In package selection menu, I only left standard packages selected.  Login to system.  Setup JDK. $ sudo apt-get install openjdk-8-jdk Install ssh and pdsh, if...

Java Thread States

Java Threads may have 6 states: new , runnable , terminated , blocked , waiting , timed_waiting . When a thread is created it is in new state. When start method of thread is called it enters runnable state. Runnable state has two inner states: ready and running . If thread is eligible for execution it is said to be ready, if it is executing it is in running state. Remember calling start method on a already started thread will raise IllegalThreadStateException. When thread finishes its execution it enters into terminated state. When a thread is trying to access a resource, a synchronized statement for example, and it is not available, lock of the object is already acquired for example, it is blocked and said to be in blocked state. When lock is released an thread has chance to acquire lock it goes back to runnable state. When a thread calls join or wait method it enters into waiting state. When joined thread finishes or for wait method notify/notifyAll metho...