CrossValidator — Model Tuning / Finding The Best Model

CrossValidator is an Estimator for model tuning, i.e. finding the best model for given parameters and a dataset.

CrossValidator splits the dataset into a set of non-overlapping randomly-partitioned numFolds pairs of training and validation datasets.

CrossValidator generates a CrossValidatorModel to hold the best model and average cross-validation metrics.

CrossValidator takes any Estimator for model selection, including the Pipeline that is used to transform raw datasets and generate a Model.
Use ParamGridBuilder for the parameter grid, i.e. collection of ParamMaps for model tuning.
val pipeline: Pipeline = ...

val paramGrid: Array[ParamMap] = new ParamGridBuilder().

val cv = new CrossValidator().

val bestModel: CrossValidatorModel =

CrossValidator is a MLWritable.

Table 1. CrossValidator' Parameters
Parameter Default Value Description



Estimator for best model selection.



Param maps for the estimator



Evaluator to select hyper-parameters that maximize the validated metric



The number of folds for cross validation

Must be at least 2.



The number of threads to use while fitting a model

Must be at least 1.


Random seed


Enable INFO or DEBUG logging levels for logger to see what happens inside.

Add the following line to conf/

Refer to Logging.

Finding The Best Model — fit Method

fit(dataset: Dataset[_]): CrossValidatorModel
fit is a part of Estimator Contract to fit a model (i.e. produce a model).

fit validates the schema (with logging turned on).

You should see the following DEBUG message in the logs:

DEBUG CrossValidator: Input schema: [json]

fit makes sure that estimator, evaluator, estimatorParamMaps and parallelism parameters are defined or reports a NoSuchElementException.

java.util.NoSuchElementException: Failed to find a default value for [name]

fit creates a ExecutionContext (per parallelism parameter).

fit creates a Instrumentation and requests it to print out the parameters numFolds, seed, parallelism to the logs.


fit requests Instrumentation to print out the tuning parameters to the logs.


fit kFolds the RDD of the dataset per numFolds and seed parameters.

fit passes the underlying RDD of the dataset to kFolds.

fit computes metrics for every pair of training and validation RDDs.

fit calculates the average metrics over all kFolds.

You should see the following INFO message in the logs:

INFO Average cross-validation metrics: [metrics]

fit requests the Evaluator for the best cross-validation metric.

You should see the following INFO message in the logs:

INFO Best set of parameters:
INFO Best cross-validation metric: [bestMetric].

fit requests the Estimator to fit the best model (for the dataset and the best set of estimatorParamMap).

You should see the following INFO message in the logs:

INFO training finished

In the end, fit creates a CrossValidatorModel (for the ID, the best model and the average metrics for every kFold) and copies parameters to it.

fit and Computing Metric for Training and Validation RDDs

fit computes metrics for every pair of training and validation RDDs (from kFold).

fit creates and persists training and validation datasets.

You can monitor the storage for persisting the datasets in web UI’s Storage tab.

fit Prints out the following DEBUG message to the logs

DEBUG Train split [index] with multiple sets of parameters.

For every map in estimatorParamMaps parameter fit fits a model using the Estimator.

fit does the fitting in parallel per parallelism parameter.

parallelism parameter defaults to 1, i.e. no parallelism for fitting models.
fit unpersists the training data (per pair of training and validation RDDs) when all models have been trained.

fit requests the models to transform their respective validation datasets (with the corresponding parameters from estimatorParamMaps) and then requests the Evaluator to evaluate the transformed datasets.

fit prints out the following DEBUG message to the logs:

DEBUG Got metric [metric] for model trained with $paramMap.

fit waits until all metrics are available and unpersists the validation dataset.

Creating CrossValidator Instance

CrossValidator takes the following when created:

  • Unique ID

Validating and Transforming Schema — transformSchema Method

transformSchema(schema: StructType): StructType
transformSchema is a part of PipelineStage Contract.

transformSchema simply passes the call to transformSchemaImpl (that is shared between CrossValidator and TrainValidationSplit).

results matching ""

    No results matching ""