Skip to content

Commit 563ffb7

Browse files
committed
feat: fix ExportCallback example case
1 parent 5d27327 commit 563ffb7

File tree

3 files changed

+15
-26
lines changed

3 files changed

+15
-26
lines changed

elastica/callback_functions.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ def make_callback(self, system: T, time: np.float64, current_step: int) -> None:
4343
Simulation step.
4444
4545
"""
46-
pass
46+
47+
def on_close(self) -> None:
48+
"""
49+
This method is called collectively when when .close() is
50+
called by the system collection.
51+
"""
4752

4853

4954
class MyCallBack(CallBackBaseClass):
@@ -264,15 +269,10 @@ def get_last_saved_path(self) -> Optional[str]:
264269
else:
265270
return self.save_path.format(self.file_count - 1, self._ext)
266271

267-
def close(self) -> None:
272+
def on_close(self) -> None:
268273
"""
269-
Save residual buffer
274+
Save residual buffer.
275+
Can be called using `simulator.close()`.
270276
"""
271277
if self.buffer_size:
272278
self._dump()
273-
274-
def clear(self) -> None:
275-
"""
276-
Alias to `close`
277-
"""
278-
self.close()

elastica/modules/memory_block.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Cosserat Rods, Rigid Body etc.
44
"""
55
from typing import cast
6+
67
from elastica.typing import (
78
RodType,
89
RigidBodyType,

tests/test_callback_functions.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -171,22 +171,22 @@ def test_export_call_back_step_skip_param(self, step_skip):
171171
callback = ExportCallBack(step_skip, "rod", temp_dir_path, "npz")
172172
callback.make_callback(mock_rod, 1, step_skip - 1)
173173
# Check empty
174-
callback.clear()
174+
callback.on_close()
175175
saved_path_name = callback.get_last_saved_path()
176176
assert saved_path_name is None, "No file should be saved."
177177

178178
# Check saved
179179
callback.make_callback(mock_rod, 1, step_skip)
180-
callback.clear()
180+
callback.on_close()
181181
saved_path_name = callback.get_last_saved_path()
182182
assert saved_path_name is not None, "File should be saved."
183183
assert os.path.exists(saved_path_name), "File should be saved"
184184

185185
# Check saved file number
186186
callback.make_callback(mock_rod, 1, step_skip * 2)
187-
callback.clear()
187+
callback.on_close()
188188
callback.make_callback(mock_rod, 1, step_skip * 5)
189-
callback.clear()
189+
callback.on_close()
190190
saved_path_name = callback.get_last_saved_path()
191191
assert (
192192
str(2) in saved_path_name
@@ -222,19 +222,7 @@ def test_export_call_back_close_test(self, rng):
222222
)
223223
for step in range(10):
224224
callback.make_callback(mock_rod, 1, step)
225-
callback.close()
226-
saved_path_name = callback.get_last_saved_path()
227-
assert os.path.exists(saved_path_name), "File is not saved."
228-
229-
def test_export_call_back_clear_test(self, rng):
230-
mock_rod = MockRodWithElements(5)
231-
with tempfile.TemporaryDirectory() as temp_dir_path:
232-
callback = ExportCallBack(
233-
1, "rod", temp_dir_path, "npz", file_save_interval=50
234-
)
235-
for step in range(10):
236-
callback.make_callback(mock_rod, 1, step)
237-
callback.clear()
225+
callback.on_close()
238226
saved_path_name = callback.get_last_saved_path()
239227
assert os.path.exists(saved_path_name), "File is not saved."
240228

0 commit comments

Comments
 (0)