Cross-Validation | Ignite Documentation
Edit

Cross-Validation

Cross validation functionality in Apache Ignite is represented by the CrossValidation class. This is a calculator parameterized by the type of model, type of label and key-value types of data. After instantiation (constructor doesn’t accept any additional parameters) we can use a score method to perform cross validation.

Let’s imagine that we have a trainer, a training set and we want to make cross validation using accuracy as a metric and using 4 folds. Apache Ignite allows us to do this as shown in the following example:

Cross-Validation (without Pipeline API usage)

// Create classification trainer
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);

// Create cross-validation instance
CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
  = new CrossValidation<>();

// Set up the cross-validation process
scoreCalculator
    .withIgnite(ignite)
    .withUpstreamCache(trainingSet)
    .withTrainer(trainer)
    .withMetric(MetricName.ACCURACY)
    .withPreprocessor(vectorizer)
    .withAmountOfFolds(4)
    .isRunningOnPipeline(false)

// Calculate accuracy for each fold
double[] accuracyScores = scoreCalculator.scoreByFolds();

In this example we specify trainer and metric as parameters, after that we pass common training arguments such as a link to the Ignite instance, cache, vectorizers, and finally specify the number of folds. This method returns an array containing chosen metrics for all possible splits of the training set.

Cross-Validation (with Pipeline API usage)

Define the pipeline and pass it as a parameter to Cross-Validation instance to run cross-validation on Pipeline.

Caution
The Pipeline API is experimental and could be changed in the next releases.
// Create classification trainer
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);

Pipeline<Integer, Vector, Integer, Double> pipeline
  = new Pipeline<Integer, Vector, Integer, Double>()
    .addVectorizer(vectorizer)
    .addPreprocessingTrainer(new ImputerTrainer<Integer, Vector>())
    .addPreprocessingTrainer(new MinMaxScalerTrainer<Integer, Vector>())
    .addTrainer(trainer);


// Create cross-validation instance
CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
  = new CrossValidation<>();

// Set up the cross-validation process
scoreCalculator
    .withIgnite(ignite)
    .withUpstreamCache(trainingSet)
    .withPipeline(pipeline)
    .withMetric(MetricName.ACCURACY)
    .withPreprocessor(vectorizer)
    .withAmountOfFolds(4)
    .isRunningOnPipeline(false)

// Calculate accuracy for each fold
double[] accuracyScores = scoreCalculator.scoreByFolds();

Example

To see how the Cross Validation can be used in practice, try this example and see step 8 of ML Tutorial that are available on GitHub and delivered with every Apache Ignite distribution.