k-NN Regression
The Apache Ignite Machine Learning component provides two versions of the widely used k-NN (k-nearest neighbors) algorithm - one for classification tasks and the other for regression tasks.
This documentation reviews k-NN as a solution for regression tasks.
Trainer and Model
The k-NN regression algorithm is a non-parametric method whose input consists of the k-closest training examples in the feature space. Each training example has a property value in a numerical form associated with the given training example.
The k-NN regression algorithm uses all training sets to predict a property value for the given test sample.
This predicted property value is an average of the values of its k nearest neighbors. If k
is 1
, then the test sample is simply assigned to the property value of a single nearest neighbor.
Presently, Ignite supports a few parameters for k-NN regression algorithm:
-
k
- a number of nearest neighbors -
distanceMeasure
- one of the distance metrics provided by the ML framework such as Euclidean, Hamming or Manhattan -
isWeighted
- false by default, if true it enables a weighted KNN algorithm. -
dataCache
- holds a training set of objects for which the class is already known. -
indexType
- distributed spatial index, has three values: ARRAY, KD_TREE, BALL_TREE
// Create trainer
KNNRegressionTrainer trainer = new KNNRegressionTrainer()
.withK(5)
.withIdxType(SpatialIndexType.BALL_TREE)
.withDistanceMeasure(new ManhattanDistance())
.withWeighted(true);
// Train model.
KNNClassificationModel knnMdl = trainer.fit(
ignite,
dataCache,
vectorizer
);
// Make a prediction.
double prediction = knnMdl.predict(observation);
Example
To see how kNN Regression can be used in practice, try this example that is available on GitHub and delivered with every Apache Ignite distribution.
The training dataset is the Iris dataset which can be loaded from the UCI Machine Learning Repository.
Apache, Apache Ignite, the Apache feather and the Apache Ignite logo are either registered trademarks or trademarks of The Apache Software Foundation.