Skip to content

Commit a5d1649

Browse files
authored
feat(langchain): add support for tracing RunnableLambda (#15477)
## Description Adds APM and LLM Observability support for tracing `RunnableLambda` operations. Additionally, makes sure span linking happens automatically as well. ## Testing Added unit tests, verified locally as well. ## Risks None.
1 parent ba60fe5 commit a5d1649

11 files changed

+566
-15
lines changed

ddtrace/contrib/internal/langchain/patch.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,78 @@ def patched_vectorstore_init_subclass(func, instance, args, kwargs):
612612
log.warning("Unable to patch LangChain VectorStore class %s", str(cls))
613613

614614

615+
def traced_runnable_lambda_operation(is_batch: bool = False):
616+
@with_traced_module
617+
def _traced_runnable_lambda_impl(langchain_core, pin, func, instance, args, kwargs):
618+
integration: LangChainIntegration = langchain_core._datadog_integration
619+
620+
instance_name = getattr(instance, "name", None)
621+
default_name = f"{instance.__class__.__name__}.{func.__name__}"
622+
if is_batch:
623+
span_name = f"{instance_name}_batch" if instance_name else default_name
624+
else:
625+
span_name = instance_name or default_name
626+
627+
span = integration.trace(
628+
pin,
629+
span_name,
630+
submit_to_llmobs=True,
631+
instance=instance,
632+
)
633+
634+
integration.record_instance(instance, span)
635+
636+
result = None
637+
638+
try:
639+
result = func(*args, **kwargs)
640+
return result
641+
except Exception:
642+
span.set_exc_info(*sys.exc_info())
643+
raise
644+
finally:
645+
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=result, operation="runnable_lambda")
646+
span.finish()
647+
648+
return _traced_runnable_lambda_impl
649+
650+
651+
def traced_runnable_lambda_operation_async(is_batch: bool = False):
652+
@with_traced_module
653+
async def _traced_runnable_lambda_impl(langchain_core, pin, func, instance, args, kwargs):
654+
integration: LangChainIntegration = langchain_core._datadog_integration
655+
656+
instance_name = getattr(instance, "name", None)
657+
default_name = f"{instance.__class__.__name__}.{func.__name__}"
658+
if is_batch:
659+
span_name = f"{instance_name}_batch" if instance_name else default_name
660+
else:
661+
span_name = instance_name or default_name
662+
663+
span = integration.trace(
664+
pin,
665+
span_name,
666+
submit_to_llmobs=True,
667+
instance=instance,
668+
)
669+
670+
integration.record_instance(instance, span)
671+
672+
result = None
673+
674+
try:
675+
result = await func(*args, **kwargs)
676+
return result
677+
except Exception:
678+
span.set_exc_info(*sys.exc_info())
679+
raise
680+
finally:
681+
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=result, operation="runnable_lambda")
682+
span.finish()
683+
684+
return _traced_runnable_lambda_impl
685+
686+
615687
def patch():
616688
if getattr(langchain_core, "_datadog_patch", False):
617689
return
@@ -626,6 +698,7 @@ def patch():
626698
from langchain_core.language_models.chat_models import BaseChatModel
627699
from langchain_core.language_models.llms import BaseLLM
628700
from langchain_core.prompts.base import BasePromptTemplate
701+
from langchain_core.runnables.base import RunnableLambda
629702
from langchain_core.runnables.base import RunnableSequence
630703
from langchain_core.tools import BaseTool
631704
from langchain_core.vectorstores import VectorStore
@@ -651,6 +724,11 @@ def patch():
651724
wrap(RunnableSequence, "stream", traced_chain_stream(langchain_core))
652725
wrap(RunnableSequence, "astream", traced_chain_stream(langchain_core))
653726

727+
wrap(RunnableLambda, "invoke", traced_runnable_lambda_operation(is_batch=False)(langchain_core))
728+
wrap(RunnableLambda, "ainvoke", traced_runnable_lambda_operation_async(is_batch=False)(langchain_core))
729+
wrap(RunnableLambda, "batch", traced_runnable_lambda_operation(is_batch=True)(langchain_core))
730+
wrap(RunnableLambda, "abatch", traced_runnable_lambda_operation_async(is_batch=True)(langchain_core))
731+
654732
wrap(BasePromptTemplate, "invoke", patched_base_prompt_template_invoke(langchain_core))
655733
wrap(BasePromptTemplate, "ainvoke", patched_base_prompt_template_ainvoke(langchain_core))
656734

@@ -683,6 +761,10 @@ def unpatch():
683761
unwrap(langchain_core.runnables.base.RunnableSequence, "abatch")
684762
unwrap(langchain_core.runnables.base.RunnableSequence, "stream")
685763
unwrap(langchain_core.runnables.base.RunnableSequence, "astream")
764+
unwrap(langchain_core.runnables.base.RunnableLambda, "invoke")
765+
unwrap(langchain_core.runnables.base.RunnableLambda, "ainvoke")
766+
unwrap(langchain_core.runnables.base.RunnableLambda, "batch")
767+
unwrap(langchain_core.runnables.base.RunnableLambda, "abatch")
686768
unwrap(langchain_core.language_models.chat_models.BaseChatModel, "stream")
687769
unwrap(langchain_core.language_models.chat_models.BaseChatModel, "astream")
688770
unwrap(langchain_core.language_models.llms.BaseLLM, "stream")

ddtrace/llmobs/_integrations/langchain.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
"system": "system",
7272
}
7373

74-
SUPPORTED_OPERATIONS = ["llm", "chat", "chain", "embedding", "retrieval", "tool"]
74+
SUPPORTED_OPERATIONS = ["llm", "chat", "chain", "embedding", "retrieval", "tool", "runnable_lambda"]
7575
LANGCHAIN_BASE_URL_FIELDS = [
7676
"api_base",
7777
"api_host",
@@ -165,7 +165,7 @@ def _llmobs_set_tags(
165165
args: List[Any],
166166
kwargs: Dict[str, Any],
167167
response: Optional[Any] = None,
168-
operation: str = "", # oneof "llm","chat","chain","embedding","retrieval","tool"
168+
operation: str = "", # oneof SUPPORTED_OPERATIONS
169169
) -> None:
170170
"""Sets meta tags and metrics for span events to be sent to LLMObs."""
171171
if not self.llmobs_enabled:
@@ -211,6 +211,8 @@ def _llmobs_set_tags(
211211
self._llmobs_set_meta_tags_from_similarity_search(span, args, kwargs, response, is_workflow=is_workflow)
212212
elif operation == "tool":
213213
self._llmobs_set_meta_tags_from_tool(span, tool_inputs=kwargs, tool_output=response)
214+
elif operation == "runnable_lambda":
215+
self._llmobs_set_meta_tags_from_runnable_lambda(span, args, kwargs, response)
214216

215217
def _set_links(self, span: Span) -> None:
216218
"""
@@ -763,6 +765,19 @@ def _llmobs_set_meta_tags_from_tool(self, span: Span, tool_inputs: Dict[str, Any
763765
}
764766
)
765767

768+
def _llmobs_set_meta_tags_from_runnable_lambda(
769+
self, span: Span, args: List[Any], kwargs: Dict[str, Any], response: Any
770+
) -> None:
771+
inputs = get_argument_value(args, kwargs, 0, "inputs")
772+
773+
span._set_ctx_items(
774+
{
775+
SPAN_KIND: "task",
776+
INPUT_VALUE: safe_json(inputs),
777+
OUTPUT_VALUE: safe_json(response),
778+
}
779+
)
780+
766781
def _set_base_span_tags(
767782
self,
768783
span: Span,
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
langchain: Adds support for tracing ``RunnableLambda`` instances.

tests/contrib/langchain/test_langchain.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,43 @@ def test_streamed_chat_model_with_no_output(langchain_openai, openai_url):
553553
except Exception as e:
554554
if not isinstance(e, APITimeoutError):
555555
assert False, f"Expected APITimeoutError, got {e}"
556+
557+
558+
@pytest.mark.snapshot(ignores=IGNORE_FIELDS)
559+
def test_runnable_lambda_invoke(langchain_core):
560+
def add(inputs: dict) -> int:
561+
return inputs["a"] + inputs["b"]
562+
563+
runnable_lambda = langchain_core.runnables.RunnableLambda(add)
564+
result = runnable_lambda.invoke(dict(a=1, b=2))
565+
assert result == 3
566+
567+
568+
@pytest.mark.snapshot(ignores=IGNORE_FIELDS)
569+
async def test_runnable_lambda_ainvoke(langchain_core):
570+
async def add(inputs: dict) -> int:
571+
return inputs["a"] + inputs["b"]
572+
573+
runnable_lambda = langchain_core.runnables.RunnableLambda(add)
574+
result = await runnable_lambda.ainvoke(dict(a=1, b=2))
575+
assert result == 3
576+
577+
578+
@pytest.mark.snapshot(ignores=IGNORE_FIELDS)
579+
def test_runnable_lambda_batch(langchain_core):
580+
def add(inputs: dict) -> int:
581+
return inputs["a"] + inputs["b"]
582+
583+
runnable_lambda = langchain_core.runnables.RunnableLambda(add)
584+
result = runnable_lambda.batch([dict(a=1, b=2), dict(a=3, b=4), dict(a=5, b=6)])
585+
assert result == [3, 7, 11]
586+
587+
588+
@pytest.mark.snapshot(ignores=IGNORE_FIELDS)
589+
async def test_runnable_lambda_abatch(langchain_core):
590+
async def add(inputs: dict) -> int:
591+
return inputs["a"] + inputs["b"]
592+
593+
runnable_lambda = langchain_core.runnables.RunnableLambda(add)
594+
result = await runnable_lambda.abatch([dict(a=1, b=2), dict(a=3, b=4), dict(a=5, b=6)])
595+
assert result == [3, 7, 11]

0 commit comments

Comments
 (0)