Split the dataset on test and train datasets | Ignite Documentation

Ignite Summit 2024 — Call For Speakers Now Open — Learn more

Edit

Split the dataset on test and train datasets

Data splitting is meant to split the data stored in a cache into two parts: the training part that is used to train the model, and the test part that is used to estimate the model quality.

All fit() methods has a special parameter to pass a filter condition to each cache.

Note

Due to distributed and lazy nature of dataset operations, the dataset split is the lazy operation too and could be defined as a filter condition that could be applied to the initial cache to form both, the train and test datasets.

In the example below the model is trained only on 75% of the initial dataset. The filter parameter value is the result of the split.getTrainFilter() that could continue with or reject the row from the initial dataset to handle it during the training.

// Define the cache.
IgniteCache<Integer, Vector> dataCache = ...;

// Define the percentage of the train sub-set of the initial dataset.
TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<>().split(0.75);

IgniteModel<Vector, Double> mdl = trainer
  .fit(ignite, dataCache, split.getTrainFilter(), vectorizer);

The split.getTestFilter() could be used to validate the model on the test data. Below is the example of working with the cache directly: printing the predicted and real regression value from the test sub-set of the initial dataset.

// Define the cache query and set the filter.
ScanQuery<Integer, Vector> qry = new ScanQuery<>();
qry.setFilter(split.getTestFilter());


try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(qry)) {
    for (Cache.Entry<Integer, Vector> observation : observations) {
         Vector val = observation.getValue();
         Vector inputs = val.copyOfRange(1, val.size());
         double groundTruth = val.get(0);

         double prediction = mdl.predict(inputs);

         System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
    }
}