Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions process/core/init.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import datetime
import getpass
import socket
import subprocess
from pathlib import Path
from typing import TYPE_CHECKING
from warnings import warn

import process
Expand Down Expand Up @@ -62,8 +65,11 @@
from process.models.stellarator.initialization import st_init
from process.models.tfcoil.base import TFCoilShapeModel

if TYPE_CHECKING:
from process.core.model import DataStructure


def init_process():
def init_process(data_structure: DataStructure):
"""Routine that calls the initialisation routines

This routine calls the main initialisation routines that set
Expand All @@ -77,7 +83,7 @@ def init_process():
process_output.OutputFileManager.open_files()

# Input any desired new initial values
inputs = parse_input_file()
inputs = parse_input_file(data_structure)

# Set active constraints
set_active_constraints()
Expand Down
33 changes: 20 additions & 13 deletions process/core/input.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
"""Handle parsing, validation, and actioning of a PROCESS input file (*IN.DAT)."""

from __future__ import annotations

import copy
import re
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any
from warnings import warn

import process
import process.data_structure as data_structure
from process.core.exceptions import ProcessValidationError, ProcessValueError
from process.core.solver.constraints import ConstraintManager

if TYPE_CHECKING:
from process.core.model import DataStructure

NumberType = int | float
ValidInputTypes = NumberType | str

Expand Down Expand Up @@ -45,7 +50,7 @@ class InputVariable:
array: bool = False
"""Is this input assigning values to an array?"""
additional_validation: (
Callable[[str, ValidInputTypes, int | None, "InputVariable"], ValidInputTypes]
Callable[[str, ValidInputTypes, int | None, InputVariable], ValidInputTypes]
| None
) = None
"""A function that takes the input variable: name, value, array index, and config (this dataclass)
Expand All @@ -56,7 +61,7 @@ class InputVariable:
been cast to the specified `type`.
"""
additional_actions: (
Callable[[str, ValidInputTypes, int | None, "InputVariable"], None] | None
Callable[[str, ValidInputTypes, int | None, InputVariable], None] | None
) = None
"""A function that takes the input variable: name, value, array index, and config (this dataclass)
as input and performs some additional action in addition to the default actions prescribed by the variables
Expand Down Expand Up @@ -184,9 +189,7 @@ def __post_init__(self):
"admv": InputVariable(
data_structure.buildings_variables, float, range=(1.0e4, 1.0e6)
),
"airtemp": InputVariable(
data_structure.water_usage_variables, float, range=(-15.0, 40.0)
),
"airtemp": InputVariable("water_use", float, range=(-15.0, 40.0)),
"alfapf": InputVariable(data_structure.pfcoil_variables, float, range=(1e-12, 1.0)),
"alstroh": InputVariable(
data_structure.pfcoil_variables, float, range=(1000000.0, 100000000000.0)
Expand Down Expand Up @@ -1770,18 +1773,14 @@ def __post_init__(self):
"water_buildings_w": InputVariable(
data_structure.buildings_variables, float, range=(10.0, 1000.0)
),
"watertemp": InputVariable(
data_structure.water_usage_variables, float, range=(0.0, 25.0)
),
"watertemp": InputVariable("water_use", float, range=(0.0, 25.0)),
"wgt": InputVariable(
data_structure.buildings_variables, float, range=(10000.0, 1000000.0)
),
"wgt2": InputVariable(
data_structure.buildings_variables, float, range=(10000.0, 1000000.0)
),
"windspeed": InputVariable(
data_structure.water_usage_variables, float, range=(0.0, 10.0)
),
"windspeed": InputVariable("water_use", float, range=(0.0, 10.0)),
"workshop_h": InputVariable(
data_structure.buildings_variables, float, range=(1.0, 100.0)
),
Expand Down Expand Up @@ -2148,7 +2147,7 @@ def __post_init__(self):
}


def parse_input_file():
def parse_input_file(data_structure_obj: DataStructure):
input_file = data_structure.global_variables.fileprefix

input_file_path = Path("IN.DAT")
Expand Down Expand Up @@ -2186,6 +2185,14 @@ def parse_input_file():

variable_config = INPUT_VARIABLES.get(variable_name)

# string indicates it should be set on the new object data structure
if isinstance(variable_config.module, str):
module = data_structure_obj
for name in variable_config.module.split("."):
module = getattr(module, name)

variable_config.module = module

if variable_config is None:
error_msg = (
f"Unrecognised input '{variable_name}' at line {line_no} of input file."
Expand Down
15 changes: 9 additions & 6 deletions process/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def __init__(self, config_file: str, solver: str = "vmcon"):
# dir changes happen in old run_process code
self.config_file = Path(config_file).resolve()
self.solver = solver
self.data = DataStructure()

def run(self):
"""Perform a VaryRun by running multiple SingleRuns.
Expand All @@ -328,7 +329,7 @@ def run(self):
setup_loggers(Path(config.wdir) / "process.log")

init.init_all_module_vars()
init.init_process()
init.init_process(self.data)

_neqns, itervars = get_neqns_itervars()
lbs, ubs = get_variable_range(itervars, config.factor)
Expand Down Expand Up @@ -405,8 +406,9 @@ def __init__(
self.validate_input(update_obsolete)
self.init_module_vars()
self.set_filenames()
self.data = DataStructure()
self.initialise()
self.models = Models()
self.models = Models(self.data)
self.solver = solver

def run(self):
Expand Down Expand Up @@ -479,7 +481,7 @@ def initialise(self):

initialise_imprad()
# Reads in input file
init.init_process()
init.init_process(self.data)

# Order optimisation parameters (arbitrary order in input file)
# Ensures consistency and makes output comparisons more straightforward
Expand Down Expand Up @@ -664,11 +666,13 @@ class Models:
engineering modules.
"""

def __init__(self):
def __init__(self, data: DataStructure):
"""Create physics and engineering model objects.

This also initialises module variables in the Fortran for that module.
"""
self.data = data

self._costs_custom = None
self._costs_1990 = Costs()
self._costs_2015 = Costs2015()
Expand Down Expand Up @@ -778,9 +782,8 @@ def setup_data_structure(self):
# This Models class should be replaced with a dataclass so we can
# iterate over the `fields`.
# This can be a disgusting temporary measure :(
data = DataStructure()
for model in self.models:
model.data = data
model.data = self.data


# setup handlers for writing to terminal (on warnings+)
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from process import main
from process.core.log import logging_model_handler
from process.core.model import DataStructure
from process.main import Models


Expand Down Expand Up @@ -235,4 +236,4 @@ def _plot_show_and_close_class(request):

@pytest.fixture
def process_models():
return Models()
return Models(DataStructure())
51 changes: 33 additions & 18 deletions tests/unit/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
import process.core.input as process_input
import process.data_structure as data_structure
from process.core.exceptions import ProcessValidationError
from process.core.model import DataStructure


@pytest.fixture
def data_structure_obj():
return DataStructure()


def _create_input_file(directory, content: str):
Expand Down Expand Up @@ -54,15 +60,15 @@ def _create_input_file(directory, content: str):
]
+ [("0.546816593988753", 0.546816593988753)],
)
def test_parse_real(epsvmc, expected, tmp_path):
def test_parse_real(epsvmc, expected, tmp_path, data_structure_obj):
"""Tests the parsing of real numbers into PROCESS.

Program to get the expected value for 0.008 provided at https://github.com/ukaea/PROCESS/pull/3067
"""
data_structure.global_variables.fileprefix = _create_input_file(
tmp_path, f"epsvmc = {epsvmc}"
)
init.init_process()
init.init_process(data_structure_obj)

assert data_structure.numerics.epsvmc == expected

Expand All @@ -78,7 +84,7 @@ def test_parse_real(epsvmc, expected, tmp_path):
[0.1293140904093427],
],
)
def test_exact_parsing(value, tmp_path):
def test_exact_parsing(value, tmp_path, data_structure_obj):
"""Tests the parsing of real numbers into PROCESS.

These tests failed using the old input parser and serve to show that the Python parser generally
Expand All @@ -87,37 +93,37 @@ def test_exact_parsing(value, tmp_path):
data_structure.global_variables.fileprefix = _create_input_file(
tmp_path, f"epsvmc = {value}"
)
init.init_process()
init.init_process(data_structure_obj)

assert data_structure.numerics.epsvmc == value


def test_parse_input(tmp_path):
def test_parse_input(tmp_path, data_structure_obj):
data_structure.global_variables.fileprefix = _create_input_file(
tmp_path,
("runtitle = my run title\nioptimz = -2\nepsvmc = 0.6\nboundl(1) = 0.5"),
)
init.init_process()
init.init_process(data_structure_obj)

assert data_structure.global_variables.runtitle == "my run title"
assert data_structure.numerics.ioptimz == -2
assert pytest.approx(data_structure.numerics.epsvmc) == 0.6
assert pytest.approx(data_structure.numerics.boundl[0]) == 0.5


def test_input_choices(tmp_path):
def test_input_choices(tmp_path, data_structure_obj):
data_structure.global_variables.fileprefix = _create_input_file(
tmp_path, ("ioptimz = -1")
)

with pytest.raises(ProcessValidationError):
init.init_process()
init.init_process(data_structure_obj)


@pytest.mark.parametrize(
("input_file_value"), ((-0.01,), (1.1,)), ids=("violate lower", "violate upper")
)
def test_input_range(tmp_path, input_file_value):
def test_input_range(tmp_path, input_file_value, data_structure_obj):
data_structure.global_variables.fileprefix = _create_input_file(
tmp_path, (f"epsvmc = {input_file_value}")
)
Expand All @@ -126,42 +132,51 @@ def test_input_range(tmp_path, input_file_value):
assert process_input.INPUT_VARIABLES["epsvmc"].range == (0.0, 1.0)

with pytest.raises(ProcessValidationError):
init.init_process()
init.init_process(data_structure_obj)


def test_input_array_when_not(tmp_path):
def test_input_array_when_not(tmp_path, data_structure_obj):
data_structure.global_variables.fileprefix = _create_input_file(
tmp_path, ("epsvmc(1) = 0.5")
)

with pytest.raises(ProcessValidationError):
init.init_process()
init.init_process(data_structure_obj)


def test_input_not_array_when_is(tmp_path):
def test_input_not_array_when_is(tmp_path, data_structure_obj):
data_structure.global_variables.fileprefix = _create_input_file(
tmp_path, ("boundl = 0.5")
)

with pytest.raises(ProcessValidationError):
init.init_process()
init.init_process(data_structure_obj)


def test_input_float_when_int(tmp_path):
def test_input_float_when_int(tmp_path, data_structure_obj):
data_structure.global_variables.fileprefix = _create_input_file(
tmp_path, ("ioptimz = 0.5")
)

with pytest.raises(ProcessValidationError):
init.init_process()
init.init_process(data_structure_obj)


def test_input_array(tmp_path):
def test_input_array(tmp_path, data_structure_obj):
data_structure.global_variables.fileprefix = _create_input_file(
tmp_path, ("boundl = 0.1, 0.2, 1.0, 0.0, 1.0e2")
)

init.init_process()
init.init_process(data_structure_obj)
np.testing.assert_array_equal(
data_structure.numerics.boundl[:6], [0.1, 0.2, 1.0, 0.0, 1.0e2, 0]
)


def test_input_on_new_data_structure(tmp_path, data_structure_obj):
data_structure.global_variables.fileprefix = _create_input_file(
tmp_path, ("windspeed = 1.22")
)

init.init_process(data_structure_obj)
assert data_structure_obj.water_use.windspeed == 1.22
2 changes: 2 additions & 0 deletions tests/unit/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from process import data_structure, main
from process.core.model import DataStructure
from process.main import Process, SingleRun, VaryRun


Expand Down Expand Up @@ -136,6 +137,7 @@ def single_run(monkeypatch, input_file, tmp_path):

single_run.input_file = str(temp_input_file)
single_run.models = None
single_run.data = DataStructure()
single_run.set_filenames()
single_run.initialise()
return single_run
Expand Down
Loading