Skip to content

Commit 01d0fc3

Browse files
committed
feat: add simulator.close() method for collective end-call for callbacks
1 parent 563ffb7 commit 01d0fc3

File tree

6 files changed

+35
-22
lines changed

6 files changed

+35
-22
lines changed

elastica/modules/base_system.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Basic coordinating for multiple, smaller systems that have an independently integrable
66
interface (i.e. works with symplectic or explicit routines `timestepper.py`.)
77
"""
8-
from typing import TYPE_CHECKING, Type, Generator, Any, overload
8+
from typing import TYPE_CHECKING, Type, Generator, Any, overload, Callable
99
from typing import final
1010
from elastica.typing import (
1111
SystemType,
@@ -72,6 +72,9 @@ def __init__(self) -> None:
7272
self._feature_group_callback: OperatorGroupFIFO[
7373
OperatorCallbackType, ModuleProtocol
7474
] = OperatorGroupFIFO()
75+
self._feature_group_on_close: OperatorGroupFIFO[Callable, ModuleProtocol] = (
76+
OperatorGroupFIFO()
77+
)
7578
self._feature_group_finalize: list[OperatorFinalizeType] = []
7679
# We need to initialize our mixin classes
7780
super().__init__()
@@ -282,6 +285,15 @@ def apply_callbacks(self, time: np.float64, current_step: int) -> None:
282285
for func in self._feature_group_callback:
283286
func(time=time, current_step=current_step)
284287

288+
@final
289+
def close(self) -> None:
290+
"""
291+
Call close functions for all features.
292+
Features are registered in _feature_group_on_close.
293+
"""
294+
for func in self._feature_group_on_close:
295+
func()
296+
285297

286298
if TYPE_CHECKING:
287299
from .protocol import SystemCollectionProtocol

elastica/modules/callbacks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def collect_diagnostics(
5757
_callback: ModuleProtocol = _CallBack(sys_idx)
5858
self._callback_list.append(_callback)
5959
self._feature_group_callback.append_id(_callback)
60+
self._feature_group_on_close.append_id(_callback)
6061

6162
return _callback
6263

@@ -70,6 +71,9 @@ def _finalize_callback(self: SystemCollectionWithCallbackProtocol) -> None:
7071
callback_instance.make_callback, system=self[sys_id]
7172
)
7273
self._feature_group_callback.add_operators(callback, [callback_operator])
74+
self._feature_group_on_close.add_operators(
75+
callback, [callback_instance.on_close]
76+
)
7377

7478
self._callback_list.clear()
7579
del self._callback_list

elastica/modules/memory_block.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,11 @@ def construct_memory_block_structures(
5151
temp_list_for_rigid_body_systems_idx.append(system_idx)
5252

5353
elif isinstance(sys_to_be_added, SurfaceBase):
54+
# TODO: Surface type is passive system
5455
pass
55-
# surface_system = cast(SurfaceType, sys_to_be_added)
56-
# raise NotImplementedError(
57-
# "Surfaces are not yet implemented in memory block construction."
58-
# )
5956

6057
else:
61-
raise TypeError(
62-
"{0}\n"
63-
"is not a system passing validity\n"
64-
"checks for constructing block structure. If you are sure that\n"
65-
"{0}\n"
66-
"satisfies all criteria for being a system, please add\n"
67-
"it here with correct memory block implementation.\n"
68-
"The allowed types are\n"
69-
"{1} {2} {3}".format(
70-
sys_to_be_added.__class__, RodBase, RigidBodyBase, SurfaceBase
71-
)
72-
)
58+
continue # No error:: any typechecking should be finished by BaseSystemCollection._check_type
7359

7460
if temp_list_for_cosserat_rod_systems:
7561
_memory_blocks.append(

elastica/modules/protocol.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Protocol, Generator, TypeVar, Any, Type, overload, Iterator
1+
from typing import Protocol, Generator, TypeVar, Any, Type, overload, Iterator, Callable
22
from typing import TYPE_CHECKING
33
from typing_extensions import Self # python 3.11: from typing import Self
44

@@ -80,6 +80,10 @@ def apply_callbacks(self, time: np.float64, current_step: int) -> None: ...
8080

8181
def finalize(self) -> None: ...
8282

83+
_feature_group_on_close: "OperatorGroupFIFO[Callable, ModuleProtocol]"
84+
85+
def close(self) -> None: ...
86+
8387

8488
# Mixin Protocols (Used to type Self)
8589
class ConnectedSystemCollectionProtocol(SystemCollectionProtocol, Protocol):

elastica/timestepper/protocol.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ def n_stages(self) -> int: ...
2525
def step_methods(self) -> SteppersOperatorsType: ...
2626

2727
def step(
28-
self, SystemCollection: SystemCollectionType, time: np.float64, dt: np.float64
28+
self,
29+
SystemCollection: SystemCollectionType,
30+
time: np.float64 | float,
31+
dt: np.float64 | float,
2932
) -> np.float64: ...
3033

3134
def step_single_instance(

elastica/timestepper/symplectic_steppers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,15 @@ def n_stages(self: SymplecticStepperProtocol) -> int:
6666
def step(
6767
self: SymplecticStepperProtocol,
6868
SystemCollection: SystemCollectionType,
69-
time: np.float64,
70-
dt: np.float64,
69+
time: np.float64 | float,
70+
dt: np.float64 | float,
7171
) -> np.float64:
7272
return SymplecticStepperMixin.do_step(
73-
self, self.steps_and_prefactors, SystemCollection, time, dt
73+
self,
74+
self.steps_and_prefactors,
75+
SystemCollection,
76+
np.float64(time),
77+
np.float64(dt),
7478
)
7579

7680
# TODO: Merge with .step method in the future.

0 commit comments

Comments
 (0)