If you are a computer scientist, probably you would call a task supervised learning what others might call classification based on data. Data being anything that has a well defined structure and associated classes. We will work on hands on example for multi-class prediction with logistic regression i.e. multinomial logistic regression (MNL). What I meant by multi-class is here that we have $k \in \bf{Z}$ distinct classes for each observation, $k>1$. Actually if you think in terms of link functions from Generalized Linear Models (GLMs), the support of the distribution will tell you the distinction of the nature of the class. In statistics literature classes can be manifest as factors (or categories).

The following pointers would be helpful before you read further. From R perspective, I'd suggest
German Rodrigez's notes for background reading. Data mining blog by Will Dwinnell is one of the clear descriptions in this direction for MATLAB. For sure, an authoritative resource on this is the book by Hastie et. al. called elements of statistical learning, you can obtain it from their Stanford page. An other excellent resource from Professor Ripley of Oxford, his books and R packages are known to be de facto standard in the field, particularly here I refer to nnet and applied statistics book.

Recall that a
design matrix $\bf{X}$ is noting more then set of observations ($n$) in the rows. Observables ($p$) are placed along columns. The idea of supervised learning is that we train our model, here we choose to be multinomial GLM, against a data set and obtain coefficients of the model (parameters). Let's do this with the simplest possible example. (R codes start with '>' and MATLAB '>>')

Statistical toolbox brings set of functionality to do multinomial logistic regression. mnrfit is the main function we would like to use. Lets use a simple data set, it is trivial and the statistics we would get from this data might be very poor, but we aimed at the concept in this post.

>> X = [0.0 0.1 0.7 1.0 1.1 1.3 1.4 1.7 2.1 2.2]'; % design matrix

>> Y = [1 2 1 3 1 2 1 3 1 1]'; % associated classes

We can split this set to training and validation data set by choosing

>> trainIndex = [2 8 10 4 5];

>> validIndex = [1 3 6 7 9];

Now we can obtain the coefficients via mnrfit.

>> betaHat = mnrfit(X(trainIndex), Y(trainIndex), 'model', 'ordinal', 'interactions', 'off', 'link', 'logit');

Note that, model is ordinal, meaning that it takes discrete values. Switched off interactions generates only one coefficient per class (betaHat vector). Then we can get the probabilities for the validation set.

>> predictProbs=mnrval(betaHat, X(validIndex), 'model', 'ordinal', 'interactions', 'off', 'link', 'logit');

ans =

0.2980 0.1958 0.5061

0.3636 0.2041 0.4324

0.4242 0.2045 0.3713

0.4346 0.2039 0.3615

0.5084 0.1954 0.2961

So the highest probabilites form this matrix will give us [3, 3, 1, 1, 1] predictive set. With this approach we have predicted last two classes correctly.

Let's do this same example with R, even though the following approach may not be exactly the same procedure described above, however this is multinomial as well:

> require(nnet)

> # Design Matrix: 10 observations

> X = matrix(c(0.0, 0.1, 0.7, 1.0, 1.1, 1.3, 1.4, 1.7, 2.1, 2.2));

> # Corresponding Classes: 3 ordinal classes

> Y = matrix(c(0, 1, 0, 0,

0, 0, 1, 0,

0, 1, 0, 0,

0, 0, 0, 1,

0, 1, 0, 0,

0, 0, 1, 0,

0, 1, 0, 0,

0, 0, 1, 0,

0, 1, 0, 0,

0, 1, 0, 0

), nrow = 10, ncol=4, byrow=TRUE)

Similarly we can obtain training and validation sets

># Generate training and validation data sets for X, Y

> trainIndex = c(2, 8, 10, 4, 5);

> validIndex = c(1, 3, 6, 7, 9);

Now, we can use nnet package fitting function for multinomial log-linear.

> mfit = multinom(formula = Y[trainIndex, ] ~ X[trainIndex] + 0)

Note that we put 0 for the intercepts while first column of Y is dummy.

> predict(mfit, X[validIndex])

[1] 2 2 2 2 2

Resulting classification is not that great. The above example shown a basic approach to supervised learning in MATLAB and R. However, one must check statistical values, like residuals, deviance, p-values etc. for the quality of results.

## No comments:

## Post a Comment