What is K-Nearest Neighbors?
K-Nearest Neighbors (KNN) is a simple, yet powerful machine learning algorithm used for both classification and regression tasks. It's a non-parametric method that makes predictions based on the K nearest data points in the feature space.
How KNN Works
Choose K value
Select the number of nearest neighbors to consider
Calculate distances
Compute distance from new point to all training points
Find K nearest
Identify K points with smallest distances
Make prediction
Classify based on majority vote of K neighbors
Distance Metrics
Euclidean Distance
d = √[(x₂-x₁)² + (y₂-y₁)²]
Most common metric, measures straight-line distance
Manhattan Distance
d = |x₂-x₁| + |y₂-y₁|
Sum of absolute differences, useful for sparse data
Interactive KNN Demonstration
Instructions:
- Adjust K value to see how it affects classification
- Click on the plot to add new test points
- Watch how decision boundaries change
KNN Implementation from Scratch
Let's build KNN from the ground up to understand exactly how it works:
Distance Calculation Function
import numpy as np
from collections import Counter
def euclidean_distance(point1, point2):
"""Calculate Euclidean distance between two points"""
return np.sqrt(np.sum((np.array(point1) - np.array(point2))**2))
KNN Prediction Function
def knn_predict(training_data, training_labels, test_point, k):
"""Predict class for test_point using k nearest neighbors"""
distances = []
# Calculate distances to all training points
for i in range(len(training_data)):
dist = euclidean_distance(test_point, training_data[i])
distances.append((dist, training_labels[i]))
# Sort by distance and get k nearest
distances.sort(key=lambda x: x[0])
k_nearest_labels = [label for _, label in distances[:k]]
# Return most common class
return Counter(k_nearest_labels).most_common(1)[0][0]
Example Usage
# Sample data: [height, weight] -> t-shirt size
training_data = [
[158, 58], [160, 59], [163, 60], [165, 61], [168, 62],
[170, 63], [158, 63], [160, 64], [163, 64], [165, 65]
]
training_labels = ['M', 'M', 'M', 'L', 'L', 'L', 'M', 'L', 'L', 'L']
# Predict for new person: height=162, weight=61
new_person = [162, 61]
predicted_size = knn_predict(training_data, training_labels, new_person, k=3)
print(f"Predicted t-shirt size: {predicted_size}")
Using Scikit-learn
Scikit-learn provides an optimized KNN implementation that's perfect for production use:
Basic Classification
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_blobs
from sklearn.metrics import accuracy_score
# Generate sample data
X, y = make_blobs(n_samples=300, centers=4, n_features=2, random_state=42)
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Create and train KNN classifier
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
# Make predictions
y_pred = knn.predict(X_test)
# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy:.2f}')
Hyperparameter Tuning
from sklearn.model_selection import GridSearchCV
# Define parameter grid
param_grid = {
'n_neighbors': [3, 5, 7, 9, 11],
'weights': ['uniform', 'distance'],
'metric': ['euclidean', 'manhattan']
}
# Grid search with cross-validation
grid_search = GridSearchCV(
KNeighborsClassifier(),
param_grid,
cv=5,
scoring='accuracy'
)
grid_search.fit(X_train, y_train)
print(f"Best parameters: {grid_search.best_params_}")
print(f"Best score: {grid_search.best_score_:.3f}")
Step-by-step Manual Example
Let's work through a complete example using the t-shirt size dataset:
Training Data
| Height (cm) | Weight (kg) | T-shirt Size |
|---|
Make a Prediction
Performance Analysis
Key Insights
- Optimal K: Usually odd numbers to avoid ties
- Bias-Variance Tradeoff: Small K = low bias, high variance
- Computational Complexity: O(n*d) for each prediction
- Memory Usage: Stores entire training dataset
Best Practices
- Scale features to similar ranges
- Use cross-validation to select K
- Consider dimensionality reduction for high-D data
- Use efficient data structures (KD-trees) for large datasets