Skip to content

⚡️ Speed up function _number_of_shards_in_gen_kwargs by 53%#132

Open
codeflash-ai[bot] wants to merge 1 commit intomainfrom
codeflash/optimize-_number_of_shards_in_gen_kwargs-mlcxbbhi
Open

⚡️ Speed up function _number_of_shards_in_gen_kwargs by 53%#132
codeflash-ai[bot] wants to merge 1 commit intomainfrom
codeflash/optimize-_number_of_shards_in_gen_kwargs-mlcxbbhi

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Feb 7, 2026

📄 53% (0.53x) speedup for _number_of_shards_in_gen_kwargs in src/datasets/utils/sharding.py

⏱️ Runtime : 281 microseconds 183 microseconds (best of 92 runs)

📝 Explanation and details

The optimization achieves a 53% runtime improvement (281μs → 183μs) by eliminating redundant operations and combining validation into a single pass through the dictionary.

Key Changes

1. Single-Pass Processing with Early Exit
The original code made two full passes through the data:

  • First pass: Dictionary comprehension to collect all list lengths (57.2% of original runtime)
  • Second pass: Creating a set to check for length mismatches (14.3% of original runtime)

The optimized version processes everything in one loop, tracking the first list length encountered and immediately detecting mismatches as it iterates. This eliminates the expensive set() creation and reduces dictionary iterations.

2. Efficient Mismatch Detection
Instead of building a complete set of unique lengths and then checking if multiple exist, the code now compares each list length against the first one found. This allows early termination when a mismatch is detected, avoiding unnecessary iterations.

3. Direct Length Calculation
The optimization replaces max(lists_lengths.values(), default=0) with a simple variable check (first_length if first_length is not None else 0), avoiding the overhead of calling max() on dictionary values.

Performance Impact

Based on the test results, the optimization excels across all scenarios:

  • Simple cases (no lists, single list): 92-118% faster - The single-pass approach eliminates overhead from creating empty sets and iterating empty dictionaries
  • Multiple lists (same length): 52-87% faster - Avoids building intermediate set structures
  • Error cases (mismatched lengths): 15-25% faster - Early detection stops processing as soon as a mismatch is found
  • Large workloads (many lists/keys): 5-59% faster - Benefits from reduced iterations and memory allocations

Relevance to Workloads

Looking at the function_references, this function is called from _prepare_split() in the dataset builder during multiprocessing setup. The function determines whether to parallelize dataset generation across multiple processes. Since this runs during dataset preparation (a hot path when num_proc > 1), the 53% speedup directly improves dataset loading initialization time, especially beneficial when preparing large datasets with multiple data sources.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 4 Passed
🌀 Generated Regression Tests 48 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
⚙️ Click to see Existing Unit Tests
Test File::Test Function Original ⏱️ Optimized ⏱️ Speedup
test_sharding_utils.py::test_number_of_shards_in_gen_kwargs 7.54μs 6.62μs 13.9%✅
🌀 Click to see Generated Regression Tests
import pytest  # used for our unit tests
from src.datasets.utils.sharding import _number_of_shards_in_gen_kwargs

def test_no_lists_returns_one():
    """
    Basic case: when there are no list values in gen_kwargs, the function should return 1.
    This covers dictionaries with scalars, tuples, strings, etc., which are not considered lists.
    """
    # tuples are intentionally not lists; should be ignored
    gen_kwargs = {"a": 42, "b": (1, 2, 3), "c": "a string"}
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs) # 4.89μs -> 2.45μs (99.6% faster)

def test_single_list_returns_its_length():
    """
    Basic case: a single list in gen_kwargs should make the function return that list's length.
    """
    gen_kwargs = {"data": [10, 20, 30, 40]}
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs) # 4.59μs -> 2.31μs (98.8% faster)

def test_multiple_lists_same_length_returns_length():
    """
    Basic case: multiple lists of the same length should return that common length.
    """
    gen_kwargs = {"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [None, None, None]}
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs) # 4.91μs -> 2.79μs (76.1% faster)

def test_list_of_length_zero_returns_one():
    """
    Edge case: a single empty list should not produce zero shards; the function normalizes to at least 1.
    """
    gen_kwargs = {"empty": []}
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs) # 4.53μs -> 2.29μs (98.0% faster)

def test_all_lists_zero_length_returns_one():
    """
    Edge case: several lists all of length zero should still produce at least one shard.
    """
    gen_kwargs = {"a": [], "b": []}
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs) # 4.82μs -> 2.57μs (87.7% faster)

def test_mixed_list_and_non_list_types_uses_list_length_only():
    """
    Edge case: when gen_kwargs contains both list and non-list types, only lists count toward sharding.
    Non-list types (tuples, strings, ints) should be ignored for shard count computation.
    """
    gen_kwargs = {"list_key": [1, 2, 3, 4], "tuple_key": (1, 2), "str_key": "hello"}
    # Only the list length (4) should determine the number of shards
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs) # 5.03μs -> 2.83μs (77.4% faster)

def test_nested_lists_count_as_list_length():
    """
    Edge case: values which are lists of lists are still lists; only their top-level length matters.
    """
    gen_kwargs = {"nested": [[1], [2], [3]]}  # length == 3
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs) # 4.54μs -> 2.30μs (97.0% faster)

def test_unequal_list_lengths_raises_runtimeerror_and_message_contains_details():
    """
    Edge case: when multiple lists of different lengths are present, the function must raise RuntimeError.
    The error message must mention the ambiguity and list each offending key with its length.
    """
    gen_kwargs = {"first": [1, 2], "second": [3], "third": [7, 8]}
    with pytest.raises(RuntimeError) as exc_info:
        _number_of_shards_in_gen_kwargs(gen_kwargs) # 6.74μs -> 5.39μs (24.9% faster)
    msg = str(exc_info.value)

def test_strings_are_ignored_not_treated_as_list():
    """
    Edge case: strings are not lists; they must be ignored by the function.
    If only string values are present, result should be 1.
    """
    gen_kwargs = {"s1": "abc", "s2": "de"}
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs) # 4.42μs -> 2.18μs (103% faster)

def test_single_large_list_performance_and_correctness():
    """
    Large-scale test: verify correct behavior with a large (but bounded) list.
    Use a list under 1000 elements to respect test constraints.
    """
    large_len = 800
    gen_kwargs = {"bigdata": list(range(large_len))}
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs) # 4.67μs -> 2.39μs (95.3% faster)

def test_many_lists_same_large_length():
    """
    Large-scale test: several large lists (but kept under 1000 elements each) with identical lengths
    should return that common length. This checks the function's handling of repetitive similar-sized lists.
    """
    n_lists = 5
    list_length = 999  # near the 1000-element guideline but still within allowed bounds
    gen_kwargs = {f"list_{i}": list(range(list_length)) for i in range(n_lists)}
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs) # 5.70μs -> 3.59μs (58.8% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import pytest
from src.datasets.utils.sharding import _number_of_shards_in_gen_kwargs

def test_empty_gen_kwargs():
    """Test with empty dictionary - should return 1 as default"""
    codeflash_output = _number_of_shards_in_gen_kwargs({}); result = codeflash_output # 4.14μs -> 1.91μs (118% faster)

def test_single_empty_list():
    """Test with a single empty list - should return 1 (max of 1 and 0)"""
    codeflash_output = _number_of_shards_in_gen_kwargs({"data": []}); result = codeflash_output # 4.58μs -> 2.25μs (104% faster)

def test_single_list_one_element():
    """Test with a single list containing one element"""
    codeflash_output = _number_of_shards_in_gen_kwargs({"data": [1]}); result = codeflash_output # 4.48μs -> 2.33μs (92.1% faster)

def test_single_list_multiple_elements():
    """Test with a single list containing multiple elements"""
    codeflash_output = _number_of_shards_in_gen_kwargs({"data": [1, 2, 3, 4, 5]}); result = codeflash_output # 4.45μs -> 2.25μs (97.6% faster)

def test_single_list_with_strings():
    """Test with a list of strings"""
    codeflash_output = _number_of_shards_in_gen_kwargs({"filenames": ["a.txt", "b.txt", "c.txt"]}); result = codeflash_output # 4.44μs -> 2.32μs (91.0% faster)

def test_multiple_lists_same_length():
    """Test with multiple lists of the same length"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "urls": ["url1", "url2", "url3"],
        "names": ["name1", "name2", "name3"]
    }); result = codeflash_output # 4.73μs -> 2.62μs (80.5% faster)

def test_non_list_values_ignored():
    """Test that non-list values are ignored in shard count calculation"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data": [1, 2, 3],
        "config": "some_config",
        "num_workers": 4
    }); result = codeflash_output # 4.89μs -> 2.63μs (85.7% faster)

def test_tuple_values_ignored():
    """Test that tuple values are not considered for sharding"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data": [1, 2],
        "metadata": (10, 20, 30)  # tuple should be ignored
    }); result = codeflash_output # 4.84μs -> 2.63μs (83.8% faster)

def test_dict_values_ignored():
    """Test that dict values are ignored"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data": [1, 2, 3],
        "config": {"key": "value"}
    }); result = codeflash_output # 4.81μs -> 2.58μs (86.9% faster)

def test_none_values_ignored():
    """Test that None values are ignored"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data": [1, 2],
        "extra": None
    }); result = codeflash_output # 4.64μs -> 2.37μs (95.4% faster)

def test_string_values_ignored():
    """Test that string values are not treated as lists"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data": [1, 2],
        "description": "some description"
    }); result = codeflash_output # 4.62μs -> 2.38μs (93.8% faster)

def test_mismatched_list_lengths_two_lists():
    """Test with two lists of different lengths - should raise RuntimeError"""
    with pytest.raises(RuntimeError) as exc_info:
        _number_of_shards_in_gen_kwargs({
            "data1": [1, 2, 3],
            "data2": [1, 2]
        }) # 6.83μs -> 5.95μs (14.9% faster)

def test_mismatched_list_lengths_three_lists():
    """Test with three lists of different lengths - should raise RuntimeError"""
    with pytest.raises(RuntimeError) as exc_info:
        _number_of_shards_in_gen_kwargs({
            "list1": [1, 2, 3],
            "list2": [1, 2],
            "list3": [1, 2, 3, 4]
        }) # 6.84μs -> 5.61μs (21.9% faster)

def test_two_lists_same_length_one_different():
    """Test with two lists of same length and one different - should raise RuntimeError"""
    with pytest.raises(RuntimeError) as exc_info:
        _number_of_shards_in_gen_kwargs({
            "data1": [1, 2, 3],
            "data2": [1, 2, 3],
            "data3": [1, 2]
        }) # 6.69μs -> 5.81μs (15.2% faster)

def test_empty_list_with_non_empty_list():
    """Test with one empty list and one non-empty list - should raise RuntimeError"""
    with pytest.raises(RuntimeError) as exc_info:
        _number_of_shards_in_gen_kwargs({
            "empty": [],
            "non_empty": [1, 2, 3]
        }) # 6.42μs -> 5.39μs (19.1% faster)

def test_multiple_lists_all_empty():
    """Test with multiple empty lists - should return 1"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "list1": [],
        "list2": [],
        "list3": []
    }); result = codeflash_output # 5.05μs -> 2.99μs (69.0% faster)

def test_list_with_complex_objects():
    """Test with lists containing complex objects like dicts"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data": [
            {"key": "value1"},
            {"key": "value2"},
            {"key": "value3"}
        ]
    }); result = codeflash_output # 4.53μs -> 2.47μs (83.3% faster)

def test_list_with_nested_lists():
    """Test with lists containing nested lists"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data": [[1, 2], [3, 4], [5, 6]]
    }); result = codeflash_output # 4.48μs -> 2.24μs (100% faster)

def test_multiple_empty_lists_with_non_lists():
    """Test with multiple empty lists and non-list values"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "list1": [],
        "list2": [],
        "config": "value",
        "num": 42
    }); result = codeflash_output # 4.92μs -> 2.88μs (71.2% faster)

def test_single_element_per_key_multiple_keys():
    """Test with multiple keys each having a single-element list"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "a": [1],
        "b": [2],
        "c": [3]
    }); result = codeflash_output # 4.97μs -> 2.80μs (77.4% faster)

def test_very_long_list():
    """Test with a very long list"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data": list(range(1000))
    }); result = codeflash_output # 4.59μs -> 2.42μs (89.8% faster)

def test_mixed_types_in_gen_kwargs():
    """Test with mixed types: list, int, str, float, bool"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "urls": ["a", "b", "c"],
        "timeout": 30,
        "name": "dataset",
        "probability": 0.5,
        "enabled": True
    }); result = codeflash_output # 5.25μs -> 3.02μs (73.9% faster)

def test_large_single_list():
    """Test with a large single list (500 elements)"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data": list(range(500))
    }); result = codeflash_output # 4.54μs -> 2.31μs (96.5% faster)

def test_large_multiple_lists_same_length():
    """Test with multiple large lists of the same length (100 elements each, 8 lists)"""
    gen_kwargs = {
        f"source_{i}": list(range(100))
        for i in range(8)
    }
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs); result = codeflash_output # 5.66μs -> 3.71μs (52.7% faster)

def test_large_mixed_data_with_many_lists():
    """Test with many lists and non-list values (10 lists of 200 items each)"""
    gen_kwargs = {
        f"list_{i}": list(range(200))
        for i in range(10)
    }
    # Add many non-list values
    for i in range(50):
        gen_kwargs[f"config_{i}"] = f"value_{i}"
    
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs); result = codeflash_output # 9.08μs -> 7.19μs (26.3% faster)

def test_large_list_with_large_objects():
    """Test with a large list containing large string objects"""
    large_strings = ["x" * 1000 for _ in range(250)]
    codeflash_output = _number_of_shards_in_gen_kwargs({"data": large_strings}); result = codeflash_output # 4.47μs -> 2.36μs (89.9% faster)

def test_large_number_of_keys_with_single_list():
    """Test with a large number of non-list keys and one list"""
    gen_kwargs = {"data": list(range(100))}
    # Add many non-list keys
    for i in range(500):
        gen_kwargs[f"param_{i}"] = i
    
    codeflash_output = _number_of_shards_in_gen_kwargs(gen_kwargs); result = codeflash_output # 33.9μs -> 32.2μs (5.09% faster)

def test_error_message_quality_with_large_mismatches():
    """Test that error message contains information about all mismatched lists"""
    gen_kwargs = {
        "list_a": list(range(100)),
        "list_b": list(range(150)),
        "list_c": list(range(200))
    }
    
    with pytest.raises(RuntimeError) as exc_info:
        _number_of_shards_in_gen_kwargs(gen_kwargs) # 6.61μs -> 5.34μs (23.7% faster)
    
    error_msg = str(exc_info.value)

def test_exact_same_length_verification_large():
    """Test that lists must be exactly the same length (not just similar)"""
    with pytest.raises(RuntimeError):
        _number_of_shards_in_gen_kwargs({
            "list_a": list(range(250)),
            "list_b": list(range(251))
        }) # 6.38μs -> 5.38μs (18.7% faster)

def test_list_with_falsy_values():
    """Test that lists with falsy values (0, False, None) are counted correctly"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data": [0, False, None, "", []]
    }); result = codeflash_output # 4.51μs -> 2.30μs (95.7% faster)

def test_list_with_none_elements():
    """Test list containing only None values"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data": [None, None, None]
    }); result = codeflash_output # 4.41μs -> 2.27μs (94.4% faster)

def test_numeric_keys_in_gen_kwargs():
    """Test that numeric-like string keys are handled correctly"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "1": [1, 2, 3],
        "2": "not a list"
    }); result = codeflash_output # 4.71μs -> 2.38μs (97.6% faster)

def test_special_characters_in_keys():
    """Test keys with special characters"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data-source": [1, 2],
        "config_file": "path/to/config"
    }); result = codeflash_output # 4.57μs -> 2.40μs (89.9% faster)

def test_unicode_keys():
    """Test keys with unicode characters"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "데이터": [1, 2, 3],
        "🔑": "value"
    }); result = codeflash_output # 4.61μs -> 2.35μs (95.9% faster)

def test_key_with_length_name():
    """Test that a key named 'length' doesn't cause issues"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "data": [1, 2, 3],
        "length": 100  # This is not a list, should be ignored
    }); result = codeflash_output # 4.74μs -> 2.38μs (99.2% faster)

def test_list_as_non_list_value():
    """Test that only direct list values are counted, not lists within other structures"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "main_data": [1, 2, 3],
        "nested": {"inner_list": [4, 5, 6]}  # This list is inside a dict
    }); result = codeflash_output # 4.86μs -> 2.61μs (86.1% faster)

def test_identical_multiple_lists_max_length():
    """Test that the maximum length is correctly identified with multiple identical lists"""
    codeflash_output = _number_of_shards_in_gen_kwargs({
        "list_a": [1, 2, 3, 4],
        "list_b": [1, 2, 3, 4],
        "list_c": [1, 2, 3, 4]
    }); result = codeflash_output # 4.91μs -> 2.73μs (79.3% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-_number_of_shards_in_gen_kwargs-mlcxbbhi and push.

Codeflash Static Badge

The optimization achieves a **53% runtime improvement** (281μs → 183μs) by eliminating redundant operations and combining validation into a single pass through the dictionary.

## Key Changes

**1. Single-Pass Processing with Early Exit**
The original code made two full passes through the data:
- First pass: Dictionary comprehension to collect all list lengths (57.2% of original runtime)
- Second pass: Creating a set to check for length mismatches (14.3% of original runtime)

The optimized version processes everything in one loop, tracking the first list length encountered and immediately detecting mismatches as it iterates. This eliminates the expensive `set()` creation and reduces dictionary iterations.

**2. Efficient Mismatch Detection**
Instead of building a complete set of unique lengths and then checking if multiple exist, the code now compares each list length against the first one found. This allows early termination when a mismatch is detected, avoiding unnecessary iterations.

**3. Direct Length Calculation**
The optimization replaces `max(lists_lengths.values(), default=0)` with a simple variable check (`first_length if first_length is not None else 0`), avoiding the overhead of calling `max()` on dictionary values.

## Performance Impact

Based on the test results, the optimization excels across all scenarios:
- **Simple cases** (no lists, single list): 92-118% faster - The single-pass approach eliminates overhead from creating empty sets and iterating empty dictionaries
- **Multiple lists** (same length): 52-87% faster - Avoids building intermediate set structures
- **Error cases** (mismatched lengths): 15-25% faster - Early detection stops processing as soon as a mismatch is found
- **Large workloads** (many lists/keys): 5-59% faster - Benefits from reduced iterations and memory allocations

## Relevance to Workloads

Looking at the `function_references`, this function is called from `_prepare_split()` in the dataset builder during multiprocessing setup. The function determines whether to parallelize dataset generation across multiple processes. Since this runs during dataset preparation (a hot path when `num_proc > 1`), the 53% speedup directly improves dataset loading initialization time, especially beneficial when preparing large datasets with multiple data sources.
@codeflash-ai codeflash-ai bot requested a review from aseembits93 February 7, 2026 23:06
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Feb 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants