Using Deep Learning to Classify Heart Defects
by Audrey Wiseman
Although many patients trust a diagnosis, or lack thereof, given by a doctor, doctors cannot always perfectly read the tests or scans they are presented with. A misdiagnosis can lead to confusion regarding symptoms, but in the worst case scenario can cause death if the true disease is left untreated. A recent study showed that clinician accuracy ranges from 58-85% depending on the experience levels of the physicians (Satia). In this project, I aim to classify and locate heart a variety of heart defects in chest x-rays with better accuracy than most clinicians.
Overview
I. Introduction
II. Data
III. Class Imbalance and Loss Function
IV. Keras ImageDataGenerator
V. Model
VI. Results
VII. Next Steps
VIII. Conclusion
Introduction
When I discovered NIH’s Chest-XRay8 dataset, I was impressed by the amount of data they had collected and was inspired to apply Computer Vision to solve the task of diagnosing heart disease. The project seemed to provide ample opportunity for hands-on experience with deep learning while also incorporating biotechnological elements that interested me. Despite the undesirable results, this project helped me better understand real-world issues that arise when working with Artificial Neural Networks and Deep Learning.
Data
The data I used to train the model was published by NIH in 2017 and contains over 112,000 chest x-rays. This dataset contains 15 classes, or defects, and more than one defect can be present in each scan. These classes include Atelectasis, Cardiomegaly, Effusion, Infiltration, Mass, Nodule, Pneumonia, Pneumothorax, Consolidation, Edema, Emphysema, Fibrosis, Pleural Thickening, Hernia, and No Findings. The dataset also includes coordinates for the bounding boxes of some of defects present in the images (but not all).
Class Imbalance and Focal Loss
A major problem I faced with this dataset was the significant imbalance between classes. Because some diseases are more prevalent than others, the model had difficulty breaking the habit of choosing the most common label on each image. It is understandable that this dataset would be imbalanced because some diseases are more common than others; however, a significantly imbalanced dataset such as this one can cause the model to over-predict some classes and under-predict others.
Due to the major class imbalance that this problem contains, a focal loss was necessary to help balance the problem. The Sigmoid Focal Cross Entropy Loss that I applied to the model is not tempted to classify images based on the number of times that classification has been correct in previous predictions. For example, if the model has only seen a thousand cases of a ‘Pneumothorax’ and is then given a ‘Hernia’, the focal loss will be considerate of this and will not necessarily predict ‘Pneumothorax’.
Keras ImageDataGenerator
Another problem I encountered while working on this project was the amount of RAM I was using to load in the data. Because I was loading in all the training and testing images at one time, I quickly learned that this was not going to work. As an alternative, I decided to use Keras’s ImageDataGenerator which loads in the data a batch at a time during training. This makes it so large amounts of data are not loaded in at once and therefore no memory issues are encountered.
Model
The model used to predict the defects in each image is a Convolutional Neural Network (CNN) that uses DenseNet121 as a feature extractor before a dense layer, followed by a dropout layer and the final classification layer. Sigmoid is the activation function applied in each dense layer. Adam is used as the optimizer with a learning rate of 1e-5.
I included dropout in my model because it helps reduce overfitting. Dropout layers remove certain connections between one layer and the next based on a given probability (for my model, I chose 0.2). This helps lower the likelihood of overfitting because it makes the model more robust by removing connections it may have ‘memorized’.
I also decided to apply transfer learning on this problem because it would save time on training a large CNN and compensate for the small amount of data I was using to train. Transfer learning is the practice of using a model developed for one problem as a starting point for a second problem. It helps save time because the layers are pre-trained and do not have to be trained again. It also helps on problems with less training data because the pre-trained model has already been trained on thousands of other images. Some common models used for transfer learning are VGG16, ResNet, and DenseNet. I chose DenseNet for this specific problem because it performs well on ImageNet classification. I made the last four layers of DenseNet trainable and all other weights were frozen, meaning the last four layers’ weights were able to be adjusted to fit this specific classification problem.
# load model without classifier layers
model = DenseNet121(include_top=False, input_shape=(1024, 1024, 3), weights='imagenet', pooling='max') # add new classifier layers flat1 = Flatten()(model.layers[-1].output)
class1 = Dense(1024, activation='relu')(flat1)
dropout = Dropout(0.2)(class1)
output = Dense(15, activation='sigmoid')(dropout) # define new model
model = Model(inputs=model.inputs, outputs=output) for layer in model.layers[:-4]:
layer.trainable = False
Results
Unfortunately, there are still problems with the class imbalance and therefore, the scores are not very impressive. The AUC score increases slowly but does not drastically improve between the first epoch and the last.
I used the Area Under Curve (AUC) metric to measure how well the model was performing. AUC was an appropriate metric for this problem because it measures performance across all classifications. The AUC depicts the probability that a random positive example is ranked more highly than a random negative example (Google Developers). The ideal AUC score is 1.0, and if the model is 100% incorrect, the AUC is 0. Accuracy is not a good metric for this project because it is not well-suited for multi-label problems. For example, if the model predicts one out of two labels present, the accuracy would consider this prediction entirely incorrect.
The final scores for the validation data were an 25.52% loss, 50.64% accuracy and 54.03% AUC. The best reported accuracy for this problem was roughly 70% so hopefully future changes will improve these results.
Next Steps
Although I am still struggling with the class imbalance of this dataset, there are a few approaches that I can try to get past this problem. The first thing I want to do is balance the current dataset. Because there are roughly 112,000 images available, there are enough images to have roughly the same amount of instances for each label. This will involve some analysis of the dataset to see what classes are more or less prominent, but hopefully it will help the model learn more quickly. In addition to balancing the dataset, I would also like to do some data augmentation in order to get more images to train on. This would also make the model more robust and reduce overfitting. The last change I would like to make is the type of network I am using as a feature extractor. I am currently using DenseNet but perhaps ResNet or VGG16 are better suited for this problem.
After I solve the issues I am facing with the class imbalance, I would also like to address the second part of the intended use of the dataset by the predicting bounding boxes that locate each defect. The only issue with this is the fact that not all images contain bounding boxes so there will be less data to train on.
Conclusion
Although I did not see the results I wanted for this project, I appreciate all of the practical lessons that it has taught me. In addition to what I anticipated on working with, such as CNNs, transfer learning, and regularization, I also ran into challenges/opportunities that I would not have encountered until working on this project. Because of this, I feel as though it has helped me better understand what working with Deep Learning is like in practice. The first road block that helped grow my knowledge in deep learning was the issue I had while loading in data. Because I had not used Keras before, I did not know about how the ImageDataGenerator loads in data by batch and makes the dataloading process much easier. Now that I have seen how it is used as well as the type of problems it can solve, I will definitely keep it in mind for future projects.
The second element of this project that taught me about deep learning in practice was the significant imbalance in the dataset. Although I was aware of the differences in the amount of each label present in the dataset, I did not realize how much of an impact it would have on the project as a whole. This element of the project has showed me the significance of a well-rounded dataset. If the labels had been more evenly distributed, perhaps this would be a much simpler classification problem. Although I wish this project had produced better results, the obstacles I faced along the way provided me with a much more practical understanding of deep learning.
Sources
Satia, Imran, et al. “Assessing the accuracy and certainty in interpreting chest X-rays in the medical division.” Royal College of Physicians, 2013.
“Machine Learning Crash Course.” Google Developers. https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc#AUC
All code can be accessed here.