xgboost: [jvm-packages] Persist CrossValidator model with xgboost4j-spark error
Environment info
Operating System:
redhat 6.5(with spark-2.1.0)
Compiler:
Package: jvm, xgboost4j-spark lastest xgboost version used
I want to save CrossValidator model ,but i got a error
java.lang.UnsupportedOperationException: Pipeline write will fail on this Pipeline because it contains a stage which does not implement Writable. Non-Writable stage: XGBoostEstimator_88624dc1e519 of type class ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator
at org.apache.spark.ml.Pipeline$SharedReadWrite$$anonfun$validateStages$1.apply(Pipeline.scala:231)
at org.apache.spark.ml.Pipeline$SharedReadWrite$$anonfun$validateStages$1.apply(Pipeline.scala:228)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
at org.apache.spark.ml.Pipeline$SharedReadWrite$.validateStages(Pipeline.scala:228)
at org.apache.spark.ml.Pipeline$PipelineWriter.<init>(Pipeline.scala:202)
at org.apache.spark.ml.Pipeline.write(Pipeline.scala:188)
at org.apache.spark.ml.util.MLWritable$class.save(ReadWrite.scala:154)
at org.apache.spark.ml.Pipeline.save(Pipeline.scala:96)
at org.apache.spark.ml.tuning.ValidatorParams$.saveImpl(ValidatorParams.scala:148)
at org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter.saveImpl(CrossValidator.scala:256)
at org.apache.spark.ml.util.MLWriter.save(ReadWrite.scala:111)
... 50 elided
my code :
import org.apache.spark.ml.Pipeline
import ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator
import org.apache.spark.ml.evaluation._
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.sql.Row
import scala.collection.mutable
import scala.util.Random
// xgboost parameters
def get_param(): mutable.HashMap[String, Any] = {
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 0.1
params += "max_depth" -> 4
params += "min_child_weight" -> 4
params += "num_rounds" -> 20
params += "silent" -> 1
params += "objective" -> "binary:logistic"
params += "booster" -> "gbtree"
params += "gamma" -> 0.0
params += "colsample_bylevel" -> 1
return params
}
val r = new Random(0)
val training = spark.createDataFrame(
Seq.fill(10000)(r.nextInt(2)).map(i => (i, i))
).toDF("feature", "label")
val test = spark.createDataFrame(
Seq.fill(10000)(r.nextInt(2)).map(i => (i, i))
).toDF("feature", "label")
// create pipeline
val assembler = new VectorAssembler()
.setInputCols(Array("feature"))
.setOutputCol("features")
val xgb = new XGBoostEstimator(get_param().toMap)
.setFeaturesCol("features")
val pipeline = new Pipeline()
.setStages(Array(assembler, xgb))
// grid
val paramGrid = new ParamGridBuilder()
// .addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
.addGrid(xgb.round, Array(10, 30, 50,100))
.addGrid(xgb.maxDepth, Array(5, 6, 8,10))
.addGrid(xgb.minChildWeight, Array(0.5, 0.7, 1.0))
.build()
// cv
// val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("probabilities")
val evaluator = new RegressionEvaluator().setLabelCol("label")
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new RegressionEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("rmse")
)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(5)
// Run cross-validation, and choose the best set of parameters.
val cvModel = cv.fit(training)
cvModel.write.overwrite.save("/tmp/xgbModel")
About this issue
- Original URL
- State: closed
- Created 7 years ago
- Comments: 35 (24 by maintainers)
Commits related to this issue
- Add example from https://github.com/dmlc/xgboost/issues/2115#issuecomment-287247577 — committed to geoHeil/xgboost by geoHeil 7 years ago
got a chance to look at the problem this afternoon, it is simply because we didn’t implement MLWritable for XGBoostEstimator, the problem should be fixed by https://github.com/dmlc/xgboost/pull/2265