Kaggle — Predict survival on the Titanic challenge in MATLAB

In this tutorial, I will demonstrate how to use MATLAB to solve this popular Machine Learning challenge

Abhishek Das
Towards Data Science

--

Photo of the RMS Titanic departing Southampton on April 10, 1912 by F.G.O. Stuart, Public Domain

The objective of this Kaggle challenge is to create a Machine Learning model which is able to predict the survival of a passenger on the Titanic, given their features like age, sex, fare, ticket class etc.

The outline of this tutorial is as follows:

  1. Data Analysis
  2. Feature Engineering
  3. Model Fitting
  4. Prediction on the test set
  5. Conclusion
  6. Future Directions

Why MATLAB?

In the past few months, I have explored the Machine Learning Toolbox in MATLAB for solving various tasks like Regression, Clustering, PCA etc. I liked its modular interface and that every function is well documented with examples, which becomes very handy. Thus, in this post, I have shared some of my findings , which I hope is helpful to the academic community.

Let’s begin …

1. Exploratory Data Analysis

Reading the dataset

>> Titanic_table = readtable('train.csv');
>> Titanic_data = (table2cell(Titanic_table));

The train set has 891 passenger entries and 12 columns.

Now, lets take a look into our data. The head function displays the top rows of a table, similar to that in Pandas.

>> head(Titanic_table)

The ‘Survived’ column is our binary target variable which we need to predict; where 0- Not Survived, 1- Survived.

The predictor variables i.e. features are as follows-

  • Pclass: Ticket class
  • Sex
  • Age
  • SibSp: number of siblings, spouses along with the passenger
  • Parch: number of parents, children along with the passenger
  • Ticket: Ticket number
  • Fare
  • Cabin: Cabin number
  • Embarked: Port of Embarkation

‘Pclass’ variable is indicative of the socio-economic status. It has 3 possible values ‘1’, ‘2’ , ‘3’ representing the Upper, Middle and Lower class respectively.

The probability of survival for passengers in the training set is around 38 %

Now, Let’s visualize survival based on different features-

  • There are 577 males and 314 female passengers.
  • Females have a much higher survival probability of 74.20 % as compared to males which is 18.90 %
  • We see that the passengers with a higher ticket class are more likely to have survived.
  • Children have a very high survival rate, while the rate drops to around 10 % for people above 65 years of age.
  • The initial statistics are in lines with our intuition that the strategy must have been to save the lives of women and children first.

2. Feature Engineering

To transform the ‘sex feature, I have used grp2idx to convert the categorical values ‘male’ and ‘female’ to numerical indices ‘1’ and ‘2’ respectively.

The transformation of the ‘age’ feature has 2 steps-

  • Filling missing entries-
% Check number of missing entries in the 'age' feature>> sum(isnan(age))
ans =
177

We will impute the missing entries in the ‘age’ feature with the mean age of all other passengers. i.e. 29.61 years. There are some other ways in which we can handle these missing entries, discussed in section 6.

  • Categorizing into bins-

MATLAB provides a discretize method which can be used to group data into bins or categories. I have used this method for categorizing age into the following 4 age groups — ‘Under 15’, ‘15–30’, ‘30–65’ and ‘Above 65’.

3. Model Fitting

We will see 4 classifiers namely-

  • Logistic Regression
  • Decision Trees
  • K Nearest Neighbor
  • Random Forest

Let’s begin with…

Logistic Regression

In MATLAB, we can implement a logistic regression model using the fitglm method.

Decision Trees

fitctree function returns a fitted binary classification decision tree for a given set of predictor and response variables. We can visualize our decision tree using the view method, thus providing an easy interpretation.

Pruning decision trees is an efficient strategy to combat overfitting. I have used the cvloss function to find the ‘BestLevel’ for pruning. The prune function returns the tree pruned to this level.

K Nearest Neighbors

The ClassificationKNN.fit function is used to fit a k-nearest neighbor classifier.

  • I have performed a search to find the optimal number of neighbors (N).
  • The above graph shows that N=17 gives the lowest Cross Validation Loss. Thus, we may select N=17. However, for N=5 and N=11, we get similar loss. So, we can experiment with N=5, 11, and 17 keeping in mind the model complexity and the chances of overfitting with a large value of N.
  • MATLAB has a range of distance metric to choose from. for e.g. ‘cityblock’, ‘chebychev’, ‘hamming’, etc. The default is the ‘euclidean’ distance metric.

Random Forest Classifier

Individual decision trees tend to overfit. The random forest approach is a bagging method where deep trees, fitted on bootstrap samples, are combined to produce an output with lower variance.

We will use the TreeBagger class to implement a random forest classifier.

  • The above plot shows the out-of-bag classification error as we increase the number of trees in consideration. This allows us to select the optimal number of decision trees to create an ensemble.
  • We can visualize the importance of each feature using the OOBPermutedPredictorDeltaError property.

The training accuracies for our models are as follows-

  • Logistic Regression: 78.34 %
  • Decision Trees: 80.47 %
  • K Nearest Neighbor: 79.80 %
  • Random Forest: 78.11 %

We see that the pruned decision tree gives the highest accuracy in this case, however it is prone to overfitting. Let’s find out the accuracy on the test set.

4. Generate Test Predictions

The test set comprises of 418 passenger entries and 11 columns without the ‘Survived’ variable, which we will predict using the trained models.

>> head(Titanic_table_test)

We apply the same feature engineering steps on the test dataset and feed it through the models described above to generate the predictions. Finally, we create a .csv file for submission to Kaggle. The snippet shows an example for the same.

Decision Tree gives the highest accuracy of 78.947 % on the test set.

5. Conclusion

We saw a step-by-step process in MATLAB to solve a Machine Learning task starting from visualizing the patterns in the dataset, selecting and engineering features, training multiple classifiers to performing prediction using the trained models.

Some observations-

  • Increasing number of bins in the ‘age’ feature led to overfitting.
  • Adding a new feature calculated as (fare/ticket frequency) did not give a significant improvement in the test accuracy.

The codes are available for future development here.

6. Further Exploration

  • Other strategies can be applied to handle the missing entries in the ‘age’ feature such as using Median age, or Family relations. for e.g. A child’s age can be approximated to be say (age of father-25).
  • Including more combination of features is naturally the next step which I look forward to analyze.
  • One can engineer new features based on initial data statistics and intuition. Here is an interesting analysis to read.
  • Since we have multiple models, model ensembles is worth giving a try.

There are many such ideas discussed on this challenge by the Kaggle community. This notebook provides a great summary for the same.

--

--

I am a Masters Student in Electrical and Computer Engineering at Carnegie Mellon University | Know more about me — https://abhishek0697.github.io 😁