3.7 – Putting it all together: MNIST digit recognition

We’ve talked a lot about the math and the bits and pieces of how things work. Let’s now actually do something interesting. A good “Hello World” program in machine learning is recognizing handwritten digits. The MNIST (Modified NIST) dataset is pretty much the de facto standard for this classification task.We’ll use the original dataset, which has data of 70,000 images, each of which is scaled down to 28×28 pixels. This results in us having 784 features. While these numbers seem big at first, these are typical for computer vision problems.

Let’s use the multinomial logistic regression model we just learned for this task. We won’t use our hand-implemented version, because the dataset is too large to use a model that isn’t tuned for performance the way libraries are. However, we’ll try making sure we know what the library is doing.

Let’s start by importing the libraries that we need.

from sklearn.datasets import fetch_mldata
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

Let’s now use the fetch_mldata function to get the MNIST dataset. This downloads the data from mldata.org. Then, we get our X and Y from the dataset. If this is your first time downloading the data, it’ll take a while to run because it’s downloading the data behind the scenes.

mnist = fetch_mldata('MNIST original')
X = mnist.data.astype('float64')
Y = mnist.target

As usual, we’ll split the dataset 70-30:

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.7)

Let’s now create the multinomial logistic regression (softmax regression) model. Recall that for multi-class classification, this is the default method used by the library, and not one-vs-rest. Also, the library by default uses L2 regularization. We won’t change any of those parameters. We then fit the model to our dataset. This will take a while because of the size of the dataset. In my system, I waited for about an hour before going to sleep. I’m not sure how long it took, but it might take less or more time depending on your system. If you don’t want to wait, I’ve also saved the trained model to a file that you can download. I’ll briefly discuss how you can load and use this model as well.

model = LogisticRegression()
model.fit(X_train, Y_train)

If you’re using the pre-trained model, use this instead:

import pickle
f = open('model.pkl', 'rb')
model = pickle.load(f)

Let’s use this to predict on the test set and see the results:

from sklearn.metrics import classification_report, accuracy_score
predictions = model.predict(X_test)
print(accuracy_score(predictions, Y_test))

This gives us an accuracy of 91.17%, which is great for a first model. Let’s look at other metrics. It’s okay if you don’t know what these mean, I’ll cover them in an upcoming post.

print(classification_report(predictions, Y_test))

This gives:

             precision    recall  f1-score   support

        0.0       0.97      0.95      0.96      2121
        1.0       0.97      0.94      0.96      2443
        2.0       0.88      0.91      0.89      2001
        3.0       0.89      0.89      0.89      2102
        4.0       0.92      0.92      0.92      2108
        5.0       0.86      0.89      0.88      1814
        6.0       0.96      0.94      0.95      2129
        7.0       0.92      0.93      0.93      2134
        8.0       0.86      0.87      0.86      2061
        9.0       0.88      0.87      0.88      2087

avg / total       0.91      0.91      0.91     21000

These are pretty good scores for a first attempt. We’ll use more sophisticated models later to get better results, but crossing 90% on the first try is great!

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s