diff --git a/data-pipeline-ffi/src/error.rs b/data-pipeline-ffi/src/error.rs index 9781e68357..0349a2b9ec 100644 --- a/data-pipeline-ffi/src/error.rs +++ b/data-pipeline-ffi/src/error.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use data_pipeline::trace_exporter::error::{ - AgentErrorKind, BuilderErrorKind, NetworkErrorKind, TraceExporterError, + AgentErrorKind, BuilderErrorKind, InternalErrorKind, NetworkErrorKind, TraceExporterError, }; use std::ffi::{c_char, CString}; use std::fmt::Display; @@ -32,6 +32,7 @@ pub enum ExporterErrorCode { NetworkUnknown, Serde, TimedOut, + Internal, } impl Display for ExporterErrorCode { @@ -57,6 +58,7 @@ impl Display for ExporterErrorCode { Self::NetworkUnknown => write!(f, "Unknown network error"), Self::Serde => write!(f, "Serialization/Deserialization error"), Self::TimedOut => write!(f, "Operation timed out"), + Self::Internal => write!(f, "Internal error"), } } } @@ -89,6 +91,9 @@ impl From for ExporterError { BuilderErrorKind::InvalidTelemetryConfig => ExporterErrorCode::InvalidArgument, BuilderErrorKind::InvalidConfiguration(_) => ExporterErrorCode::InvalidArgument, }, + TraceExporterError::Internal(e) => match e { + InternalErrorKind::InvalidWorkerState(_) => ExporterErrorCode::Internal, + }, TraceExporterError::Deserialization(_) => ExporterErrorCode::Serde, TraceExporterError::Io(e) => match e.kind() { IoErrorKind::InvalidData => ExporterErrorCode::InvalidData, diff --git a/data-pipeline/src/agent_info/fetcher.rs b/data-pipeline/src/agent_info/fetcher.rs index c429c96651..dfa0f48590 100644 --- a/data-pipeline/src/agent_info/fetcher.rs +++ b/data-pipeline/src/agent_info/fetcher.rs @@ -6,8 +6,7 @@ use super::{schema::AgentInfo, AgentInfoArc}; use anyhow::{anyhow, Result}; use arc_swap::ArcSwapOption; -use ddcommon::hyper_migration; -use ddcommon::Endpoint; +use ddcommon::{hyper_migration, worker::Worker, Endpoint}; use http_body_util::BodyExt; use hyper::{self, body::Buf, header::HeaderName}; use std::sync::Arc; @@ -96,12 +95,13 @@ pub async fn fetch_info(info_endpoint: &Endpoint) -> Result> { /// # Example /// ```no_run /// # use anyhow::Result; +/// # use ddcommon::worker::Worker; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// // Define the endpoint /// let endpoint = ddcommon::Endpoint::from_url("http://localhost:8126/info".parse().unwrap()); /// // Create the fetcher -/// let fetcher = data_pipeline::agent_info::AgentInfoFetcher::new( +/// let mut fetcher = data_pipeline::agent_info::AgentInfoFetcher::new( /// endpoint, /// std::time::Duration::from_secs(5 * 60), /// ); @@ -122,6 +122,7 @@ pub async fn fetch_info(info_endpoint: &Endpoint) -> Result> { /// # Ok(()) /// # } /// ``` +#[derive(Debug)] pub struct AgentInfoFetcher { info_endpoint: Endpoint, info: AgentInfoArc, @@ -139,11 +140,20 @@ impl AgentInfoFetcher { } } + /// Return an AgentInfoArc storing the info received by the agent. + /// + /// When the fetcher is running it updates the AgentInfoArc when the agent's info changes. + pub fn get_info(&self) -> AgentInfoArc { + self.info.clone() + } +} + +impl Worker for AgentInfoFetcher { /// Start fetching the info endpoint with the given interval. /// /// # Warning /// This method does not return and should be called within a dedicated task. - pub async fn run(&self) { + async fn run(&mut self) { loop { let current_info = self.info.load(); let current_hash = current_info.as_ref().map(|info| info.state_hash.as_str()); @@ -163,13 +173,6 @@ impl AgentInfoFetcher { sleep(self.refresh_interval).await; } } - - /// Return an AgentInfoArc storing the info received by the agent. - /// - /// When the fetcher is running it updates the AgentInfoArc when the agent's info changes. - pub fn get_info(&self) -> AgentInfoArc { - self.info.clone() - } } #[cfg(test)] @@ -328,7 +331,7 @@ mod tests { }) .await; let endpoint = Endpoint::from_url(server.url("/info").parse().unwrap()); - let fetcher = AgentInfoFetcher::new(endpoint.clone(), Duration::from_millis(100)); + let mut fetcher = AgentInfoFetcher::new(endpoint.clone(), Duration::from_millis(100)); let info = fetcher.get_info(); assert!(info.load().is_none()); tokio::spawn(async move { diff --git a/data-pipeline/src/lib.rs b/data-pipeline/src/lib.rs index e2d7e2571b..190acb3e88 100644 --- a/data-pipeline/src/lib.rs +++ b/data-pipeline/src/lib.rs @@ -12,6 +12,7 @@ pub mod agent_info; mod health_metrics; +mod pausable_worker; #[allow(missing_docs)] pub mod span_concentrator; #[allow(missing_docs)] diff --git a/data-pipeline/src/pausable_worker.rs b/data-pipeline/src/pausable_worker.rs new file mode 100644 index 0000000000..3ff373b0b6 --- /dev/null +++ b/data-pipeline/src/pausable_worker.rs @@ -0,0 +1,172 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +//! Defines a pausable worker to be able to stop background processes before forks + +use ddcommon::worker::Worker; +use std::fmt::Display; +use tokio::{ + runtime::Runtime, + select, + task::{JoinError, JoinHandle}, +}; +use tokio_util::sync::CancellationToken; + +/// A pausable worker which can be paused and restarted on forks. +/// +/// Used to allow a [`ddcommon::worker::Worker`] to be paused while saving its state when dropping +/// a tokio runtime to be able to restart with the same state on a new runtime. This is used to +/// stop all threads before a fork to avoid deadlocks in child. +/// +/// # Time-to-pause +/// This loop should yield regularly to reduce time-to-pause. See [`tokio::task::yield_now`]. +/// +/// # Cancellation safety +/// The main loop can be interrupted at any yield point (`.await`ed call). The state of the worker +/// at this point will be saved and used to restart the worker. To be able to safely restart, the +/// worker must be in a valid state on every call to `.await`. +/// See [`tokio::select#cancellation-safety`] for more details. +#[derive(Debug)] +pub enum PausableWorker { + Running { + handle: JoinHandle, + stop_token: CancellationToken, + }, + Paused { + worker: T, + }, + InvalidState, +} + +#[derive(Debug)] +pub enum PausableWorkerError { + InvalidState, + TaskAborted, +} + +impl Display for PausableWorkerError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PausableWorkerError::InvalidState => { + write!(f, "Worker is in an invalid state and must be recreated.") + } + PausableWorkerError::TaskAborted => { + write!(f, "Worker task has been aborted and state has been lost.") + } + } + } +} + +impl core::error::Error for PausableWorkerError {} + +impl PausableWorker { + /// Create a new pausable worker from the given worker. + pub fn new(worker: T) -> Self { + Self::Paused { worker } + } + + /// Start the worker on the given runtime. + /// + /// The worker's main loop will be run on the runtime. + /// + /// # Errors + /// Fails if the worker is in an invalid state. + pub fn start(&mut self, rt: &Runtime) -> Result<(), PausableWorkerError> { + if let Self::Running { .. } = self { + Ok(()) + } else if let Self::Paused { mut worker } = std::mem::replace(self, Self::InvalidState) { + // Worker is temporarily in an invalid state, but since this block is failsafe it will + // be replaced by a valid state. + let stop_token = CancellationToken::new(); + let cloned_token = stop_token.clone(); + let handle = rt.spawn(async move { + select! { + _ = worker.run() => {worker} + _ = cloned_token.cancelled() => {worker} + } + }); + + *self = PausableWorker::Running { handle, stop_token }; + Ok(()) + } else { + Err(PausableWorkerError::InvalidState) + } + } + + /// Pause the worker saving it's state to be restarted. + /// + /// # Errors + /// Fails if the worker handle has been aborted preventing the worker from being retrieved. + pub async fn pause(&mut self) -> Result<(), PausableWorkerError> { + match self { + PausableWorker::Running { handle, stop_token } => { + stop_token.cancel(); + if let Ok(worker) = handle.await { + *self = PausableWorker::Paused { worker }; + Ok(()) + } else { + // The task has been aborted and the worker can't be retrieved. + *self = PausableWorker::InvalidState; + Err(PausableWorkerError::TaskAborted) + } + } + PausableWorker::Paused { .. } => Ok(()), + PausableWorker::InvalidState => Err(PausableWorkerError::InvalidState), + } + } + + /// Wait for the run method of the worker to exit. + pub async fn join(self) -> Result<(), JoinError> { + if let PausableWorker::Running { handle, .. } = self { + handle.await?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use tokio::{runtime::Builder, time::sleep}; + + use super::*; + use std::{ + sync::mpsc::{channel, Sender}, + time::Duration, + }; + + /// Test worker incrementing the state and sending it with the sender. + struct TestWorker { + state: u32, + sender: Sender, + } + + impl Worker for TestWorker { + async fn run(&mut self) { + loop { + let _ = self.sender.send(self.state); + self.state += 1; + sleep(Duration::from_millis(100)).await; + } + } + } + + #[test] + fn test_restart() { + let (sender, receiver) = channel::(); + let worker = TestWorker { state: 0, sender }; + let runtime = Builder::new_multi_thread().enable_time().build().unwrap(); + let mut pausable_worker = PausableWorker::new(worker); + + pausable_worker.start(&runtime).unwrap(); + + assert_eq!(receiver.recv().unwrap(), 0); + runtime.block_on(async { pausable_worker.pause().await.unwrap() }); + // Empty the message queue and get the last message + let mut next_message = 1; + for message in receiver.try_iter() { + next_message = message + 1; + } + pausable_worker.start(&runtime).unwrap(); + assert_eq!(receiver.recv().unwrap(), next_message); + } +} diff --git a/data-pipeline/src/stats_exporter.rs b/data-pipeline/src/stats_exporter.rs index d434ace74e..003882de50 100644 --- a/data-pipeline/src/stats_exporter.rs +++ b/data-pipeline/src/stats_exporter.rs @@ -14,7 +14,7 @@ use std::{ use crate::{span_concentrator::SpanConcentrator, trace_exporter::TracerMetadata}; use datadog_trace_protobuf::pb; use datadog_trace_utils::send_with_retry::{send_with_retry, RetryStrategy}; -use ddcommon::Endpoint; +use ddcommon::{worker::Worker, Endpoint}; use hyper; use tokio::select; use tokio_util::sync::CancellationToken; @@ -127,13 +127,15 @@ impl StatsExporter { .flush(time::SystemTime::now(), force_flush), ) } +} +impl Worker for StatsExporter { /// Run loop of the stats exporter /// /// Once started, the stats exporter will flush and send stats on every `self.flush_interval`. /// If the `self.cancellation_token` is cancelled, the exporter will force flush all stats and /// return. - pub async fn run(&mut self) { + async fn run(&mut self) { loop { select! { _ = self.cancellation_token.cancelled() => { diff --git a/data-pipeline/src/telemetry/mod.rs b/data-pipeline/src/telemetry/mod.rs index 938409a6f1..9ffbe17a4e 100644 --- a/data-pipeline/src/telemetry/mod.rs +++ b/data-pipeline/src/telemetry/mod.rs @@ -12,11 +12,11 @@ use datadog_trace_utils::{ }; use ddcommon::tag::Tag; use ddtelemetry::worker::{ - LifecycleAction, TelemetryActions, TelemetryWorkerBuilder, TelemetryWorkerFlavor, - TelemetryWorkerHandle, + LifecycleAction, TelemetryActions, TelemetryWorker, TelemetryWorkerBuilder, + TelemetryWorkerFlavor, TelemetryWorkerHandle, }; use std::{collections::HashMap, time::Duration}; -use tokio::task::JoinHandle; +use tokio::runtime::Handle; /// Structure to build a Telemetry client. /// @@ -86,7 +86,10 @@ impl TelemetryClientBuilder { } /// Builds the telemetry client. - pub async fn build(self) -> Result { + pub fn build( + self, + runtime: Handle, + ) -> Result<(TelemetryClient, TelemetryWorker), TelemetryError> { #[allow(clippy::unwrap_used)] let mut builder = TelemetryWorkerBuilder::new_fetch_host( self.service_name.unwrap(), @@ -102,16 +105,17 @@ impl TelemetryClientBuilder { builder.runtime_id = Some(id); } - let (worker, handle) = builder - .spawn() - .await + let (worker_handle, worker) = builder + .build_worker(runtime) .map_err(|e| TelemetryError::Builder(e.to_string()))?; - Ok(TelemetryClient { - handle, - metrics: Metrics::new(&worker), + Ok(( + TelemetryClient { + metrics: Metrics::new(&worker_handle), + worker: worker_handle, + }, worker, - }) + )) } } @@ -120,7 +124,6 @@ impl TelemetryClientBuilder { pub struct TelemetryClient { metrics: Metrics, worker: TelemetryWorkerHandle, - handle: JoinHandle<()>, } /// Telemetry describing the sending of a trace payload @@ -246,41 +249,34 @@ impl TelemetryClient { /// Starts the client pub async fn start(&self) { - if let Err(_e) = self + _ = self .worker .send_msg(TelemetryActions::Lifecycle(LifecycleAction::Start)) - .await - { - self.handle.abort(); - } + .await; } /// Shutdowns the telemetry client. pub async fn shutdown(self) { - if let Err(_e) = self + _ = self .worker .send_msg(TelemetryActions::Lifecycle(LifecycleAction::Stop)) - .await - { - self.handle.abort(); - } - - let _ = self.handle.await; + .await; } } #[cfg(test)] mod tests { - use ddcommon::hyper_migration; + use ddcommon::{hyper_migration, worker::Worker}; use httpmock::Method::POST; use httpmock::MockServer; use hyper::{Response, StatusCode}; use regex::Regex; + use tokio::time::sleep; use super::*; async fn get_test_client(url: &str) -> TelemetryClient { - TelemetryClientBuilder::default() + let (client, mut worker) = TelemetryClientBuilder::default() .set_service_name("test_service") .set_language("test_language") .set_language_version("test_language_version") @@ -288,9 +284,10 @@ mod tests { .set_url(url) .set_heartbeat(100) .set_debug_enabled(true) - .build() - .await - .unwrap() + .build(Handle::current()) + .unwrap(); + tokio::spawn(async move { worker.run().await }); + client } #[test] @@ -320,15 +317,14 @@ mod tests { } #[cfg_attr(miri, ignore)] - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn spawn_test() { let client = TelemetryClientBuilder::default() .set_service_name("test_service") .set_language("test_language") .set_language_version("test_language_version") .set_tracer_version("test_tracer_version") - .build() - .await; + .build(Handle::current()); assert!(client.is_ok()); } @@ -356,6 +352,9 @@ mod tests { client.start().await; let _ = client.send(&data); client.shutdown().await; + while telemetry_srv.hits_async().await == 0 { + sleep(Duration::from_millis(10)).await; + } telemetry_srv.assert_hits_async(1).await; } @@ -382,6 +381,9 @@ mod tests { client.start().await; let _ = client.send(&data); client.shutdown().await; + while telemetry_srv.hits_async().await == 0 { + sleep(Duration::from_millis(10)).await; + } telemetry_srv.assert_hits_async(1).await; } @@ -408,6 +410,9 @@ mod tests { client.start().await; let _ = client.send(&data); client.shutdown().await; + while telemetry_srv.hits_async().await == 0 { + sleep(Duration::from_millis(10)).await; + } telemetry_srv.assert_hits_async(1).await; } @@ -434,6 +439,9 @@ mod tests { client.start().await; let _ = client.send(&data); client.shutdown().await; + while telemetry_srv.hits_async().await == 0 { + sleep(Duration::from_millis(10)).await; + } telemetry_srv.assert_hits_async(1).await; } @@ -460,6 +468,9 @@ mod tests { client.start().await; let _ = client.send(&data); client.shutdown().await; + while telemetry_srv.hits_async().await == 0 { + sleep(Duration::from_millis(10)).await; + } telemetry_srv.assert_hits_async(1).await; } @@ -486,6 +497,9 @@ mod tests { client.start().await; let _ = client.send(&data); client.shutdown().await; + while telemetry_srv.hits_async().await == 0 { + sleep(Duration::from_millis(10)).await; + } telemetry_srv.assert_hits_async(1).await; } @@ -512,6 +526,9 @@ mod tests { client.start().await; let _ = client.send(&data); client.shutdown().await; + while telemetry_srv.hits_async().await == 0 { + sleep(Duration::from_millis(10)).await; + } telemetry_srv.assert_hits_async(1).await; } @@ -538,6 +555,9 @@ mod tests { client.start().await; let _ = client.send(&data); client.shutdown().await; + while telemetry_srv.hits_async().await == 0 { + sleep(Duration::from_millis(10)).await; + } telemetry_srv.assert_hits_async(1).await; } @@ -675,10 +695,10 @@ mod tests { .set_url(&server.url("/")) .set_heartbeat(100) .set_runtime_id("foo") - .build() - .await; + .build(Handle::current()); - let client = result.unwrap(); + let (client, mut worker) = result.unwrap(); + tokio::spawn(async move { worker.run().await }); client.start().await; client @@ -688,6 +708,9 @@ mod tests { }) .unwrap(); client.shutdown().await; + while telemetry_srv.hits_async().await == 0 { + sleep(Duration::from_millis(10)).await; + } // One payload generate-metrics telemetry_srv.assert_hits_async(1).await; } diff --git a/data-pipeline/src/trace_exporter/error.rs b/data-pipeline/src/trace_exporter/error.rs index 5d4adefc00..9fe9303c57 100644 --- a/data-pipeline/src/trace_exporter/error.rs +++ b/data-pipeline/src/trace_exporter/error.rs @@ -50,6 +50,25 @@ impl Display for BuilderErrorKind { } } } + +/// Represents different kinds of internal errors. +#[derive(Debug, PartialEq)] +pub enum InternalErrorKind { + /// Indicates that some background workers are in an invalid state. The associated `String` + /// contains the error message. + InvalidWorkerState(String), +} + +impl Display for InternalErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + InternalErrorKind::InvalidWorkerState(msg) => { + write!(f, "Invalid worker state: {}", msg) + } + } + } +} + /// Represents different kinds of network errors. #[derive(Copy, Clone, Debug)] pub enum NetworkErrorKind { @@ -147,6 +166,8 @@ pub enum TraceExporterError { Agent(AgentErrorKind), /// Invalid builder input. Builder(BuilderErrorKind), + /// Error internal to the trace exporter. + Internal(InternalErrorKind), /// Error in deserialization of incoming trace payload. Deserialization(DecodeError), /// Generic IO error. @@ -272,6 +293,7 @@ impl Display for TraceExporterError { match self { TraceExporterError::Agent(e) => std::fmt::Display::fmt(e, f), TraceExporterError::Builder(e) => std::fmt::Display::fmt(e, f), + TraceExporterError::Internal(e) => std::fmt::Display::fmt(e, f), TraceExporterError::Deserialization(e) => std::fmt::Display::fmt(e, f), TraceExporterError::Io(e) => std::fmt::Display::fmt(e, f), TraceExporterError::Network(e) => std::fmt::Display::fmt(e, f), diff --git a/data-pipeline/src/trace_exporter/mod.rs b/data-pipeline/src/trace_exporter/mod.rs index a660cc5828..6fa497228e 100644 --- a/data-pipeline/src/trace_exporter/mod.rs +++ b/data-pipeline/src/trace_exporter/mod.rs @@ -4,8 +4,10 @@ pub mod agent_response; pub mod error; use self::agent_response::AgentResponse; use crate::agent_info::{AgentInfoArc, AgentInfoFetcher}; +use crate::pausable_worker::PausableWorker; +use crate::stats_exporter::StatsExporter; use crate::telemetry::{SendPayloadTelemetry, TelemetryClient, TelemetryClientBuilder}; -use crate::trace_exporter::error::{RequestError, TraceExporterError}; +use crate::trace_exporter::error::{InternalErrorKind, RequestError, TraceExporterError}; use crate::{ health_metrics, health_metrics::HealthMetric, span_concentrator::SpanConcentrator, stats_exporter, @@ -21,7 +23,9 @@ use ddcommon::header::{ APPLICATION_MSGPACK_STR, DATADOG_SEND_REAL_HTTP_STATUS_STR, DATADOG_TRACE_COUNT_STR, }; use ddcommon::tag::Tag; +use ddcommon::MutexExt; use ddcommon::{hyper_migration, parse_uri, tag, Endpoint}; +use ddtelemetry::worker::TelemetryWorker; use dogstatsd_client::{new, Client, DogStatsDAction}; use either::Either; use error::BuilderErrorKind; @@ -32,7 +36,7 @@ use std::io; use std::sync::{Arc, Mutex}; use std::time::Duration; use std::{borrow::Borrow, collections::HashMap, str::FromStr, time}; -use tokio::{runtime::Runtime, task::JoinHandle}; +use tokio::runtime::Runtime; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; @@ -157,10 +161,16 @@ enum StatsComputationStatus { Enabled { stats_concentrator: Arc>, cancellation_token: CancellationToken, - exporter_handle: JoinHandle<()>, }, } +#[derive(Debug)] +struct TraceExporterWorkers { + pub info: PausableWorker, + pub stats: Option>, + pub telemetry: Option>, +} + /// The TraceExporter ingest traces from the tracers serialized as messagepack and forward them to /// the agent while applying some transformation. /// @@ -187,7 +197,7 @@ pub struct TraceExporter { input_format: TraceExporterInputFormat, output_format: TraceExporterOutputFormat, // TODO - do something with the response callback - https://datadoghq.atlassian.net/browse/APMSP-1019 - runtime: Runtime, + runtime: Arc>>>, /// None if dogstatsd is disabled dogstatsd: Option, common_stats_tags: Vec, @@ -196,6 +206,7 @@ pub struct TraceExporter { agent_info: AgentInfoArc, previous_info_state: ArcSwapOption, telemetry: Option, + workers: Arc>, } enum DeserInputFormat { @@ -209,6 +220,72 @@ impl TraceExporter { TraceExporterBuilder::default() } + fn runtime(&self) -> Result, TraceExporterError> { + match self.runtime.lock_or_panic().as_ref() { + Some(runtime) => Ok(runtime.clone()), + None => self.run_worker(), + } + } + + pub fn run_worker(&self) -> Result, TraceExporterError> { + let mut runtime_guard = self.runtime.lock_or_panic(); + let runtime = match runtime_guard.as_ref() { + Some(runtime) => { + // Runtime already running + runtime.clone() + } + None => { + // Create a new current thread runtime with all features enabled + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()?, + ); + *runtime_guard = Some(runtime.clone()); + runtime + } + }; + + // Restart workers + let mut workers = self.workers.lock_or_panic(); + workers.info.start(&runtime).map_err(|e| { + TraceExporterError::Internal(InternalErrorKind::InvalidWorkerState(e.to_string())) + })?; + if let Some(stats_worker) = &mut workers.stats { + stats_worker.start(&runtime).map_err(|e| { + TraceExporterError::Internal(InternalErrorKind::InvalidWorkerState(e.to_string())) + })?; + } + if let Some(telemetry_worker) = &mut workers.telemetry { + telemetry_worker.start(&runtime).map_err(|e| { + TraceExporterError::Internal(InternalErrorKind::InvalidWorkerState(e.to_string())) + })?; + if let Some(client) = &self.telemetry { + runtime.block_on(client.start()); + } + }; + Ok(runtime) + } + + pub fn stop_worker(&self) { + let runtime = self.runtime.lock_or_panic().take(); + if let Some(ref rt) = runtime { + // Stop workers to save their state + let mut workers = self.workers.lock_or_panic(); + rt.block_on(async { + let _ = workers.info.pause().await; + if let Some(stats_worker) = &mut workers.stats { + let _ = stats_worker.pause().await; + }; + if let Some(telemetry_worker) = &mut workers.telemetry { + let _ = telemetry_worker.pause().await; + }; + }); + } + // Drop runtime to shutdown all threads + drop(runtime); + } + /// Send msgpack serialized traces to the agent /// /// # Arguments @@ -242,53 +319,52 @@ impl TraceExporter { } /// Safely shutdown the TraceExporter and all related tasks - pub fn shutdown(self, timeout: Option) -> Result<(), TraceExporterError> { + pub fn shutdown(mut self, timeout: Option) -> Result<(), TraceExporterError> { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + if let Some(timeout) = timeout { - match self.runtime.block_on(async { - tokio::time::timeout(timeout, async { - let stats_status: Option = - Arc::::into_inner( - self.client_side_stats.into_inner(), - ); - if let Some(StatsComputationStatus::Enabled { - stats_concentrator: _, - cancellation_token, - exporter_handle, - }) = stats_status - { - cancellation_token.cancel(); - let _ = exporter_handle.await; - } - if let Some(telemetry) = self.telemetry { - telemetry.shutdown().await; - } - }) - .await - }) { + match runtime + .block_on(async { tokio::time::timeout(timeout, self.shutdown_async()).await }) + { Ok(()) => Ok(()), Err(e) => Err(TraceExporterError::Io(e.into())), } } else { - self.runtime.block_on(async { - let stats_status: Option = - Arc::::into_inner(self.client_side_stats.into_inner()); - if let Some(StatsComputationStatus::Enabled { - stats_concentrator: _, - cancellation_token, - exporter_handle, - }) = stats_status - { - cancellation_token.cancel(); - let _ = exporter_handle.await; - } - if let Some(telemetry) = self.telemetry { - telemetry.shutdown().await; - } - }); + runtime.block_on(self.shutdown_async()); Ok(()) } } + /// Future used inside `Self::shutdown`. + /// + /// This function should not take ownership of the trace exporter as it will cause the runtime + /// stored in the trace exporter to be dropped in a non-blocking context causing a panic. + async fn shutdown_async(&mut self) { + let stats_status = self.client_side_stats.load(); + if let StatsComputationStatus::Enabled { + cancellation_token, .. + } = stats_status.as_ref() + { + cancellation_token.cancel(); + + let stats_worker = self.workers.lock_or_panic().stats.take(); + + if let Some(stats_worker) = stats_worker { + let _ = stats_worker.join().await; + } + } + if let Some(telemetry) = self.telemetry.take() { + telemetry.shutdown().await; + let telemetry_worker = self.workers.lock_or_panic().telemetry.take(); + + if let Some(telemetry_worker) = telemetry_worker { + let _ = telemetry_worker.join().await; + } + } + } + /// Start the stats exporter and enable stats computation /// /// Should only be used if the agent enabled stats computation @@ -309,23 +385,25 @@ impl TraceExporter { let cancellation_token = CancellationToken::new(); - let mut stats_exporter = stats_exporter::StatsExporter::new( + let stats_exporter = stats_exporter::StatsExporter::new( bucket_size, stats_concentrator.clone(), self.metadata.clone(), Endpoint::from_url(add_path(&self.endpoint.url, STATS_ENDPOINT)), cancellation_token.clone(), ); + let mut stats_worker = PausableWorker::new(stats_exporter); + let runtime = self.runtime()?; + stats_worker.start(&runtime).map_err(|e| { + TraceExporterError::Internal(InternalErrorKind::InvalidWorkerState(e.to_string())) + })?; - let exporter_handle = self.runtime.spawn(async move { - stats_exporter.run().await; - }); + self.workers.lock_or_panic().stats = Some(stats_worker); self.client_side_stats .store(Arc::new(StatsComputationStatus::Enabled { stats_concentrator, cancellation_token, - exporter_handle, })); }; Ok(()) @@ -338,19 +416,20 @@ impl TraceExporter { if let StatsComputationStatus::Enabled { stats_concentrator, cancellation_token, - exporter_handle: _, } = &**self.client_side_stats.load() { - self.runtime.block_on(async { - cancellation_token.cancel(); - }); - #[allow(clippy::unwrap_used)] - let bucket_size = stats_concentrator.lock().unwrap().get_bucket_size(); - - self.client_side_stats - .store(Arc::new(StatsComputationStatus::DisabledByAgent { - bucket_size, - })); + if let Ok(runtime) = self.runtime() { + runtime.block_on(async { + cancellation_token.cancel(); + }); + self.workers.lock_or_panic().stats = None; + let bucket_size = stats_concentrator.lock_or_panic().get_bucket_size(); + + self.client_side_stats + .store(Arc::new(StatsComputationStatus::DisabledByAgent { + bucket_size, + })); + } } } @@ -388,13 +467,10 @@ impl TraceExporter { } } StatsComputationStatus::Enabled { - stats_concentrator, - cancellation_token: _, - exporter_handle: _, + stats_concentrator, .. } => { if agent_info.info.client_drop_p0s.is_some_and(|v| v) { - #[allow(clippy::unwrap_used)] - let mut concentrator = stats_concentrator.lock().unwrap(); + let mut concentrator = stats_concentrator.lock_or_panic(); concentrator.set_span_kinds( agent_info @@ -463,7 +539,7 @@ impl TraceExporter { trace_count: usize, uri: Uri, ) -> Result { - self.runtime.block_on(async { + self.runtime()?.block_on(async { let mut req_builder = hyper::Request::builder() .uri(uri) .header( @@ -717,7 +793,26 @@ impl TraceExporter { let payload_len = mp_payload.len(); - self.runtime.block_on(async { + let runtime = { + let mut runtime_guard = self.runtime.lock_or_panic(); + let runtime = match runtime_guard.as_ref() { + Some(runtime) => runtime.clone(), + None => { + // Create a new current thread runtime with all features enabled + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build()?, + ); + *runtime_guard = Some(runtime.clone()); + runtime + } + }; + runtime + }; + + runtime.block_on(async { // Send traces to the agent let result = send_with_retry(&endpoint, mp_payload, &headers, &strategy, None).await; @@ -1011,10 +1106,12 @@ impl TraceExporterBuilder { )); } - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(1) - .enable_all() - .build()?; + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build()?, + ); let dogstatsd = self.dogstatsd_url.and_then(|u| { new(Endpoint::from_slice(&u)).ok() // If we couldn't set the endpoint return @@ -1034,11 +1131,11 @@ impl TraceExporterBuilder { Endpoint::from_url(add_path(&agent_url, INFO_ENDPOINT)), Duration::from_secs(5 * 60), ); - let agent_info = info_fetcher.get_info(); - runtime.spawn(async move { - info_fetcher.run().await; - }); + let mut info_fetcher_worker = PausableWorker::new(info_fetcher); + info_fetcher_worker.start(&runtime).map_err(|e| { + TraceExporterError::Builder(BuilderErrorKind::InvalidConfiguration(e.to_string())) + })?; // Proxy mode does not support stats if self.input_format != TraceExporterInputFormat::Proxy { @@ -1062,15 +1159,25 @@ impl TraceExporterBuilder { if let Some(id) = telemetry_config.runtime_id { builder = builder.set_runtime_id(&id); } - builder.build().await + builder.build(runtime.handle().clone()) })?) } else { None }; - if let Some(client) = &telemetry { - runtime.block_on(client.start()); - } + let (telemetry_client, telemetry_worker) = match telemetry { + Some((client, worker)) => { + let mut telemetry_worker = PausableWorker::new(worker); + telemetry_worker.start(&runtime).map_err(|e| { + TraceExporterError::Builder(BuilderErrorKind::InvalidConfiguration( + e.to_string(), + )) + })?; + runtime.block_on(client.start()); + (Some(client), Some(telemetry_worker)) + } + None => (None, None), + }; Ok(TraceExporter { endpoint: Endpoint { @@ -1096,13 +1203,18 @@ impl TraceExporterBuilder { input_format: self.input_format, output_format: self.output_format, client_computed_top_level: self.client_computed_top_level, - runtime, + runtime: Arc::new(Mutex::new(Some(runtime))), dogstatsd, common_stats_tags: vec![libdatadog_version], client_side_stats: ArcSwap::new(stats.into()), agent_info, previous_info_state: ArcSwapOption::new(None), - telemetry, + telemetry: telemetry_client, + workers: Arc::new(Mutex::new(TraceExporterWorkers { + info: info_fetcher_worker, + stats: None, + telemetry: telemetry_worker, + })), }) } @@ -1305,9 +1417,15 @@ mod tests { // Wait for the info fetcher to get the config while mock_info.hits() == 0 { - exporter.runtime.block_on(async { - sleep(Duration::from_millis(100)).await; - }) + exporter + .runtime + .lock() + .unwrap() + .as_ref() + .unwrap() + .block_on(async { + sleep(Duration::from_millis(100)).await; + }) } let result = exporter.send(data.as_ref(), 1); @@ -1390,9 +1508,15 @@ mod tests { // Wait for the info fetcher to get the config while mock_info.hits() == 0 { - exporter.runtime.block_on(async { - sleep(Duration::from_millis(100)).await; - }) + exporter + .runtime + .lock() + .unwrap() + .as_ref() + .unwrap() + .block_on(async { + sleep(Duration::from_millis(100)).await; + }) } exporter.send(data.as_ref(), 1).unwrap(); @@ -1773,9 +1897,15 @@ mod tests { traces_endpoint.assert_hits(1); while metrics_endpoint.hits() == 0 { - exporter.runtime.block_on(async { - sleep(Duration::from_millis(100)).await; - }) + exporter + .runtime + .lock() + .unwrap() + .as_ref() + .unwrap() + .block_on(async { + sleep(Duration::from_millis(100)).await; + }) } metrics_endpoint.assert_hits(1); } @@ -1821,9 +1951,15 @@ mod tests { traces_endpoint.assert_hits(1); while metrics_endpoint.hits() == 0 { - exporter.runtime.block_on(async { - sleep(Duration::from_millis(100)).await; - }) + exporter + .runtime + .lock() + .unwrap() + .as_ref() + .unwrap() + .block_on(async { + sleep(Duration::from_millis(100)).await; + }) } metrics_endpoint.assert_hits(1); } @@ -1881,9 +2017,15 @@ mod tests { traces_endpoint.assert_hits(1); while metrics_endpoint.hits() == 0 { - exporter.runtime.block_on(async { - sleep(Duration::from_millis(100)).await; - }) + exporter + .runtime + .lock() + .unwrap() + .as_ref() + .unwrap() + .block_on(async { + sleep(Duration::from_millis(100)).await; + }) } metrics_endpoint.assert_hits(1); } @@ -1964,9 +2106,15 @@ mod tests { // Wait for the info fetcher to get the config while mock_info.hits() == 0 { - exporter.runtime.block_on(async { - sleep(Duration::from_millis(100)).await; - }) + exporter + .runtime + .lock() + .unwrap() + .as_ref() + .unwrap() + .block_on(async { + sleep(Duration::from_millis(100)).await; + }) } let _ = exporter.send(data.as_ref(), 1).unwrap(); diff --git a/data-pipeline/tests/test_fetch_info.rs b/data-pipeline/tests/test_fetch_info.rs index 4b97e85b71..461d3ee0a8 100644 --- a/data-pipeline/tests/test_fetch_info.rs +++ b/data-pipeline/tests/test_fetch_info.rs @@ -6,7 +6,7 @@ mod tracing_integration_tests { use arc_swap::access::Access; use data_pipeline::agent_info::{fetch_info, AgentInfoFetcher}; use datadog_trace_utils::test_utils::datadog_test_agent::DatadogTestAgent; - use ddcommon::Endpoint; + use ddcommon::{worker::Worker, Endpoint}; use std::time::Duration; #[cfg_attr(miri, ignore)] @@ -30,7 +30,7 @@ mod tracing_integration_tests { async fn test_agent_info_fetcher_with_test_agent() { let test_agent = DatadogTestAgent::new(None, None, &[]).await; let endpoint = Endpoint::from_url(test_agent.get_uri_for_endpoint("info", None).await); - let fetcher = AgentInfoFetcher::new(endpoint, Duration::from_secs(1)); + let mut fetcher = AgentInfoFetcher::new(endpoint, Duration::from_secs(1)); let info_arc = fetcher.get_info(); tokio::spawn(async move { fetcher.run().await }); let info_received = async { diff --git a/ddcommon/src/lib.rs b/ddcommon/src/lib.rs index 50f182a333..b40e8c50b1 100644 --- a/ddcommon/src/lib.rs +++ b/ddcommon/src/lib.rs @@ -25,6 +25,7 @@ pub mod hyper_migration; pub mod rate_limiter; pub mod tag; pub mod unix_utils; +pub mod worker; /// Extension trait for `Mutex` to provide a method that acquires a lock, panicking if the lock is /// poisoned. diff --git a/ddcommon/src/worker.rs b/ddcommon/src/worker.rs new file mode 100644 index 0000000000..c79c9317f2 --- /dev/null +++ b/ddcommon/src/worker.rs @@ -0,0 +1,12 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +/// Trait representing a generic worker. +/// +/// The worker runs an async looping function running periodic tasks. +/// +/// This trait can be used to provide wrapper around a worker. +pub trait Worker { + /// Main worker loop + fn run(&mut self) -> impl std::future::Future + Send; +} diff --git a/ddtelemetry/src/worker/mod.rs b/ddtelemetry/src/worker/mod.rs index 9336487d28..5a89388155 100644 --- a/ddtelemetry/src/worker/mod.rs +++ b/ddtelemetry/src/worker/mod.rs @@ -11,8 +11,9 @@ use crate::{ metrics::{ContextKey, MetricBuckets, MetricContexts}, }; use ddcommon::Endpoint; -use ddcommon::{hyper_migration, tag::Tag}; +use ddcommon::{hyper_migration, tag::Tag, worker::Worker}; +use std::fmt::Debug; use std::iter::Sum; use std::ops::Add; use std::{ @@ -103,6 +104,7 @@ pub struct LogIdentifier { } // Holds the current state of the telemetry worker +#[derive(Debug)] struct TelemetryWorkerData { started: bool, dependencies: store::Store, @@ -116,6 +118,7 @@ struct TelemetryWorkerData { } pub struct TelemetryWorker { + flavor: TelemetryWorkerFlavor, config: Config, mailbox: mpsc::Receiver, cancellation_token: CancellationToken, @@ -125,6 +128,50 @@ pub struct TelemetryWorker { deadlines: scheduler::Scheduler, data: TelemetryWorkerData, } +impl Debug for TelemetryWorker { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TelemetryWorker") + .field("flavor", &self.flavor) + .field("config", &self.config) + .field("mailbox", &self.mailbox) + .field("cancellation_token", &self.cancellation_token) + .field("seq_id", &self.seq_id) + .field("runtime_id", &self.runtime_id) + .field("deadlines", &self.deadlines) + .field("data", &self.data) + .finish() + } +} + +impl Worker for TelemetryWorker { + // Runs a state machine that waits for actions, either from the worker's + // mailbox, or scheduled actions from the worker's deadline object. + async fn run(&mut self) { + loop { + if self.cancellation_token.is_cancelled() { + return; + } + + let action = self.recv_next_action().await; + + let action_result = match self.flavor { + TelemetryWorkerFlavor::Full => self.dispatch_action(action).await, + TelemetryWorkerFlavor::MetricsLogs => { + self.dispatch_metrics_logs_action(action).await + } + }; + + match action_result { + ControlFlow::Continue(()) => {} + ControlFlow::Break(()) => { + if !self.config.restartable { + break; + } + } + }; + } + } +} #[derive(Debug, Default, Serialize, Deserialize)] pub struct TelemetryWorkerStats { @@ -213,25 +260,6 @@ impl TelemetryWorker { }) } - async fn run_metrics_logs(mut self) { - loop { - if self.cancellation_token.is_cancelled() { - return; - } - - let action = self.recv_next_action().await; - - match self.dispatch_metrics_logs_action(action).await { - ControlFlow::Continue(()) => {} - ControlFlow::Break(()) => { - if !self.config.restartable { - break; - } - } - }; - } - } - async fn dispatch_metrics_logs_action(&mut self, action: TelemetryActions) -> ControlFlow<()> { telemetry_worker_log!(self, DEBUG, "Handling metric action {:?}", action); use LifecycleAction::*; @@ -272,6 +300,12 @@ impl TelemetryWorker { if !(self.data.started || self.config.restartable) { return CONTINUE; } + + #[allow(clippy::unwrap_used)] + self.deadlines + .schedule_event(LifecycleAction::FlushData) + .unwrap(); + let batch = self.build_observability_batch(); if !batch.is_empty() { let payload = data::Payload::MessageBatch(batch); @@ -280,11 +314,6 @@ impl TelemetryWorker { Err(e) => self.log_err(&e), } } - - #[allow(clippy::unwrap_used)] - self.deadlines - .schedule_event(LifecycleAction::FlushData) - .unwrap(); } AddConfig(_) | AddDependecy(_) | AddIntegration(_) | Lifecycle(ExtendedHeartbeat) => {} Lifecycle(Stop) => { @@ -313,27 +342,6 @@ impl TelemetryWorker { CONTINUE } - // Runs a state machine that waits for actions, either from the worker's - // mailbox, or scheduled actions from the worker's deadline object. - async fn run(mut self) { - loop { - if self.cancellation_token.is_cancelled() { - return; - } - - let action = self.recv_next_action().await; - - match self.dispatch_action(action).await { - ControlFlow::Continue(()) => {} - ControlFlow::Break(()) => { - if !self.config.restartable { - break; - } - } - }; - } - } - async fn dispatch_action(&mut self, action: TelemetryActions) -> ControlFlow<()> { telemetry_worker_log!(self, DEBUG, "Handling action {:?}", action); @@ -385,6 +393,12 @@ impl TelemetryWorker { if !(self.data.started || self.config.restartable) { return CONTINUE; } + + #[allow(clippy::unwrap_used)] + self.deadlines + .schedule_event(LifecycleAction::FlushData) + .unwrap(); + let mut batch = self.build_app_events_batch(); let payload = if batch.is_empty() { data::Payload::AppHeartbeat(()) @@ -405,11 +419,6 @@ impl TelemetryWorker { Err(err) => self.log_err(&err), } } - - #[allow(clippy::unwrap_used)] - self.deadlines - .schedule_event(LifecycleAction::FlushData) - .unwrap(); } Lifecycle(ExtendedHeartbeat) => { self.data.dependencies.unflush_stored(); @@ -900,7 +909,7 @@ impl TelemetryWorkerHandle { /// How many dependencies/integrations/configs we keep in memory at most pub const MAX_ITEMS: usize = 5000; -#[derive(Default, Clone, Copy)] +#[derive(Debug, Default, Clone, Copy)] pub enum TelemetryWorkerFlavor { /// Send all telemetry messages including lifecylce events like app-started, hearbeats, /// dependencies and configurations @@ -974,7 +983,10 @@ impl TelemetryWorkerBuilder { } } - fn build_worker( + /// Build the corresponding worker and it's handle. + /// The runtime handle is wrapped in the worker handle and should be the one used to run the + /// worker task. + pub fn build_worker( self, tokio_runtime: Handle, ) -> Result<(TelemetryWorkerHandle, TelemetryWorker)> { @@ -991,6 +1003,7 @@ impl TelemetryWorkerBuilder { #[allow(clippy::unwrap_used)] let worker = TelemetryWorker { + flavor: self.flavor, data: TelemetryWorkerData { started: false, dependencies: self.dependencies, @@ -1040,13 +1053,9 @@ impl TelemetryWorkerBuilder { pub async fn spawn(self) -> Result<(TelemetryWorkerHandle, JoinHandle<()>)> { let tokio_runtime = tokio::runtime::Handle::current(); - let flavor = self.flavor; - let (worker_handle, worker) = self.build_worker(tokio_runtime.clone())?; + let (worker_handle, mut worker) = self.build_worker(tokio_runtime.clone())?; - let join_handle = match flavor { - TelemetryWorkerFlavor::Full => tokio_runtime.spawn(worker.run()), - TelemetryWorkerFlavor::MetricsLogs => tokio_runtime.spawn(worker.run_metrics_logs()), - }; + let join_handle = tokio_runtime.spawn(async move { worker.run().await }); Ok((worker_handle, join_handle)) } @@ -1056,14 +1065,10 @@ impl TelemetryWorkerBuilder { let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; - let flavor = self.flavor; - let (handle, worker) = self.build_worker(runtime.handle().clone())?; + let (handle, mut worker) = self.build_worker(runtime.handle().clone())?; let notify_shutdown = handle.shutdown.clone(); std::thread::spawn(move || { - match flavor { - TelemetryWorkerFlavor::Full => runtime.block_on(worker.run()), - TelemetryWorkerFlavor::MetricsLogs => runtime.block_on(worker.run_metrics_logs()), - } + runtime.block_on(worker.run()); runtime.shutdown_background(); notify_shutdown.shutdown_finished(); }); diff --git a/ddtelemetry/src/worker/store.rs b/ddtelemetry/src/worker/store.rs index 8bcf9ca2cd..277a0f1edf 100644 --- a/ddtelemetry/src/worker/store.rs +++ b/ddtelemetry/src/worker/store.rs @@ -10,6 +10,7 @@ mod queuehashmap { hash::{BuildHasher, Hash}, }; + #[derive(Debug)] pub struct QueueHashMap { table: HashTable, hash_builder: DefaultHashBuilder, @@ -135,7 +136,7 @@ mod queuehashmap { pub use queuehashmap::QueueHashMap; -#[derive(Default)] +#[derive(Debug, Default)] /// Stores telemetry data item, like dependencies and integrations /// /// * Bounds the length of the collection it uses to prevent memory leaks