Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 211 additions & 67 deletions mellea/stdlib/sampling/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
"""Base Sampling Strategies."""

import abc
import asyncio
import math
from collections.abc import AsyncGenerator, Callable, Coroutine
from copy import deepcopy
from dataclasses import dataclass
from typing import Any

import tqdm

Expand All @@ -16,28 +21,97 @@
from .types import SamplingResult, SamplingStrategy


@dataclass
class _SamplingResultSlice:
"""Helper class for returning the result of a single sample operation."""

success: bool
generation: ModelOutputThunk
validation: list[tuple[Requirement, ValidationResult]]
context: Context
action: Component


def _get_sampling_result(
slices: list[_SamplingResultSlice],
select_from_failure: Callable[
[
list[Component],
list[ModelOutputThunk],
list[list[tuple[Requirement, ValidationResult]]],
],
int,
],
) -> SamplingResult:
sample_generations: list[ModelOutputThunk] = []
sample_validations: list[list[tuple[Requirement, ValidationResult]]] = []
sample_actions: list[Component] = []
sample_contexts: list[Context] = []

success = False
best_index = -1
for i, slice in enumerate(slices):
if slice.success and not success:
# If a success hasn't already been found, update the status and index.
success = True
best_index = i

sample_generations.append(slice.generation)
sample_validations.append(slice.validation)
sample_actions.append(slice.action)
sample_contexts.append(slice.context)

if not success:
best_index = select_from_failure(
sample_actions, sample_generations, sample_validations
)

return SamplingResult(
result_index=best_index,
success=success,
sample_generations=sample_generations,
sample_validations=sample_validations,
sample_actions=sample_actions,
sample_contexts=sample_contexts,
)


class BaseSamplingStrategy(SamplingStrategy):
"""Base class for multiple strategies that rejects samples based on given instructions."""

loop_budget: int
concurrency_budget: int

def __init__(
self, *, loop_budget: int = 1, requirements: list[Requirement] | None = None
self,
*,
loop_budget: int = 1,
concurrency_budget: int = 1,
requirements: list[Requirement] | None = None,
):
"""Initialize a new instance of the class with default parameters.

Will generate at most loop_budget * concurrency_budget requests. The sampling will end at the first valid result.
The loop budget specifies the depth of repair strategies.

For example:
- loop_budget = 1: no repair strategies will be used
- loop_budget = 3 and concurrency_budget = 1: generation -> repair -> generation -> repair -> final generation
- loop_budget = 2 and concurrency_budget = 2: each initial concurrent generation will undergo a repair strategy once and then attempt a second generation

Args:
loop_budget: Number of times to iterate through the process. Must be greater than 0.
validate: Function to validate the results against requirements. If None, validation is provided later through setter.
generate: Function to generate new model output thunks. If None, generate is provided later through setter.
concurrency_budget: Number of concurrent generations per loop. Use the default of 1 for no-concurrent sampling. Must be greater than 0.
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.

Raises:
AssertionError: If loop_budget is not greater than 0.
"""
assert loop_budget > 0, "Loop budget must be at least 1."
assert concurrency_budget > 0, "Concurrency budget must be at least 1."

self.loop_budget = loop_budget
self.concurrency_budget = concurrency_budget
self.requirements = requirements

@staticmethod
Expand Down Expand Up @@ -118,11 +192,6 @@ async def sample(

flog = FancyLogger.get_logger()

sampled_results: list[ModelOutputThunk] = []
sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = []
sampled_actions: list[Component] = []
sample_contexts: list[Context] = []

# The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress
# flag to determine whether we should show the pbar.
show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO
Expand All @@ -136,23 +205,124 @@ async def sample(
reqs += requirements
reqs = list(set(reqs))

loop_count = 0
loop_budget_range_iterator = (
tqdm.tqdm(range(self.loop_budget)) # type: ignore
if show_progress
else range(self.loop_budget) # type: ignore
total_possible_generations = self.loop_budget * self.concurrency_budget
progress_indicator = None
if show_progress:
progress_indicator = tqdm.tqdm(
iterable=range(total_possible_generations),
desc=f"{self.__class__.__name__}",
)

generators: list[AsyncGenerator[_SamplingResultSlice, Any]] = []
for _ in range(self.concurrency_budget):
generators.append(
self._subsample_iteration(
iterations=self.loop_budget,
action=action,
context=context,
backend=backend,
requirements=reqs,
validation_ctx=validation_ctx,
format=format,
model_options=model_options,
tool_calls=tool_calls,
)
)

async def async_generator_producer(
generator: AsyncGenerator[_SamplingResultSlice, Any],
queue: asyncio.Queue[_SamplingResultSlice],
):
"""Add items from an async generator to an async queue."""
async for item in generator:
await queue.put(item)

sample_slice_queue: asyncio.Queue[_SamplingResultSlice] = asyncio.Queue()
producer_tasks = [
# Use tasks so that we don't need to explicitly await each generator.
asyncio.create_task(async_generator_producer(generator, sample_slice_queue))
for generator in generators
]

slices: list[_SamplingResultSlice] = []
while not all(task.done() for task in producer_tasks):
sr_slice = await sample_slice_queue.get()
slices.append(sr_slice)

if progress_indicator:
progress_indicator.update()

if sr_slice.success:
break

# TODO: We could add a sleep here after a success to try to collect
# any other finished sample iterations. But this also risks ceding
# control to some other task / requirement validator that takes must longer
# to process and makes this sampling result take much longer.
# await asyncio.sleep(.1)

for task in producer_tasks:
task.cancel() # Works even if task is already done / cancelled.

while not sample_slice_queue.empty():
try:
# Shouldn't have to wait here since all tasks are cancelled.
sr_slice = sample_slice_queue.get_nowait()
slices.append(sr_slice)

if progress_indicator:
progress_indicator.update()
except asyncio.QueueEmpty:
# This is somewhat redundant but isn't harmful.
break

if progress_indicator:
progress_indicator.close()

s_result = _get_sampling_result(
slices=slices, select_from_failure=self.select_from_failure
)
if not s_result.success:
flog.info(
f"Invoking select_from_failure after {len(s_result.sample_generations)} failed attempts."
)
else:
flog.info("Sampling was successful.")

assert s_result.result_index < len(s_result.sample_generations), (
"The select_from_failure method did not return a valid result. It has to selected from failed_results."
)

assert s_result.result._generate_log is not None
s_result.result._generate_log.is_final_result = True

return s_result

async def _subsample_iteration(
self,
iterations: int,
action: Component,
context: Context,
backend: Backend,
requirements: list[Requirement],
*,
validation_ctx: Context | None = None,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
):
"""Helper function that represents a single sampling iteration: generating a sample and validating it."""
sampled_results: list[ModelOutputThunk] = []
sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = []
sampled_actions: list[Component] = []
sample_contexts: list[Context] = []

next_action = deepcopy(action)
next_context = context
for _ in loop_budget_range_iterator: # type: ignore
loop_count += 1
if not show_progress:
flog.info(f"Running loop {loop_count} of {self.loop_budget}")

for _ in range(iterations): # type: ignore
# run a generation pass
result, result_ctx = await backend.generate_from_context(
next_action,
action=next_action,
ctx=next_context,
format=format,
model_options=model_options,
Expand All @@ -162,7 +332,7 @@ async def sample(

# validation pass
val_scores_co = mfuncs.avalidate(
reqs=reqs,
reqs=requirements,
context=result_ctx,
backend=backend,
output=result,
Expand All @@ -173,36 +343,33 @@ async def sample(
val_scores = await val_scores_co

# match up reqs with scores
constraint_scores = list(zip(reqs, val_scores))

# collect all data
sampled_results.append(result)
sampled_scores.append(constraint_scores)
sampled_actions.append(next_action)
sample_contexts.append(result_ctx)
constraint_scores = list(zip(requirements, val_scores))

# if all vals are true -- break and return success
success = False
if all(bool(s[1]) for s in constraint_scores):
flog.info("SUCCESS")
success = True
assert (
result._generate_log is not None
) # Cannot be None after generation.
result._generate_log.is_final_result = True

# SUCCESS !!!!
return SamplingResult(
result_index=len(sampled_results) - 1,
success=True,
sample_generations=sampled_results,
sample_validations=sampled_scores,
sample_contexts=sample_contexts,
sample_actions=sampled_actions,
)
yield _SamplingResultSlice(
success=success,
generation=result,
validation=constraint_scores,
context=result_ctx,
action=next_action,
)

if success:
# End generation early.
return

else:
# log partial success and continue
count_valid = len([s for s in constraint_scores if bool(s[1])])
flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}")
# Have to append so that the repair strategy gets the correct info.
sampled_results.append(result)
sampled_scores.append(constraint_scores)
sampled_actions.append(next_action)
sample_contexts.append(result_ctx)

# If we did not pass all constraints, update the instruction and try again.
next_action, next_context = self.repair(
Expand All @@ -213,31 +380,8 @@ async def sample(
sampled_scores,
)

flog.info(
f"Invoking select_from_failure after {len(sampled_results)} failed attempts."
)

# if no valid result could be determined, find a last resort.
best_failed_index = self.select_from_failure(
sampled_actions, sampled_results, sampled_scores
)
assert best_failed_index < len(sampled_results), (
"The select_from_failure method did not return a valid result. It has to selected from failed_results."
)

assert (
sampled_results[best_failed_index]._generate_log is not None
) # Cannot be None after generation.
sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore

return SamplingResult(
result_index=best_failed_index,
success=False,
sample_generations=sampled_results,
sample_validations=sampled_scores,
sample_actions=sampled_actions,
sample_contexts=sample_contexts,
)
# End generation after all iterations are done.
return


class RejectionSamplingStrategy(BaseSamplingStrategy):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ combine-as-imports = true
split-on-trailing-comma = false

[tool.codespell]
ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot,rouge,Rouge'
ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot,rouge,Rouge,strat'
check-filenames = true
check-hidden = false
regex = "(?<![a-z])[a-z'`]+|[A-Z][a-z'`]*|[a-z]+'[a-z]*|[a-z]+(?=[_-])|[a-z]+(?=[A-Z])|\\d+"
Expand Down
Loading