scikit-learn: Trees with MAE criterion are slow to train

Description

when I use ‘mae’ criterion for the model extratreesregressor, training for a long time, it’s seems lead to an endless training. there have no problem for mse I find not only me hava this problem. I hava tried two version (0.18 and 0.19.X) , but no used .

https://www.kaggle.com/c/allstate-claims-severity/discussion/24293

Steps/Code to Reproduce

from sklearn.ensemble import ExtraTreesRegressor
rfr = ExtraTreesRegressor(n_estimators=100,
max_features=0.8,  
criterion='mae', 
max_depth=6, 
min_samples_leaf=200,
 n_jobs=-1,
 random_state=17, 
verbose=0)
 mod = rfr.fit(train[distilled_features], train['y'])

Expected Results

can normal training in my model when mae used

Actual Results

aways traing in fit step

Versions

Darwin-16.1.0-x86_64-i386-64bit Python 3.6.1 |Anaconda 4.4.0 (x86_64)| (default, May 11 2017, 13:04:09) [GCC 4.2.1 Compatible Apple LLVM 6.0 (clang-600.0.57)] NumPy 1.12.1 SciPy 0.19.0 Scikit-Learn 0.19.X

About this issue

  • Original URL
  • State: open
  • Created 7 years ago
  • Reactions: 26
  • Comments: 61 (40 by maintainers)

Most upvoted comments

I noticed this is currently still an issue - my training with “mae” as criterion does not finish (Grid Search with GradientBoosting Regression Trees. I spent a lot of time trying to debug what was wrong before stumbling on this thread. This is why I would propose adding a warning in the documentation (e.g. “training with ‘mae’ as criterion means recalculating the loss function after each iteration and can take an extremely long time”) to prevent future coders getting stuck at the same problem.

Observation

When using criterion='absolute_error', most of the time is spent pushing, popping or removing element from the WeightedMedianCalculator’s WeightedPQueue: https://github.com/scikit-learn/scikit-learn/blob/2f65ac764be40e2420817cbeca3301c8d664baa3/sklearn/tree/_utils.pyx#L78-L95

This slows down the execution because data-structures are being resized by utilities functions for memory management: https://github.com/scikit-learn/scikit-learn/blob/2f65ac764be40e2420817cbeca3301c8d664baa3/sklearn/tree/_utils.pyx#L19-L73

Currently copies are made and memory is moved (see __memmove_evex_unaligned_erms using elements of the last section bellow). Moreover, the reallocation generally blocks when using default allocators, which currently is the case of scikit-learn: in a multi-threaded context (e.g. when using RandomForest* or ExtraTrees*), this can create significant contention at the level of the OS, bringing performance down.

Proposed solutions (including proposals for an generally improved maintenance):

Elements for troubleshooting performance

Script
# sklearn_9626.py
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.datasets import make_regression

X, y = make_regression(n_samples=10000, n_features=10)

rfr = ExtraTreesRegressor(
    n_estimators=100,
    max_features=0.8,
    criterion='absolute_error',
    max_depth=6,
    min_samples_leaf=200,
    n_jobs=-1,
    random_state=17,
    verbose=0,
)
rfr.fit(X, y)
Inspection with py-spy

A small follow-up and inspection with py-spy with this command (taskset(1) is used to limit the execution to one thread):

taskset -c 0 py-spy record --rate=500 --native -o sklearn_9626.json -f speedscope \
                    -- python ./sklearn_9626.py

gives the following SpeedScope-inspectable report: sklearn_9626

Inspection with perf(1)
taskset -c 0 perf record  python sklearn_9626.py

gives when using hierarchical report:

Samples: 207K of event 'cycles:u', Event count (approx.): 205685970449                                                                                                                                                                                        
                                                                                                                                                                                                                                                              
-  100.00%        python
   -   74.26%        _utils.cpython-310-x86_64-linux-gnu.so
          59.41%        [.] __pyx_f_7sklearn_4tree_6_utils_14WeightedPQueue_push
          12.62%        [.] __pyx_f_7sklearn_4tree_6_utils_14WeightedPQueue_remove
           0.44%        [.] __pyx_f_7sklearn_4tree_6_utils_24WeightedMedianCalculator_update_median_parameters_post_push
           0.38%        [.] __pyx_f_7sklearn_4tree_6_utils_24WeightedMedianCalculator_update_median_parameters_post_remove
           0.31%        [.] __pyx_f_7sklearn_4tree_6_utils_24WeightedMedianCalculator_pop
           0.29%        [.] __pyx_f_7sklearn_4tree_6_utils_24WeightedMedianCalculator_push
           0.26%        [.] __pyx_f_7sklearn_4tree_6_utils_24WeightedMedianCalculator_get_median
           0.14%        [.] __pyx_f_7sklearn_4tree_6_utils_24WeightedMedianCalculator_size
           0.13%        [.] __pyx_f_7sklearn_4tree_6_utils_14WeightedPQueue_size
           0.08%        [.] __pyx_f_7sklearn_4tree_6_utils_14WeightedPQueue_pop
           0.07%        [.] __pyx_f_7sklearn_4tree_6_utils_24WeightedMedianCalculator_remove
           0.07%        [.] __pyx_f_7sklearn_4tree_6_utils_14WeightedPQueue_get_value_from_index
           0.03%        [.] __pyx_f_7sklearn_4tree_6_utils_14WeightedPQueue_get_weight_from_index
           0.02%        [.] memmove@plt
           0.00%        [.] __pyx_tp_dealloc_7sklearn_4tree_6_utils_WeightedMedianCalculator
           0.00%        [.] __pyx_fuse_6__pyx_f_7sklearn_4tree_6_utils_safe_realloc
           0.00%        [.] __pyx_f_7sklearn_4tree_6_utils_rand_int
           0.00%        [.] __pyx_tp_new_7sklearn_4tree_6_utils_WeightedMedianCalculator
   -   23.89%        libc.so.6
          23.84%        [.] __memmove_evex_unaligned_erms
           0.01%        [.] __memmove_sse2_unaligned_erms
           0.01%        [.] _int_malloc
           0.00%        [.] malloc
           0.00%        [.] __memcmp_evex_movbe
           0.00%        [.] _int_free
           0.00%        [.] __memset_evex_unaligned_erms
           0.00%        [.] printf_positional
           0.00%        [.] memcpy@GLIBC_2.2.5
           0.00%        [.] sem_trywait@@GLIBC_2.34
           0.00%        [.] pthread_mutex_lock@@GLIBC_2.2.5
           0.00%        [.] __strcmp_evex
           0.00%        [.] __strchr_evex
           0.00%        [.] realloc
           0.00%        [.] unlink_chunk.constprop.0
           0.00%        [.] __strncmp_evex
           0.00%        [.] cfree@GLIBC_2.2.5
           0.00%        [.] __strlen_evex
           0.00%        [.] sem_post@GLIBC_2.2.5
           0.00%        [.] 0x0000000000026614
           0.00%        [.] __GI_____strtoll_l_internal
           0.00%        [.] __mbsrtowcs_l
           0.00%        [.] __strchrnul_evex
           0.00%        [.] __parse_one_specmb
           0.00%        [.] __GI___readdir64
           0.00%        [.] __GI___pthread_mutex_unlock_usercnt
           0.00%        [.] pthread_cond_signal@@GLIBC_2.3.2
           0.00%        [.] malloc_consolidate
   +    0.93%        python3.10
   +    0.48%        _criterion.cpython-310-x86_64-linux-gnu.so
   +    0.38%        _splitter.cpython-310-x86_64-linux-gnu.so
   +    0.03%        _multiarray_umath.cpython-310-x86_64-linux-gnu.so
   +    0.02%        ld-linux-x86-64.so.2
   +    0.01%        [unknown]
   +    0.00%        _tree.cpython-310-x86_64-linux-gnu.so
   +    0.00%        _mt19937.cpython-310-x86_64-linux-gnu.so
   +    0.00%        bit_generator.cpython-310-x86_64-linux-gnu.so
   +    0.00%        mtrand.cpython-310-x86_64-linux-gnu.so
   +    0.00%        libcrypto.so.3
   +    0.00%        _bounded_integers.cpython-310-x86_64-linux-gnu.so
   +    0.00%        libstdc++.so.6.0.30
   +    0.00%        _quad_tree.cpython-310-x86_64-linux-gnu.so
   +    0.00%        libbz2.so.1.0.8

https://github.com/yupbank/np_decision_tree/blob/master/decision_tree/strategy_l1.py#L87

I did something in the past to abstract the core data structure into a running median to make faster MAE splits.

+1 for one or more PRs to implement the proposed solutions and compare them with an ad-hoc benchmark.

Note: for py-spy output, if you use -f speedscope, please use a .json filename, otherwise it’s quite confusing (even for my firefox).

That commit is written by past Jeremy who was much better informed on the issue. It’s probably the best I can give you without diving into it again, which I haven’t had the time for. I can also say reflecting back that I didn’t succeed because I wasn’t able to set up a good CI cycle for the WIP code (wasn’t as familiar with Cython and testing as I am now), and because I got segmentation faults. I also suspect my code logic around the update was flawed, this could be avoided with some good test cases when implementing the update.

I would suggest starting over from scratch instead of from my very outdated branch, and using the commit as a reference of where you might need to add code moreso than what code you need to add (the descriptions in this issue are better for that).

PSA, users are encouraged to use HistGradientBoostingRegressor with loss='least_absolute_deviation' which will be considerably faster

@Broundal it’s because when the criterion is MAE (and as a result the random forest outputs the median instead of the mean), it is more difficult to update the loss function. For MSE you can take the original variance and perform a constant time update of the loss when you move a sample from one side of the partition to the other, but for MAE currently the loss is completely recomputed, which takes O(N) time. As a result, the entire process is O(N^2) instead of O(N). The only time they’ll perform comparably is if you have very few samples. At 10k samples you can at least have your computation finish in 20 seconds, at 800k samples the N^2 scaling would imply you need 128000 seconds, or 1.5 days. Conversely 1000 samples should only take 0.2 seconds.

We think there’s a faster way to update the loss function.

What was your idea?

I guess, theoretically, it boils down to computing the exact median which costs O(n log n), right? How about an approximation of the median in linear time, like the Median of Medians?

Especially when it comes to random forests, the error introduced by the median approximation can be neglected or rather seen as another source of randomness.

Is there any update on this? Running a multi-output Random Forest Regresser with the mae criterion and a couple of million samples takes days in comparison to minutes for mse.