xgboost: How does XGBoost generate class probabilities in a multiclass classification problem?
When I try to dump a trained XGBoost tree, I get it in the following format:
0:[f37<39811]
1:[f52<199.5]
3:leaf=-0.021461
4:[f2<0.00617284]
9:leaf=-0.0118755
10:[f35<-83.548]
19:[f13<0.693844]
23:[f37<1831]
29:leaf=-0
30:leaf=0.123949
24:leaf=0.0108628
20:[f35<-74.6198]
25:leaf=-0.0175897
26:leaf=0.051898
2:leaf=-0.0239901
While it is clear what the tree structure is, it is not clear how to interpret the leaf values. For a binary classification problem and a log-loss cost function, the leaf values can be converted to class probabilities using a logistic function: 1/(1+exp(value)). However, for a multiclass classification problem, there is no information on which class those values belong to, and without that information, it is not clear how to calculate the class probabilities.
Any ideas? Or is there any other function to get those information out of the trained trees?
About this issue
- Original URL
- State: closed
- Created 8 years ago
- Reactions: 3
- Comments: 15 (1 by maintainers)
That’s a very good observation, which I hadn’t noticed before. It’s true, the number of trees in the model is n_estimator x n_class. However, to figure out the order of the trees, I used the following toy example:
which basically creates a training dataset where [1,0,0] is mapped to ‘a’, [0,1,0] is mapped to ‘c’, and [0,0,1] is mapped to ‘b’. I intentionally switched the order of ‘b’ and ‘c’ to see if xgboost sorts classes based on their labels before dumping the trees.
Here is the resulting dumped model:
and here are the conclusions:
numpy.unique()returns - here it is alphabetical, and that’s why the first tree is using f0 while the second is using f2.It would be nice if these information is added to xgboost documentation.
For the example in my early comment, if I try to predict the class probabilities for [1 0 0]:
it will generate these results:
There is no way for
1/(1+exp(-value))to generate this. The only way is that these probabilities are generated by a softmax function of the summed values across the classes:where i is the target class (out of N classes), and val[i] is the sum of all values generated from the trees belonging to that class. In our example:
will generate:
which is exactly what the predict_proba() function gave us.
@mlxai I also only tried the Python API and I don’t remember I was able to get them per class.
However, in my opinion, the definition of feature importance in XGBoost as the total number of splits on that feature is a bit simplistic. Personally, I calculate feature importance by removing that feature from the feature set and calculating the resulting drop (or increase) in accuracy, when the model is trained and tested without that feature. I do this for all features, and I define “feature importance” as the amount of drop (or minus increase) in accuracy. Since training and testing can be done multiple times with different train/test subsets for each feature removal, one can also estimate the confidence intervals of feature importance. Plus, by calculating importance separately for each class, you get feature importances per class. This has produced more meaningful results in my projects than counting the number of splits.
I also have the same question. Only thing that I noticed is that if you have 20 classes and set number of estimators to 100, then you’ll have 20*100=2000 trees printed in your model. Now my guess was that first 100 estimators classify first class vs others. Next 100 estimators classify second class vs others etc.
Can’t confirm but maybe it could be something like this?
@sosata your example is very clear. Thank you very much!
@chrisplyn The predict instance [1,0,0] will be classified by booster[0~5], and get the leaf values [0.122449, -0.0674157, -0.0674157, 0.10941, -0.0650523, -0.0650523] respectively. For booster[0, 3] belongs to class 0, booster[1, 4] belongs to class 1, and booster[2, 5] belongs to class 2, so val[0] =(0.122449 + 0.10941), val[1] = (-0.0674157 + -0.0650523), val[2] = (-0.0674157 + -0.0650523). Is it more clear?