Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ddtrace/contrib/celery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ class MyTask(app.Task):
def run(self):
pass

Distributed tracing is disabled by default. To enable it:

from ddtrace import config

config.celery['distributed_tracing'] = True

To change Celery service name, you can use the ``Config`` API as follows::

Expand Down
1 change: 1 addition & 0 deletions ddtrace/contrib/celery/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# Celery default settings
config._add('celery', {
'distributed_tracing': get_env('celery', 'distributed_tracing', default=False),
'producer_service_name': get_env('celery', 'producer_service_name', default=PRODUCER_SERVICE),
'worker_service_name': get_env('celery', 'worker_service_name', default=WORKER_SERVICE),
})
Expand Down
20 changes: 20 additions & 0 deletions ddtrace/contrib/celery/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from ...constants import ANALYTICS_SAMPLE_RATE_KEY, SPAN_MEASURED_KEY
from ...ext import SpanTypes
from ...internal.logger import get_logger
from ...propagation.http import HTTPPropagator
from . import constants as c
from .utils import tags_from_context, retrieve_task_id, attach_span, detach_span, retrieve_span

log = get_logger(__name__)
propagator = HTTPPropagator()


def trace_prerun(*args, **kwargs):
Expand All @@ -27,6 +29,11 @@ def trace_prerun(*args, **kwargs):
log.debug('no pin found on task or task.app task_id=%s', task_id)
return

if config.celery['distributed_tracing']:
context = propagator.extract(task.request.get('headers', {}))
if context.trace_id:
pin.tracer.context_provider.activate(context)

# propagate the `Span` in the current task Context
service = config.celery['worker_service_name']
span = pin.tracer.trace(c.WORKER_ROOT_SPAN, service=service, resource=task.name, span_type=SpanTypes.WORKER)
Expand Down Expand Up @@ -95,11 +102,24 @@ def trace_before_publish(*args, **kwargs):
span.set_tag(c.TASK_TAG_KEY, c.TASK_APPLY_ASYNC)
span.set_tag('celery.id', task_id)
span.set_tags(tags_from_context(kwargs))

# Note: adding tags from `traceback` or `state` calls will make an
# API call to the backend for the properties so we should rely
# only on the given `Context`
attach_span(task, task_id, span, is_publish=True)

if config.celery['distributed_tracing']:
trace_headers = {}
propagator.inject(span.context, trace_headers)

# This weirdness is due to yet another Celery bug concerning
# how headers get propagated in async flows
# https://github.com/celery/celery/issues/4875
task_headers = kwargs.get('headers') or {}
task_headers.setdefault('headers', {})
task_headers['headers'].update(trace_headers)
kwargs['headers'] = task_headers


def trace_after_publish(*args, **kwargs):
task_name = kwargs.get('sender')
Expand Down
72 changes: 72 additions & 0 deletions tests/contrib/celery/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import celery
from celery.exceptions import Retry

from ddtrace import Pin
from ddtrace.context import Context
from ddtrace.contrib.celery import patch, unpatch
from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY
from ddtrace.propagation.http import HTTPPropagator

from .base import CeleryBaseTestCase

Expand Down Expand Up @@ -545,9 +548,78 @@ def fn_task_parameters(user, force_logout=False):

self.assert_is_measured(dd_span)
assert dd_span.error == 0

assert dd_span.name == "celery.apply"
assert dd_span.resource == "tests.contrib.celery.test_integration.fn_task_parameters"
assert dd_span.service == "celery-producer"
assert dd_span.get_tag("celery.id") == t.task_id
assert dd_span.get_tag("celery.action") == "apply_async"
assert dd_span.get_tag("celery.routing_key") == "celery"


class CeleryDistributedTracingIntegrationTask(CeleryBaseTestCase):
"""Distributed tracing is tricky to test for two reasons:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 thanks for the test and great accompanying documentation too! Much appreciated 😄


1. We aren't running anything distributed at all in this test suite
2. Celery doesn't run the `before_task_publish` signal we rely
on when running in synchronous mode, e.g. using apply.
https://github.com/celery/celery/issues/3864

To get around #1, we inject our own new context in a prerun signal
to simulate a distributed worker beginning with its own context
(which does not match the parent context from the publisher).

Hopefully if #2 is ever fixed we can more robustly test both of the
signals we're using for distributed trace propagation. For now,
this is only really testing the prerun part of the distributed tracing,
and it needs to simulate the before_publish part. :(
"""

def setUp(self):
# Register our context-ruining signal before the normal ones so that
# the "real" task_prerun signal starts with the new context
celery.signals.task_prerun.connect(self.inject_new_context)
super(CeleryDistributedTracingIntegrationTask, self).setUp()
provider = Pin.get_from(self.app).tracer.context_provider
provider.activate(Context(trace_id=12345, span_id=12345, sampling_priority=1))

def tearDown(self):
celery.signals.task_prerun.disconnect(self.inject_new_context)
super(CeleryDistributedTracingIntegrationTask, self).tearDown()

def inject_new_context(self, *args, **kwargs):
pin = Pin.get_from(self.app)
pin.tracer.context_provider.activate(Context(trace_id=99999, span_id=99999, sampling_priority=1))

def test_distributed_tracing_disabled(self):
"""This test is just making sure our signal hackery in this test class
is working the way we expect it to.
"""

@self.app.task
def fn_task():
return 42

fn_task.apply()

traces = self.tracer.writer.pop_traces()
span = traces[0][0]
assert span.trace_id == 99999

def test_distributed_tracing_propagation(self):
@self.app.task
def fn_task():
return 42

# This header manipulation is copying the work that should be done
# by the before_publish signal. Rip it out if Celery ever fixes their bug.
current_context = Pin.get_from(self.app).tracer.context_provider.active()
headers = {}
HTTPPropagator().inject(current_context, headers)

with self.override_config("celery", dict(distributed_tracing=True)):
fn_task.apply(headers=headers)

traces = self.tracer.writer.pop_traces()
span = traces[0][0]
assert span.trace_id == 12345