diff --git a/tree/ml/inc/ROOT/ML/RClusterLoader.hxx b/tree/ml/inc/ROOT/ML/RClusterLoader.hxx index fabad80b3be33..6a8d651569fb4 100644 --- a/tree/ml/inc/ROOT/ML/RClusterLoader.hxx +++ b/tree/ml/inc/ROOT/ML/RClusterLoader.hxx @@ -227,18 +227,27 @@ public: // --- Shuffled path // Every cluster contributes a prefix to training and a suffix to validation. // Cost: Each cluster is read twice per epoch, only when validation split is more than 0. - // TODO(staider) Swicth between prefix or suffix for validation randomly per cluster + // We generate a random boolean value to decide whether the training set gets the prefix + // or suffix of each cluster to ensure better shuffling across runs when splitting. + std::mt19937 g(fSetSeed); + std::uniform_int_distribution coin(0, 1); + for (const RClusterRange &c : fAllClusters) { const std::size_t sz = c.GetNumEntries(); const std::size_t trainSz = static_cast((1.0f - fValidationSplit) * sz); const std::size_t valSz = sz - trainSz; + // Randomly assign prefix or suffix to training + const uint64_t trainIsPrefix = coin(g); + const uint64_t trainStart = trainIsPrefix ? c.start : c.start + static_cast(valSz); + const uint64_t valStart = trainIsPrefix ? c.start + static_cast(trainSz) : c.start; + if (trainSz > 0) { - fTrainingClusters.push_back({c.rdfIdx, c.start, c.start + static_cast(trainSz)}); + fTrainingClusters.push_back({c.rdfIdx, trainStart, trainStart + static_cast(trainSz)}); fNumTrainingEntries += trainSz; } if (valSz > 0) { - fValidationClusters.push_back({c.rdfIdx, c.start + static_cast(trainSz), c.end}); + fValidationClusters.push_back({c.rdfIdx, valStart, valStart + static_cast(valSz)}); fNumValidationEntries += valSz; } } @@ -392,14 +401,26 @@ public: std::min(static_cast(totalFiltered * (1.0f - fValidationSplit)), trainRemaining); const std::size_t valCount = totalFiltered - trainCount; + // We generate a random boolean value to decide whether the training set gets the prefix + // or suffix of each cluster to ensure better shuffling across runs when splitting. + std::mt19937 g(fSetSeed + fAccumulatedFilteredForTrain); // vary per cluster + std::uniform_int_distribution coin(0, 1); + const uint64_t trainIsPrefix = coin(g); + // The boundary is the raw entry index of the first entry assigned to validation. // Stable across epochs since the same filter always produces the same ordered entries. - const std::uint64_t boundary = (valCount > 0) ? rdfEntries[trainCount] : endRow; + const std::uint64_t trainBoundaryEntry = trainIsPrefix ? rdfEntries[trainCount] : rdfEntries[valCount]; + const std::uint64_t boundary = (valCount > 0) ? trainBoundaryEntry : endRow; + + const std::uint64_t trainStart = trainIsPrefix ? startRow : boundary; + const std::uint64_t trainEnd = trainIsPrefix ? boundary : endRow; + const std::uint64_t valStart = trainIsPrefix ? boundary : startRow; + const std::uint64_t valEnd = trainIsPrefix ? endRow : boundary; if (trainCount > 0) - fTrainingClusters.push_back({rdfIdx, startRow, boundary, trainCount}); + fTrainingClusters.push_back({rdfIdx, trainStart, trainEnd, trainCount}); if (valCount > 0) - fValidationClusters.push_back({rdfIdx, boundary, endRow, valCount}); + fValidationClusters.push_back({rdfIdx, valStart, valEnd, valCount}); fAccumulatedFilteredForTrain += trainCount; return trainCount;