diff --git a/tools/submission/power/power_checker.py b/tools/submission/power/power_checker.py index c723adbfd9..b8f17e0710 100755 --- a/tools/submission/power/power_checker.py +++ b/tools/submission/power/power_checker.py @@ -32,10 +32,12 @@ class LineWithoutTimeStamp(Exception): + """ Exception raised when there exists a line without a timestamp in the log file. """ pass class CheckerWarning(Exception): + """ Exception raised internally when a Checker reports something wrong. """ pass diff --git a/tools/submission/submission_checker/checks/accuracy_check.py b/tools/submission/submission_checker/checks/accuracy_check.py index db1b1a7559..0180a332de 100644 --- a/tools/submission/submission_checker/checks/accuracy_check.py +++ b/tools/submission/submission_checker/checks/accuracy_check.py @@ -1,3 +1,5 @@ +""" Module for performing accuracy-related checks on MLPerf submission artifacts. """ + from .base import BaseCheck from ..constants import * from ..loader import SubmissionLogs diff --git a/tools/submission/submission_checker/checks/base.py b/tools/submission/submission_checker/checks/base.py index 8e2a678fb9..ae38b39fd6 100644 --- a/tools/submission/submission_checker/checks/base.py +++ b/tools/submission/submission_checker/checks/base.py @@ -1,3 +1,4 @@ +""" Define a base Checker class for MLPerf submission checks. """ from abc import ABC, abstractmethod @@ -8,6 +9,12 @@ class BaseCheck(ABC): """ def __init__(self, log, path): + """Initialize checker + + Args: + log (Logger): A logger instance for logging check results and errors. + path (str): A path to the submission artifact being checked. + """ self.checks = [] self.log = log self.path = path diff --git a/tools/submission/submission_checker/checks/compliance_check.py b/tools/submission/submission_checker/checks/compliance_check.py index 13cc6b16b8..bc7f1dd43b 100644 --- a/tools/submission/submission_checker/checks/compliance_check.py +++ b/tools/submission/submission_checker/checks/compliance_check.py @@ -1,4 +1,4 @@ - +""" Module for performing compliance checks on MLPerf submission artifacts. """ from .base import BaseCheck from ..constants import * from ..loader import SubmissionLogs diff --git a/tools/submission/submission_checker/checks/measurements_checks.py b/tools/submission/submission_checker/checks/measurements_checks.py index 06b89f56fc..732aad4034 100644 --- a/tools/submission/submission_checker/checks/measurements_checks.py +++ b/tools/submission/submission_checker/checks/measurements_checks.py @@ -1,3 +1,4 @@ +""" Module for performing measurement-related checks on MLPerf submission artifacts. """ from .base import BaseCheck from ..constants import * from ..loader import SubmissionLogs diff --git a/tools/submission/submission_checker/checks/performance_check.py b/tools/submission/submission_checker/checks/performance_check.py index e54e7b5564..6b1dcb09dd 100644 --- a/tools/submission/submission_checker/checks/performance_check.py +++ b/tools/submission/submission_checker/checks/performance_check.py @@ -1,3 +1,4 @@ +""" Module for performing performance-related checks on MLPerf submission artifacts. """ from .base import BaseCheck from ..constants import * from ..loader import SubmissionLogs diff --git a/tools/submission/submission_checker/checks/power/power_checker.py b/tools/submission/submission_checker/checks/power/power_checker.py index bf6835133b..0a4de785c6 100755 --- a/tools/submission/submission_checker/checks/power/power_checker.py +++ b/tools/submission/submission_checker/checks/power/power_checker.py @@ -30,12 +30,13 @@ logging.basicConfig(level=logging.INFO) log = logging.getLogger("main") - class LineWithoutTimeStamp(Exception): + """ Exception raised when there exists a line without a timestamp in the log file. """ pass class CheckerWarning(Exception): + """ Exception raised internally when a Checker reports something wrong. """ pass @@ -108,6 +109,14 @@ def _sort_dict(x: Dict[str, Any]) -> "OrderedDict[str, Any]": def hash_dir(dirname: str) -> Dict[str, str]: + """For all files in a directory, create a dictionary that maps their name to their hash. + + Args: + dirname (str): Directory to traverse + + Returns: + Dict[str, str]: Map from fname to hash + """ result: Dict[str, str] = {} for path, dirs, files in os.walk(dirname, topdown=True): @@ -125,6 +134,20 @@ def hash_dir(dirname: str) -> Dict[str, str]: def get_time_from_line( line: str, data_regexp: str, file: str, timezone_offset: int ) -> float: + """Extract time from a given line using regex and save it as a UTC timestamp. + + Args: + line (str): Line to be parsed. + data_regexp (str): Date format regex + file (str): File to be searched in. Used for logging and error-checking. + timezone_offset (int): Offset added to timezone as needed. + + Raises: + LineWithoutTimeStamp: An exception raised if the regex does not find a timestamp in the line. + + Returns: + float: UTC timestamp. + """ log_time_str = re.search(data_regexp, line) if log_time_str and log_time_str.group(0): log_datetime = datetime.strptime( @@ -135,13 +158,20 @@ def get_time_from_line( class SessionDescriptor: + """Class for holding and checking session descriptor data.""" def __init__(self, path: str): + """Initialize session descriptor from JSON file. + + Args: + path (str): Path to session descriptor. + """ self.path = path with open(path, "r") as f: self.json_object: Dict[str, Any] = json.loads(f.read()) self.required_fields_check() def required_fields_check(self) -> None: + """Check that all required fields are present in session descriptor JSON.""" required_fields = [ "version", "timezone", @@ -161,6 +191,13 @@ def required_fields_check(self) -> None: def compare_dicts_values( d1: Dict[str, str], d2: Dict[str, str], comment: str) -> None: + """Assert that all keys in d1 are in d2 and have the same value as the values in d1. + + Args: + d1 (Dict[str, str]): Reference dictionary. + d2 (Dict[str, str]): Dictionary to be compared with. + comment (str): Comment for warning that will popup if assert fails. + """ files_with_diff_check_sum = {k: d1[k] for k in d1 if k in d2 and d1[k] != d2[k]} assert len(files_with_diff_check_sum) == 0, f"{comment}" + "".join( @@ -173,6 +210,13 @@ def compare_dicts_values( def compare_dicts(s1: Dict[str, str], s2: Dict[str, str], comment: str) -> None: + """Ensure that the keys and values in s1 and s2 are the same. + + Args: + s1 (Dict[str, str]): The first dictionary to work with. + s2 (Dict[str, str]): The second dictionary to work with. + comment (str): Comment for warning that will popup if assert fails. + """ assert ( not s1.keys() - s2.keys() ), f"{comment} Missing {', '.join(sorted(s1.keys() - s2.keys()))!r}" @@ -205,6 +249,14 @@ def ptd_messages_check(sd: SessionDescriptor) -> None: msgs: List[Dict[str, str]] = sd.json_object["ptd_messages"] def get_ptd_answer(command: str) -> str: + """From the list of messages, return the reply from the first instance of a given command. + + Args: + command (str): Name of command to look for. + + Returns: + str: The reply of the command if found, otherwise an empty string. + """ for msg in msgs: if msg["cmd"] == command: return msg["reply"] @@ -226,6 +278,12 @@ def get_ptd_answer(command: str) -> str: ), f"Power meter {power_meter_model!r} is not supported. Only {', '.join(SUPPORTED_MODEL.keys())} are supported." def check_reply(cmd: str, reply: str) -> None: + """For a given command, look for a particular reply. If the reply is what is expected, continue. Otherwise raise AssertionError. + + Args: + cmd (str): Command to look for. + reply (str): Expected reply. + """ stop_counter = 0 for msg in msgs: if msg["cmd"].startswith(cmd): @@ -248,6 +306,15 @@ def check_reply(cmd: str, reply: str) -> None: check_reply("Stop", "Stopping untimed measurement") def get_initial_range(param_num: int, reply: str) -> str: + """Get the initial range of a value from the power logs. + + Args: + param_num (int): Identify which field of the power log we need to look into. + reply (str): Name of source of data, used for assertions + + Returns: + str: Initial range when possible, otherwise "Auto". + """ reply_list = reply.split(",") try: if reply_list[param_num] == "0" and float( @@ -259,6 +326,15 @@ def get_initial_range(param_num: int, reply: str) -> str: def get_command_by_value_and_number( cmd: str, number: int) -> Optional[str]: + """From the list of msgs, get the `number` occurence of the command if possible. + + Args: + cmd (str): Command to look for. + number (int): Index of command. 1-indexed, so 1 will get the first instance of the command. + + Returns: + Optional[str]: The full command if found, otherwise None. + """ command_counter = 0 for msg in msgs: if msg["cmd"].startswith(cmd): @@ -347,6 +423,13 @@ def phases_check( def compare_time( phases_client: List[List[float]], phases_server: List[List[float]], mode: str ) -> None: + """Compare the time difference between each checkpoint on the client and the server. If they are less than the TIME_DELTA_TOLERANCE, raise AssertionError. + + Args: + phases_client (List[List[float]]): List of client checkpoints for ranging and testing. + phases_server (List[List[float]]): List of server checkpoints for ranging and testing. + mode (str): Mode information string. Used for assertion message. + """ assert len(phases_client) == len( phases_server ), f"Phases amount is not equal for {mode} mode." @@ -361,6 +444,15 @@ def compare_time( compare_time(phases_testing_c, phases_testing_s, TESTING_MODE) def compare_duration(range_duration: float, test_duration: float) -> None: + """Compare the duration between the range mode and the test mode. Fail if the duration difference is more than 5 percent. + + Args: + range_duration (float): Length of range duration. + test_duration (float): Length of test duration. + + Raises: + CheckerWarning: Raised if duration difference is more than 5 percent. + """ duration_diff = (range_duration - test_duration) / range_duration if duration_diff > 0.5: @@ -373,6 +465,14 @@ def compare_duration(range_duration: float, test_duration: float) -> None: def compare_time_boundaries( begin: float, end: float, phases: List[Any], mode: str ) -> None: + """Temporarily compare time boundaries between beginning and end, and raise an AssertionError if false. + + Args: + begin (float): Beginning timestamp + end (float): Ending timestamp + phases (List[Any]): List of phases. + mode (str): Mode name, used for assertion message. + """ # TODO: temporary workaround, remove when proper DST handling is # implemented! assert ( @@ -409,6 +509,15 @@ def compare_time_boundaries( compare_duration(ranging_duration_d, testing_duration_d) def get_avg_power(power_path: str, run_path: str) -> Tuple[float, float]: + """Get the average power from the power log. + + Args: + power_path (str): Path of power log. Unused. + run_path (str): Path of run log. + + Returns: + Tuple[float, float]: Return average power and pf. + """ # parse the power logs power_begin, power_end = _get_begin_end_time_from_mlperf_log_detail( @@ -512,6 +621,15 @@ def messages_check(client_sd: SessionDescriptor, # Server.json contains all client.json messages and replies. Checked # earlier. def get_version(regexp: str, line: str) -> str: + """Try to get client and server version within server.json using a given regex. + + Args: + regexp (str): Regex to look for a particular system in. + line (str): Line of text to search. + + Returns: + str: Version information if possible, otherwise returns AssertionError + """ version_o = re.search(regexp, line) assert version_o is not None, f"Server version is not defined in:'{line}'" return version_o.group(1) @@ -553,6 +671,11 @@ def results_check( result_paths_copy.remove("power/client.json") def remove_optional_path(res: Dict[str, str]) -> None: + """Given a dictionary of string, hash pairs, delete the paths not in results_paths_copy. + + Args: + res (Dict[str, str]): _description_ + """ keys = list(res.keys()) for path in keys: # Ignore all the optional files. @@ -591,6 +714,13 @@ def remove_optional_path(res: Dict[str, str]) -> None: def result_files_compare( res: Dict[str, str], ref_res: List[str], path: str ) -> None: + """Check if all of the required files are present, using the result dictionary. + + Args: + res (Dict[str, str]): Result dictionary. + ref_res (List[str]): Reference result dictionary. + path (str): Path to file. Used for AssertionError. + """ # If a file is required (in ref_res) but is not present in results directory (res), # then the submission is invalid. absent_files = set(ref_res) - set(res.keys()) @@ -626,6 +756,16 @@ def check_ptd_logs( ptd_log_lines = f.readlines() def find_error_or_warning(reg_exp: str, line: str, error: bool) -> None: + """Given a particular log line and a regex and an error, perform error handling. Any known and common testing error is accepted, while any known and common ranging error leads to an AssertionError. + + Args: + reg_exp (str): Regex to find lines with errors or warnings. + line (str): Log line to check. + error (bool): Whether the line represents an error. + + Raises: + CheckerWarning: If the line is a warning. + """ problem_line = re.search(reg_exp, line) if problem_line and problem_line.group(0): @@ -672,6 +812,14 @@ def find_error_or_warning(reg_exp: str, line: str, error: bool) -> None: start_ranging_line = f": Go with mark {ranging_mark!r}" def get_msg_without_time(line: str) -> Optional[str]: + """Try to extract the message contents of a log line without getting the timestamp. + + Args: + line (str): Line of text to extract from. + + Returns: + Optional[str]: Message content if possible, otherwise None. + """ try: get_time_from_line(line, date_regexp, file_path, timezone_offset) except LineWithoutTimeStamp: @@ -773,6 +921,15 @@ def debug_check(server_sd: SessionDescriptor) -> None: def check_with_logging( check_name: str, check: Callable[[], None]) -> Tuple[bool, bool]: + """Try running `check`, but log any and all errors to a logfile, including tracebacks. + + Args: + check_name (str): Name of check being ran. Used for logging. + check (Callable[[], None]): The check function being ran. + + Returns: + Tuple[bool, bool]: A tuple of (No Errors Detected, Warnings Detected) + """ try: check() except AssertionError as e: @@ -794,6 +951,14 @@ def check_with_logging( def check(path: str) -> int: + """Run the power checker on a particular path. + + Args: + path (str): Path to json files to be checked. + + Returns: + int: 1 if there is an error otherwise 0. Used in sys.exit() call. + """ client = SessionDescriptor(os.path.join(path, "power/client.json")) server = SessionDescriptor(os.path.join(path, "power/server.json")) diff --git a/tools/submission/submission_checker/checks/power_check.py b/tools/submission/submission_checker/checks/power_check.py index d3519a3503..6ca5d6d847 100644 --- a/tools/submission/submission_checker/checks/power_check.py +++ b/tools/submission/submission_checker/checks/power_check.py @@ -1,3 +1,5 @@ +""" Module for performing power-related checks on MLPerf submission artifacts. """ + from .base import BaseCheck from ..constants import * from ..loader import SubmissionLogs diff --git a/tools/submission/submission_checker/checks/structure_check.py b/tools/submission/submission_checker/checks/structure_check.py index e0667bb7f6..076e470156 100644 --- a/tools/submission/submission_checker/checks/structure_check.py +++ b/tools/submission/submission_checker/checks/structure_check.py @@ -2,10 +2,23 @@ class StructureCheck(BaseCheck): + """Simple Sample Check to test the structure of the checker. Not used in actual checking.""" def __init__(self, log, path, parsed_log): + """Initialize Sample Checker. + + Args: + log (str): Path to log + path (str): Path to results + parsed_log (str): Parsed log + """ super().__init__(log, path) self.parsed_log = parsed_log self.checks.append(self.sample_check) def sample_check(self): + """Simple check that always returns true. + + Returns: + bool: True + """ return True diff --git a/tools/submission/submission_checker/configuration/configuration.py b/tools/submission/submission_checker/configuration/configuration.py index 97a20bc4a2..4f07c84a6c 100644 --- a/tools/submission/submission_checker/configuration/configuration.py +++ b/tools/submission/submission_checker/configuration/configuration.py @@ -21,6 +21,23 @@ def __init__( skip_dataset_size_check=False, submitter=None, ): + """Initialize submission checker configuration. + + Args: + version (str): Version Number + extra_model_benchmark_map (_type_): Extra model benchmark map. + ignore_uncommited (bool, optional): Whether to ignore uncommitted changes in loadgen. Defaults to False. + skip_compliance (bool, optional): Whether to skip compliance checks. Defaults to False. + skip_power_check (bool, optional): Whether to skip power checks. Defaults to False. + skip_meaningful_fields_emptiness_check (bool, optional): Whether to skip meaningful fields emptiness checks. Defaults to False. + skip_check_power_measure_files (bool, optional): Whether to skip power measure files checks. Defaults to False. + skip_empty_files_check (bool, optional): Whether to skip empty files checks. Defaults to False. + skip_extra_files_in_root_check (bool, optional): Whether to skip extra files in root checks. Defaults to False. + skip_extra_accuracy_files_check (bool, optional): Whether to skip extra accuracy files checks. Defaults to False. + skip_all_systems_have_results_check (bool, optional): Whether to skip all systems have results checks. Defaults to False. + skip_calibration_check (bool, optional): Whether to skip calibration checks. Defaults to False. + skip_dataset_size_check (bool, optional): Whether to skip dataset size checks. Defaults to False. + """ self.base = MODEL_CONFIG.get(version) self.extra_model_benchmark_map = extra_model_benchmark_map self.version = version @@ -41,6 +58,11 @@ def __init__( self.load_config(version) def load_config(self, version): + """Unused system used to load config from base or other system. + + Args: + version (str): Version Information. + """ # TODO: Load values from self.models = self.base["models"] self.seeds = self.base["seeds"] @@ -58,6 +80,14 @@ def load_config(self, version): self.optional = None def set_type(self, submission_type): + """Set submission type based on if datacenter, edge, or both. + + Args: + submission_type (Enum): Either "datacenter", "edge", "datacenter,edge" or "edge,datacenter". + + Raises: + ValueError: Incorrect system type. + """ if submission_type == "datacenter": self.required = self.base["required-scenarios-datacenter"] self.optional = self.base["optional-scenarios-datacenter"] @@ -76,6 +106,15 @@ def get_submitter(self): return self.submitter def get_mlperf_model(self, model, extra_model_mapping=None): + """Get mlperf model scenarios, converting some naming differences when possible. + + Args: + model (str): Model name + extra_model_mapping (_type_, optional): Dictionary of other model names. Defaults to None. + + Returns: + str: MLPerf model scenarios + """ # preferred - user is already using the official name if model in self.models: return model @@ -107,26 +146,72 @@ def get_mlperf_model(self, model, extra_model_mapping=None): return mlperf_model def get_required(self, model): + """Get required scenarios for a given model. + + Args: + model (str): Model name + + Returns: + set: Set of required scenarios if model is known, otherwise None. + """ model = self.get_mlperf_model(model) if model not in self.required: return None return set(self.required[model]) def get_optional(self, model): + """Get optional scenarios for a given model. + + Args: + model (str): Model name + + Returns: + set: Set of optional scenarios if model is known, otherwise an empty set. + """ model = self.get_mlperf_model(model) if model not in self.optional: return set() return set(self.optional[model]) def get_accuracy_target(self, model): + """Get accuracy target of a given model. + + Args: + model (str): Model name + + Raises: + ValueError: Model not found in mapping. + + Returns: + float: Accuracy target for the given model. + """ if model not in self.accuracy_target: raise ValueError("model not known: " + model) return self.accuracy_target[model] def get_accuracy_upper_limit(self, model): + """Get accuracy upper limit of a given model. + + Args: + model (str): Model name + + Raises: + ValueError: Model not found in mapping. + + Returns: + float: Accuracy upper limit for the given model. + """ return self.accuracy_upper_limit.get(model, None) def get_accuracy_values(self, model): + """Get accuracy patterns and targets for a given model. + + Args: + model (str): Model name + + Returns: + Tuple: Patterns, accuracy targets, accuracy types, accuracy upper limits, and accuracy regex patterns for the given model. + """ patterns = [] acc_targets = [] acc_types = [] @@ -151,18 +236,48 @@ def get_accuracy_values(self, model): return patterns, acc_targets, acc_types, acc_limits, up_patterns, acc_upper_limit def get_performance_sample_count(self, model): + """Get performance sample count for a given model. + + Args: + model (str): Model name + + Raises: + ValueError: If model name not found in performance sample count, get it. + + Returns: + int: Performance sample count for a given model. + """ model = self.get_mlperf_model(model) if model not in self.performance_sample_count: raise ValueError("model not known: " + model) return self.performance_sample_count[model] def get_accuracy_sample_count(self, model): + """Get accuracy sample count for a given model. + + Args: + model (str): Model name + + Raises: + ValueError: If model name not found in performance sample count, get it. + + Returns: + int: Accuracy sample count for a given model. + """ model = self.get_mlperf_model(model) if model not in self.accuracy_sample_count: return self.get_dataset_size(model) return self.accuracy_sample_count[model] def ignore_errors(self, line): + """Check if a given line should be ignored in parsing for errors. + + Args: + line (str): log line + + Returns: + bool: Whether to ignore the line or not. + """ for error in self.base["ignore_errors"]: if error in line: return True @@ -174,18 +289,50 @@ def ignore_errors(self, line): return False def get_min_query_count(self, model, scenario): + """Get minimum query count for a given model and scenario. + + Args: + model (str): Model name + scenario (str): Model scenario. + + Raises: + ValueError: Raised if model is unknown. + + Returns: + int: Minimum number of queries from configuration. + """ model = self.get_mlperf_model(model) if model not in self.min_queries: raise ValueError("model not known: " + model) return self.min_queries[model].get(scenario) def get_dataset_size(self, model): + """Get dataset size for a given model. + + Args: + model (str): Model name + + Raises: + ValueError: Raised if model is unknown. + + Returns: + int: Dataset size for a given model. + """ model = self.get_mlperf_model(model) if model not in self.dataset_size: raise ValueError("model not known: " + model) return self.dataset_size[model] def get_delta_perc(self, model, metric): + """Get delta percentage of a given metric for a model. + + Args: + model (str): Model name + metric (str): Metric name + + Returns: + float: Percentage delta. + """ if model in self.accuracy_delta_perc: if metric in self.accuracy_delta_perc[model]: return self.accuracy_delta_perc[model][metric] @@ -198,12 +345,34 @@ def get_delta_perc(self, model, metric): return required_delta_perc def has_new_logging_format(self): + """Return if the system has the new logging format or not. + + Returns: + bool: True + """ return True def uses_early_stopping(self, scenario): + """Return whether the scenario uses early stopping or not. + + Args: + scenario (str): Scenario Name + + Returns: + bool : Whether the scenario uses early stopping or not. + """ return scenario in ["Server", "SingleStream", "MultiStream"] def requires_equal_issue(self, model, division): + """Return whether the scenario requires equal issue or not. + + Args: + model (str): Model name + division (str): Model division. + + Returns: + bool: Whether the scenario requires equal issue or not. + """ return ( division in ["closed", "network"] and model @@ -220,6 +389,11 @@ def requires_equal_issue(self, model, division): ) def get_llm_models(self): + """Get a list of llm models. + + Returns: + list[str]: A list of LLM model names. + """ return [ "llama2-70b-99", "llama2-70b-99.9", diff --git a/tools/submission/submission_checker/constants.py b/tools/submission/submission_checker/constants.py index dc45cd83d2..2f4abd87f8 100644 --- a/tools/submission/submission_checker/constants.py +++ b/tools/submission/submission_checker/constants.py @@ -1132,12 +1132,12 @@ "84", "59", "12", - "31", + "31", "86", - "122", - "233", + "122", + "233", "96", - ] + ] }, } } diff --git a/tools/submission/submission_checker/main.py b/tools/submission/submission_checker/main.py index babb1193be..08bcaef8f7 100644 --- a/tools/submission/submission_checker/main.py +++ b/tools/submission/submission_checker/main.py @@ -253,6 +253,15 @@ def main(): ) def merge_two_dict(x, y): + """Merge two dictionaries by key. If a key is present in both dictionaries, the new key will have the values of the sum. + + Args: + x (dict): First dictionary + y (dict): Second dictionary + + Returns: + dict: Merged dictionary + """ z = x.copy() for key in y: if key not in z: @@ -285,6 +294,14 @@ def merge_two_dict(x, y): # Counting the number of closed,open and network results def sum_dict_values(x): + """Sum over the values in the dictionary. + + Args: + x (dict): Dictionary to process + + Returns: + float: Sum of dictionary values. + """ count = 0 for key in x: count += x[key] diff --git a/tools/submission/submission_checker/parsers/base.py b/tools/submission/submission_checker/parsers/base.py index 1af266b53b..2776bcdf6e 100644 --- a/tools/submission/submission_checker/parsers/base.py +++ b/tools/submission/submission_checker/parsers/base.py @@ -3,6 +3,7 @@ class BaseParser: + """Base class for parsing the detailed logs.""" def __init__(self, log_path): """ Helper class to parse the detail logs. diff --git a/tools/submission/submission_checker/parsers/loadgen_parser.py b/tools/submission/submission_checker/parsers/loadgen_parser.py index b2812c0b78..356afad185 100644 --- a/tools/submission/submission_checker/parsers/loadgen_parser.py +++ b/tools/submission/submission_checker/parsers/loadgen_parser.py @@ -18,10 +18,8 @@ import sys from .base import BaseParser -# pylint: disable=missing-docstring - - class LoadgenParser(BaseParser): + """Loadgenerator Logs Parser.""" def __init__(self, log_path, strict=True): """ Helper class to parse the detail logs. diff --git a/tools/submission/submission_checker/utils.py b/tools/submission/submission_checker/utils.py index 6435b9e165..806a90b229 100644 --- a/tools/submission/submission_checker/utils.py +++ b/tools/submission/submission_checker/utils.py @@ -4,29 +4,69 @@ def list_dir(*path): + """List directories inside of a given directory. + + Args: + path (list[str]): Path to directory to search from. + + Returns: + List[str]: List of files + """ path = os.path.join(*path) return sorted([f for f in os.listdir( path) if os.path.isdir(os.path.join(path, f))]) def list_files(*path): + """List files inside of a directory. + + Args: + path (list[str]): Path to directory to search from. + + Returns: + List[str]: List of files + """ path = os.path.join(*path) return sorted([f for f in os.listdir( path) if os.path.isfile(os.path.join(path, f))]) def list_empty_dirs_recursively(*path): + """List all empty directories inside of a directory, recursively. + + Args: + path (list[str]): Path to directory to search from. + + Returns: + List[str]: List of empty directories + """ path = os.path.join(*path) return [dirpath for dirpath, dirs, files in os.walk( path) if not dirs and not files] def list_dirs_recursively(*path): + """List all directories ( both empty and not empty ) inside of a directory, recursively. + + Args: + path (list[str]): Path to directory to search from. + + Returns: + List[str]: List of all directories + """ path = os.path.join(*path) return [dirpath for dirpath, dirs, files in os.walk(path)] def list_files_recursively(*path): + """List all files inside of a directory, looking recursively. + + Args: + path (list[str]): Path to directory to search from. + + Returns: + List[str]: List of all files + """ path = os.path.join(*path) return [ os.path.join(dirpath, file) @@ -44,6 +84,16 @@ def files_diff(list1, list2, optional=None): def check_extra_files(path, target_files): + """Check if there are extra files in the directory compared to the target files. + + Args: + path (str): Path to directory to check. + target_files (dict): Dictionary of target files, with keys as subdirectories and values as lists of files in those + subdirectories. + Returns: + bool: Whether there are extra files or not. + List[str]: List of extra files. + """ missing_files = [] check_pass = True folders = list_dir(path) @@ -71,10 +121,29 @@ def check_extra_files(path, target_files): def split_path(m): + """Naively split path from string to a list, separating on forward slashes. Converts backslash pairs to forward slashes. + + Args: + m (string): Path to split + + Returns: + List[string]: Path as list + """ return m.replace("\\", "/").split("/") def get_boolean(s): + """Convert a bool, string or int to a bool. Strings are case-insensitive, and ints are converted to bools by checking if they are 0. None values are converted to False. + + Args: + s (Any): Element to convert into a bool. + + Raises: + TypeError: If the input is not a bool, string or int. + + Returns: + bool: The converted boolean value. + """ if s is None: return False elif isinstance(s, bool): @@ -90,6 +159,15 @@ def get_boolean(s): def merge_two_dict(x, y): + """Merge two dictionaries by key. If a key is present in both dictionaries, the new value is the sum of the values in each of them. + + Args: + x (dict): A dictionary + y (dict): Another dictionary + + Returns: + dict: The sum of the dictionaries. Note: Neither x nor y are changed by this operation. + """ z = x.copy() for key in y: if key not in z: @@ -100,6 +178,14 @@ def merge_two_dict(x, y): def sum_dict_values(x): + """Compute the sum of the values in a given dictionary. Expects values to be numeric. + + Args: + x (dict): Dictionary whose values we need to sum over. + + Returns: + float: The sum of all values in the dictionary. + """ count = 0 for key in x: count += x[key] @@ -107,6 +193,14 @@ def sum_dict_values(x): def is_number(s): + """Check if argument is a number by trying to cast it to a float. + + Args: + s (Any): Some value + + Returns: + bool: Whether the argument is a number or not. + """ try: float(s) return True @@ -129,6 +223,18 @@ def contains_list(l1, l2): def get_performance_metric( config, model, path, scenario_fixed): + """Get performance metric from the logs. + + Args: + config (Config): Configuration class. + model (str): Model name + path (str): Log path + scenario_fixed (str): Scenario information. + + Returns: + float: Performance metric. + """ + # Assumes new logging format version = config.version @@ -164,6 +270,19 @@ def get_performance_metric( def get_inferred_result( scenario_fixed, scenario, res, mlperf_log, config, log_error=False ): + """Get result from logs. + + Args: + scenario_fixed (str): Scenario fixed type + scenario (str): Scenario type. + res (float): Result variable + mlperf_log (LoadgenParser): Parsed log object to get the result from. + config (Config): Configuration object to check for scenario properties. + log_error (bool, optional): Whether to log errors. Defaults to False. + + Returns: + tuple: A tuple containing the inferred flag, result, and validity flag. + """ inferred = False is_valid = True @@ -212,6 +331,14 @@ def get_inferred_result( def check_compliance_perf_dir(test_dir): + """Check that the compliance perf directory is valid. + + Args: + test_dir (str): Test directory. + + Returns: + bool: True if the directory is valid, False otherwise. + """ is_valid = False import logging log = logging.getLogger("main") @@ -254,6 +381,18 @@ def check_compliance_perf_dir(test_dir): def get_power_metric(config, scenario_fixed, log_path, is_valid, res): + """Get power metric from logs and config info. + + Args: + config (Config): Unuseed. + scenario_fixed (str): Fixed scenario. + log_path (str): Path to log + is_valid (bool): Whether the metric passes the checks or not. + res (float): Result. + + Returns: + Tuple: If the result was valid, the power metric, the scenario, and the average power efficiency. + """ # parse the power logs import datetime import logging