Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 65 additions & 18 deletions httomo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from os import PathLike
from pathlib import Path, PurePath
import sys
import tempfile
from typing import List, Optional, TextIO, Union, Any

import yaml

import click
import shutil
from mpi4py import MPI
from loguru import logger

Expand All @@ -18,9 +20,13 @@
from httomo.sweep_runner.param_sweep_runner import ParamSweepRunner
from httomo.transform_layer import TransformLayer
from httomo.utils import log_exception, log_once, mpi_abort_excepthook
from httomo.yaml_checker import validate_yaml_config
from httomo.yaml_checker import (
validate_yaml_config,
_get_template_yaml_conf,
PipelineConfig,
)
from httomo.runner.task_runner import TaskRunner
from httomo.ui_layer import UiLayer, PipelineFormat
from httomo.ui_layer import UiLayer, PipelineFormat, yaml_loader

try:
from . import __version__
Expand Down Expand Up @@ -259,21 +265,21 @@ def run(
method_wrapper_comm = global_comm if not does_contain_sweep else MPI.COMM_SELF

if global_comm.rank == 0:
initialise_output_directory(pipeline)
initialise_output_directory(pipeline, does_contain_sweep)

setup_logger(Path(httomo.globals.run_out_dir))

# Convert string to enum
format_enum = (
PipelineFormat.Json if pipeline_format == "Json" else PipelineFormat.Yaml
)
pipeline = generate_pipeline(
pipeline_object = generate_pipeline(
in_data_file, pipeline, save_all, method_wrapper_comm, format_enum
)

if not does_contain_sweep:
execute_high_throughput_run(
pipeline,
pipeline_object,
global_comm,
gpu_id,
max_memory,
Expand All @@ -283,17 +289,12 @@ def run(
save_snapshots,
)
else:
execute_sweep_run(pipeline, global_comm)
execute_sweep_run(pipeline_object, global_comm)

if mpi_abort_hook:
sys.excepthook = sys.__excepthook__


def _check_yaml(yaml_config: Path, in_data: Path):
"""Check a YAML pipeline file for errors."""
return validate_yaml_config(yaml_config, in_data)


def transform_limit_str_to_bytes(limit_str: str):
try:
limit_upper = limit_str.upper()
Expand Down Expand Up @@ -369,7 +370,9 @@ def set_global_constants(
httomo.globals.MAX_CPU_SLICES = max_cpu_slices


def initialise_output_directory(pipeline: Union[Path, str]) -> None:
def initialise_output_directory(
pipeline: Union[Path, str], does_contain_sweep: bool
) -> None:
try:
Path.mkdir(httomo.globals.run_out_dir, parents=True, exist_ok=True)
except PermissionError as e:
Expand All @@ -378,9 +381,29 @@ def initialise_output_directory(pipeline: Union[Path, str]) -> None:

# If pipeline is a file path, copy it to output directory
if isinstance(pipeline, Path):
with open(pipeline, "r") as input:
pipeline_conf = yaml_loader(pipeline)
distortion_coeff_path = _get_distortion_coeff_path(pipeline_conf)
if distortion_coeff_path is not None:
shutil.copyfile(
distortion_coeff_path,
Path(httomo.globals.run_out_dir) / "dist_coeff.txt",
)
path_to_pipeline = pipeline
path_to_saved_pipeline = Path(httomo.globals.run_out_dir) / pipeline.name
# if does_contain_sweep do not inject default parameters due to issue around "sweep" aliases in yaml
if not does_contain_sweep:
path_to_pipeline = path_to_saved_pipeline
pipeline_updated = _substitute_omitted_default_values(pipeline_conf)
with open(path_to_saved_pipeline, "w") as file_descriptor:
yaml.dump(
pipeline_updated,
file_descriptor,
default_flow_style=False,
sort_keys=False,
)
with open(path_to_pipeline, "r") as input:
pipeline_contents = input.read()
with open(Path(httomo.globals.run_out_dir) / pipeline.name, "a") as output:
with open(path_to_saved_pipeline, "w") as output:
output.write(f"# Created with HTTomo version {__version__}\n")
output.write(pipeline_contents)
# If pipeline is a JSON string, write it to a file in the output directory
Expand All @@ -389,6 +412,30 @@ def initialise_output_directory(pipeline: Union[Path, str]) -> None:
f.write(pipeline)


def _substitute_omitted_default_values(
pipeline_conf: PipelineConfig,
) -> PipelineConfig:
templates_conf = _get_template_yaml_conf(pipeline_conf)
for i, (method, template) in enumerate(zip(pipeline_conf, templates_conf)):
template_param_dict = template["parameters"]
method_params = set(method.get("parameters", {}).keys())
template_params = set(template_param_dict.keys())
omitted_params = template_params - method_params

for param in omitted_params:
# insert ommited parameter into the pipeline
Comment thread
dkazanc marked this conversation as resolved.
pipeline_conf[i]["parameters"][param] = template["parameters"][param]
return pipeline_conf


def _get_distortion_coeff_path(pipeline_conf: PipelineConfig) -> Union[None, Path]:
distortion_coeff_path = None
for method in pipeline_conf:
if "distortion_correction" in method["method"]:
distortion_coeff_path = method["parameters"]["metadata_path"]
return distortion_coeff_path


def generate_pipeline(
in_data_file: Path,
pipeline: Union[Path, str],
Expand All @@ -403,13 +450,13 @@ def generate_pipeline(
comm=method_wrapper_comm,
pipeline_format=pipeline_format,
)
pipeline = init_UiLayer.build_pipeline()
pipeline_object = init_UiLayer.build_pipeline()

# perform transformations on pipeline
tr = TransformLayer(comm=method_wrapper_comm, save_all=save_all)
pipeline = tr.transform(pipeline)
pipeline_object = tr.transform(pipeline_object)

return pipeline
return pipeline_object


def execute_high_throughput_run(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_initialise_output_directory_handles_json_string(tmp_path):
json_string = json.dumps(SAMPLE_JSON_PIPELINE)

# Call the function with a JSON string
initialise_output_directory(json_string)
initialise_output_directory(json_string, False)

# Verify directory was created
assert output_dir.exists()
Expand All @@ -206,7 +206,7 @@ def test_initialise_output_directory_handles_path_input(
pipeline_path = Path(__file__).parent.parent / standard_loader

# Call the function with a Path
initialise_output_directory(pipeline_path)
initialise_output_directory(pipeline_path, False)

# Verify directory was created
assert output_dir.exists()
Expand All @@ -219,7 +219,7 @@ def test_output_dir_created_if_doesnt_exist(tmp_path: Path, standard_loader: str
output_dir_cli_arg = tmp_path / "out"
httomo.globals.run_out_dir = output_dir_cli_arg / "httomo-output-dir"
pipeline_path = Path(__file__).parent.parent / standard_loader
initialise_output_directory(pipeline_path)
initialise_output_directory(pipeline_path, False)
assert output_dir_cli_arg.exists()
assert httomo.globals.run_out_dir.exists()

Expand Down
Loading