Skip to content

Commit eb06a1f

Browse files
vezhnickcopybara-github
authored andcommitted
Add ability to save and load rational and basic agents to/from json.
- add methods for saving and loading to components - add save / load to all main agent factories - saving and loading of associative memories PiperOrigin-RevId: 690960131 Change-Id: Ie323681fe63fbc966d70bbd21bcaf9d09f06c48d
1 parent c8a45e5 commit eb06a1f

21 files changed

+832
-69
lines changed

concordia/agents/entity_agent.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ def get_component(
107107
component = self._context_components[name]
108108
return cast(entity_component.ComponentT, component)
109109

110+
def get_act_component(self) -> entity_component.ActingComponent:
111+
return self._act_component
112+
113+
def get_all_context_components(
114+
self,
115+
) -> Mapping[str, entity_component.ContextComponent]:
116+
return types.MappingProxyType(self._context_components)
117+
110118
def _parallel_call_(
111119
self,
112120
method_name: str,

concordia/agents/entity_agent_with_logging.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
"""A modular entity agent using the new component system with side logging."""
1616

1717
from collections.abc import Mapping
18+
import copy
1819
import types
1920
from typing import Any
2021
from absl import logging
2122
from concordia.agents import entity_agent
23+
from concordia.associative_memory import formative_memories
2224
from concordia.typing import agent
2325
from concordia.typing import entity_component
2426
from concordia.utils import measurements as measurements_lib
@@ -39,6 +41,7 @@ def __init__(
3941
types.MappingProxyType({})
4042
),
4143
component_logging: measurements_lib.Measurements | None = None,
44+
config: formative_memories.AgentConfig | None = None,
4245
):
4346
"""Initializes the agent.
4447
@@ -56,6 +59,7 @@ def __init__(
5659
None, a NoOpContextProcessor will be used.
5760
context_components: The ContextComponents that will be used by the agent.
5861
component_logging: The channels where components publish events.
62+
config: The agent configuration, used for checkpointing and debug.
5963
"""
6064
super().__init__(agent_name=agent_name,
6165
act_component=act_component,
@@ -75,6 +79,7 @@ def __init__(
7579
on_error=lambda e: logging.error('Error in component logging: %s', e))
7680
else:
7781
self._channel_names = []
82+
self._config = copy.deepcopy(config)
7883

7984
def _set_log(self, log: tuple[Any, ...]) -> None:
8085
"""Set the logging object to return from get_last_log.
@@ -89,3 +94,6 @@ def _set_log(self, log: tuple[Any, ...]) -> None:
8994
def get_last_log(self):
9095
self._tick.on_next(None) # Trigger the logging.
9196
return self._log
97+
98+
def get_config(self) -> formative_memories.AgentConfig | None:
99+
return copy.deepcopy(self._config)

concordia/associative_memory/associative_memory.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
import threading
2626

2727
from concordia.associative_memory import importance_function
28+
from concordia.typing import entity_component
2829
import numpy as np
2930
import pandas as pd
3031

32+
3133
_NUM_TO_RETRIEVE_TO_CONTEXTUALIZE_IMPORTANCE = 25
3234

3335

@@ -79,6 +81,29 @@ def __init__(
7981
self._interval = clock_step_size
8082
self._stored_hashes = set()
8183

84+
def get_state(self) -> entity_component.ComponentState:
85+
"""Converts the AssociativeMemory to a dictionary."""
86+
87+
with self._memory_bank_lock:
88+
output = {
89+
'seed': self._seed,
90+
'stored_hashes': list(self._stored_hashes),
91+
'memory_bank': self._memory_bank.to_json(),
92+
}
93+
if self._interval:
94+
output['interval'] = self._interval.total_seconds()
95+
return output
96+
97+
def set_state(self, state: entity_component.ComponentState) -> None:
98+
"""Sets the AssociativeMemory from a dictionary."""
99+
100+
with self._memory_bank_lock:
101+
self._seed = state['seed']
102+
self._stored_hashes = set(state['stored_hashes'])
103+
self._memory_bank = pd.read_json(state['memory_bank'])
104+
if 'interval' in state:
105+
self._interval = datetime.timedelta(seconds=state['interval'])
106+
82107
def add(
83108
self,
84109
text: str,

concordia/associative_memory/formative_memories.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
"""This is a factory for generating memories for concordia agents."""
1717

18-
from collections.abc import Callable, Iterable, Sequence
18+
from collections.abc import Callable, Collection, Sequence
1919
import dataclasses
2020
import datetime
2121
import logging
@@ -58,9 +58,24 @@ class AgentConfig:
5858
specific_memories: str = ''
5959
goal: str = ''
6060
date_of_birth: datetime.datetime = DEFAULT_DOB
61-
formative_ages: Iterable[int] = DEFAULT_FORMATIVE_AGES
61+
formative_ages: Collection[int] = DEFAULT_FORMATIVE_AGES
6262
extras: dict[str, Any] = dataclasses.field(default_factory=dict)
6363

64+
def to_dict(self) -> dict[str, Any]:
65+
"""Converts the AgentConfig to a dictionary."""
66+
result = dataclasses.asdict(self)
67+
result['date_of_birth'] = self.date_of_birth.isoformat()
68+
return result
69+
70+
@classmethod
71+
def from_dict(cls, data: dict[str, Any]) -> 'AgentConfig':
72+
"""Initializes an AgentConfig from a dictionary."""
73+
date_of_birth = datetime.datetime.fromisoformat(
74+
data['date_of_birth']
75+
)
76+
data = data | {'date_of_birth': date_of_birth}
77+
return cls(**data)
78+
6479

6580
class FormativeMemoryFactory:
6681
"""Generator of formative memories."""

concordia/components/agent/action_spec_ignored.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
import abc
1818
import threading
19-
from typing import Final
19+
from typing import Final, Any
2020

2121
from concordia.typing import entity as entity_lib
2222
from concordia.typing import entity_component
23+
from typing_extensions import override
2324

2425

2526
class ActionSpecIgnored(
@@ -89,3 +90,11 @@ def get_named_component_pre_act_value(self, component_name: str) -> str:
8990
"""Returns the pre-act value of a named component of the parent entity."""
9091
return self.get_entity().get_component(
9192
component_name, type_=ActionSpecIgnored).get_pre_act_value()
93+
94+
@override
95+
def set_state(self, state: entity_component.ComponentState) -> Any:
96+
return None
97+
98+
@override
99+
def get_state(self) -> entity_component.ComponentState:
100+
return {}

concordia/components/agent/all_similar_memories.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Return all memories similar to a prompt and filter them for relevance.
16-
"""
15+
"""Return all memories similar to a prompt and filter them for relevance."""
1716

18-
from collections.abc import Mapping
1917
import types
18+
from typing import Mapping
2019

2120
from concordia.components.agent import action_spec_ignored
2221
from concordia.components.agent import memory_component

concordia/components/agent/concat_act_component.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,11 @@ def _log(self,
164164
'Value': result,
165165
'Prompt': prompt.view().text().splitlines(),
166166
})
167+
168+
def get_state(self) -> entity_component.ComponentState:
169+
"""Converts the component to a dictionary."""
170+
return {}
171+
172+
def set_state(self, state: entity_component.ComponentState) -> None:
173+
pass
174+

concordia/components/agent/memory_component.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ def _check_phase(self) -> None:
5757
'You can only access the memory outside of the `UPDATE` phase.'
5858
)
5959

60+
def get_state(self) -> Mapping[str, Any]:
61+
with self._lock:
62+
return self._memory.get_state()
63+
64+
def set_state(self, state: Mapping[str, Any]) -> None:
65+
with self._lock:
66+
self._memory.set_state(state)
67+
6068
def retrieve(
6169
self,
6270
query: str = '',

concordia/components/agent/plan.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,15 @@ def _make_pre_act_value(self) -> str:
163163
})
164164

165165
return result
166+
167+
def get_state(self) -> entity_component.ComponentState:
168+
"""Converts the component to JSON data."""
169+
with self._lock:
170+
return {
171+
'current_plan': self._current_plan,
172+
}
173+
174+
def set_state(self, state: entity_component.ComponentState) -> None:
175+
"""Sets the component state from JSON data."""
176+
with self._lock:
177+
self._current_plan = state['current_plan']

concordia/components/agent/question_of_query_associated_memories.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ def pre_act(
189189
def update(self) -> None:
190190
self._component.update()
191191

192+
def get_state(self) -> entity_component.ComponentState:
193+
return self._component.get_state()
194+
195+
def set_state(self, state: entity_component.ComponentState) -> None:
196+
self._component.set_state(state)
197+
192198

193199
class Identity(QuestionOfQueryAssociatedMemories):
194200
"""Identity component containing a few characteristics.

0 commit comments

Comments
 (0)