Implement a function stratified_train_test_split that splits a dataset into training and testing sets while preserving the proportion of each class in both sets.
Stratified splitting is essential when dealing with classification problems, especially with imbalanced datasets, as it ensures that both training and test sets have representative samples from each class.
Parameters:
X: A 2D numpy array of shape (n_samples, n_features) representing the feature matrixy: A 1D numpy array of shape (n_samples,) representing the class labelstest_size: A float between 0 and 1 representing the proportion of data to include in the test setrandom_seed: An optional integer for reproducibility of the random shuffling
Returns:
- A tuple of four arrays: (X_train, X_test, y_train, y_test)
Requirements:
- For each class, calculate the number of samples for the test set as
int(n_class_samples * test_size) - Shuffle samples within each class before splitting
- Maintain the original class distribution proportionally in both train and test sets
Examples
Example 1:
Input:
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
y = np.array([0, 0, 0, 1, 1, 1])
test_size = 0.5
random_seed = 42Output:
X_train shape: (4, 2), X_test shape: (2, 2)
y_train class counts: {0: 2, 1: 2}
y_test class counts: {0: 1, 1: 1}Explanation: With 6 samples (3 per class) and test_size=0.5, we calculate n_test = int(3 * 0.5) = 1 sample per class for the test set. After shuffling within each class using random_seed=42, we select 1 sample from each class for testing and 2 from each for training. The original 50-50 class distribution is maintained in both sets: train has 2 samples of class 0 and 2 of class 1, while test has 1 sample of each class.
Starter Code
import numpy as np
def stratified_train_test_split(X, y, test_size, random_seed=None):
"""
Split data into train and test sets while maintaining class proportions.
Args:
X: Feature matrix of shape (n_samples, n_features)
y: Label vector of shape (n_samples,)
test_size: Proportion of data for test set (0 < test_size < 1)
random_seed: Random seed for reproducibility
Returns:
X_train, X_test, y_train, y_test
"""
np.random.seed(random_seed)
Python3
ReadyLines: 1Characters: 0
Ready