diff --git a/httomo/cli.py b/httomo/cli.py index 580fcd600..45f756c6a 100644 --- a/httomo/cli.py +++ b/httomo/cli.py @@ -3,10 +3,12 @@ from os import PathLike from pathlib import Path, PurePath import sys -import tempfile from typing import List, Optional, TextIO, Union, Any +import yaml + import click +import shutil from mpi4py import MPI from loguru import logger @@ -18,9 +20,13 @@ from httomo.sweep_runner.param_sweep_runner import ParamSweepRunner from httomo.transform_layer import TransformLayer from httomo.utils import log_exception, log_once, mpi_abort_excepthook -from httomo.yaml_checker import validate_yaml_config +from httomo.yaml_checker import ( + validate_yaml_config, + _get_template_yaml_conf, + PipelineConfig, +) from httomo.runner.task_runner import TaskRunner -from httomo.ui_layer import UiLayer, PipelineFormat +from httomo.ui_layer import UiLayer, PipelineFormat, yaml_loader try: from . import __version__ @@ -259,7 +265,7 @@ def run( method_wrapper_comm = global_comm if not does_contain_sweep else MPI.COMM_SELF if global_comm.rank == 0: - initialise_output_directory(pipeline) + initialise_output_directory(pipeline, does_contain_sweep) setup_logger(Path(httomo.globals.run_out_dir)) @@ -267,13 +273,13 @@ def run( format_enum = ( PipelineFormat.Json if pipeline_format == "Json" else PipelineFormat.Yaml ) - pipeline = generate_pipeline( + pipeline_object = generate_pipeline( in_data_file, pipeline, save_all, method_wrapper_comm, format_enum ) if not does_contain_sweep: execute_high_throughput_run( - pipeline, + pipeline_object, global_comm, gpu_id, max_memory, @@ -283,17 +289,12 @@ def run( save_snapshots, ) else: - execute_sweep_run(pipeline, global_comm) + execute_sweep_run(pipeline_object, global_comm) if mpi_abort_hook: sys.excepthook = sys.__excepthook__ -def _check_yaml(yaml_config: Path, in_data: Path): - """Check a YAML pipeline file for errors.""" - return validate_yaml_config(yaml_config, in_data) - - def transform_limit_str_to_bytes(limit_str: str): try: limit_upper = limit_str.upper() @@ -369,7 +370,9 @@ def set_global_constants( httomo.globals.MAX_CPU_SLICES = max_cpu_slices -def initialise_output_directory(pipeline: Union[Path, str]) -> None: +def initialise_output_directory( + pipeline: Union[Path, str], does_contain_sweep: bool +) -> None: try: Path.mkdir(httomo.globals.run_out_dir, parents=True, exist_ok=True) except PermissionError as e: @@ -378,9 +381,29 @@ def initialise_output_directory(pipeline: Union[Path, str]) -> None: # If pipeline is a file path, copy it to output directory if isinstance(pipeline, Path): - with open(pipeline, "r") as input: + pipeline_conf = yaml_loader(pipeline) + distortion_coeff_path = _get_distortion_coeff_path(pipeline_conf) + if distortion_coeff_path is not None: + shutil.copyfile( + distortion_coeff_path, + Path(httomo.globals.run_out_dir) / "dist_coeff.txt", + ) + path_to_pipeline = pipeline + path_to_saved_pipeline = Path(httomo.globals.run_out_dir) / pipeline.name + # if does_contain_sweep do not inject default parameters due to issue around "sweep" aliases in yaml + if not does_contain_sweep: + path_to_pipeline = path_to_saved_pipeline + pipeline_updated = _substitute_omitted_default_values(pipeline_conf) + with open(path_to_saved_pipeline, "w") as file_descriptor: + yaml.dump( + pipeline_updated, + file_descriptor, + default_flow_style=False, + sort_keys=False, + ) + with open(path_to_pipeline, "r") as input: pipeline_contents = input.read() - with open(Path(httomo.globals.run_out_dir) / pipeline.name, "a") as output: + with open(path_to_saved_pipeline, "w") as output: output.write(f"# Created with HTTomo version {__version__}\n") output.write(pipeline_contents) # If pipeline is a JSON string, write it to a file in the output directory @@ -389,6 +412,30 @@ def initialise_output_directory(pipeline: Union[Path, str]) -> None: f.write(pipeline) +def _substitute_omitted_default_values( + pipeline_conf: PipelineConfig, +) -> PipelineConfig: + templates_conf = _get_template_yaml_conf(pipeline_conf) + for i, (method, template) in enumerate(zip(pipeline_conf, templates_conf)): + template_param_dict = template["parameters"] + method_params = set(method.get("parameters", {}).keys()) + template_params = set(template_param_dict.keys()) + omitted_params = template_params - method_params + + for param in omitted_params: + # insert ommited parameter into the pipeline + pipeline_conf[i]["parameters"][param] = template["parameters"][param] + return pipeline_conf + + +def _get_distortion_coeff_path(pipeline_conf: PipelineConfig) -> Union[None, Path]: + distortion_coeff_path = None + for method in pipeline_conf: + if "distortion_correction" in method["method"]: + distortion_coeff_path = method["parameters"]["metadata_path"] + return distortion_coeff_path + + def generate_pipeline( in_data_file: Path, pipeline: Union[Path, str], @@ -403,13 +450,13 @@ def generate_pipeline( comm=method_wrapper_comm, pipeline_format=pipeline_format, ) - pipeline = init_UiLayer.build_pipeline() + pipeline_object = init_UiLayer.build_pipeline() # perform transformations on pipeline tr = TransformLayer(comm=method_wrapper_comm, save_all=save_all) - pipeline = tr.transform(pipeline) + pipeline_object = tr.transform(pipeline_object) - return pipeline + return pipeline_object def execute_high_throughput_run( diff --git a/tests/test_cli.py b/tests/test_cli.py index 27d6bc39f..cd236acb6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -181,7 +181,7 @@ def test_initialise_output_directory_handles_json_string(tmp_path): json_string = json.dumps(SAMPLE_JSON_PIPELINE) # Call the function with a JSON string - initialise_output_directory(json_string) + initialise_output_directory(json_string, False) # Verify directory was created assert output_dir.exists() @@ -206,7 +206,7 @@ def test_initialise_output_directory_handles_path_input( pipeline_path = Path(__file__).parent.parent / standard_loader # Call the function with a Path - initialise_output_directory(pipeline_path) + initialise_output_directory(pipeline_path, False) # Verify directory was created assert output_dir.exists() @@ -219,7 +219,7 @@ def test_output_dir_created_if_doesnt_exist(tmp_path: Path, standard_loader: str output_dir_cli_arg = tmp_path / "out" httomo.globals.run_out_dir = output_dir_cli_arg / "httomo-output-dir" pipeline_path = Path(__file__).parent.parent / standard_loader - initialise_output_directory(pipeline_path) + initialise_output_directory(pipeline_path, False) assert output_dir_cli_arg.exists() assert httomo.globals.run_out_dir.exists()