• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import logging
16import threading
17import time
18from typing import Any, AnyStr, Dict, Iterable, List, Optional, Set, Union
19
20import grpc
21
22# pytype: disable=pyi-error
23from grpc_observability import _cyobservability
24from grpc_observability import _observability
25from grpc_observability import _open_telemetry_measures
26from grpc_observability._cyobservability import MetricsName
27from grpc_observability._cyobservability import PLUGIN_IDENTIFIER_SEP
28from grpc_observability._observability import OptionalLabelType
29from grpc_observability._observability import StatsData
30from opentelemetry.metrics import Counter
31from opentelemetry.metrics import Histogram
32from opentelemetry.metrics import Meter
33
34_LOGGER = logging.getLogger(__name__)
35
36ClientCallTracerCapsule = Any  # it appears only once in the function signature
37ServerCallTracerFactoryCapsule = (
38    Any  # it appears only once in the function signature
39)
40grpc_observability = Any  # grpc_observability.py imports this module.
41OpenTelemetryPlugin = Any  # _open_telemetry_plugin.py imports this module.
42OpenTelemetryPluginOption = (
43    Any  # _open_telemetry_plugin.py imports this module.
44)
45
46GRPC_METHOD_LABEL = "grpc.method"
47GRPC_TARGET_LABEL = "grpc.target"
48GRPC_CLIENT_METRIC_PREFIX = "grpc.client"
49GRPC_OTHER_LABEL_VALUE = "other"
50_observability_lock: threading.RLock = threading.RLock()
51_OPEN_TELEMETRY_OBSERVABILITY: Optional["OpenTelemetryObservability"] = None
52
53GRPC_STATUS_CODE_TO_STRING = {
54    grpc.StatusCode.OK: "OK",
55    grpc.StatusCode.CANCELLED: "CANCELLED",
56    grpc.StatusCode.UNKNOWN: "UNKNOWN",
57    grpc.StatusCode.INVALID_ARGUMENT: "INVALID_ARGUMENT",
58    grpc.StatusCode.DEADLINE_EXCEEDED: "DEADLINE_EXCEEDED",
59    grpc.StatusCode.NOT_FOUND: "NOT_FOUND",
60    grpc.StatusCode.ALREADY_EXISTS: "ALREADY_EXISTS",
61    grpc.StatusCode.PERMISSION_DENIED: "PERMISSION_DENIED",
62    grpc.StatusCode.UNAUTHENTICATED: "UNAUTHENTICATED",
63    grpc.StatusCode.RESOURCE_EXHAUSTED: "RESOURCE_EXHAUSTED",
64    grpc.StatusCode.FAILED_PRECONDITION: "FAILED_PRECONDITION",
65    grpc.StatusCode.ABORTED: "ABORTED",
66    grpc.StatusCode.OUT_OF_RANGE: "OUT_OF_RANGE",
67    grpc.StatusCode.UNIMPLEMENTED: "UNIMPLEMENTED",
68    grpc.StatusCode.INTERNAL: "INTERNAL",
69    grpc.StatusCode.UNAVAILABLE: "UNAVAILABLE",
70    grpc.StatusCode.DATA_LOSS: "DATA_LOSS",
71}
72
73
74class _OpenTelemetryPlugin:
75    _plugin: OpenTelemetryPlugin
76    _metric_to_recorder: Dict[MetricsName, Union[Counter, Histogram]]
77    _enabled_client_plugin_options: Optional[List[OpenTelemetryPluginOption]]
78    _enabled_server_plugin_options: Optional[List[OpenTelemetryPluginOption]]
79    identifier: str
80
81    def __init__(self, plugin: OpenTelemetryPlugin):
82        self._plugin = plugin
83        self._metric_to_recorder = dict()
84        self.identifier = str(id(self))
85        self._enabled_client_plugin_options = None
86        self._enabled_server_plugin_options = None
87
88        meter_provider = self._plugin.meter_provider
89        if meter_provider:
90            meter = meter_provider.get_meter("grpc-python", grpc.__version__)
91            enabled_metrics = _open_telemetry_measures.base_metrics()
92            self._metric_to_recorder = self._register_metrics(
93                meter, enabled_metrics
94            )
95
96    def _should_record(self, stats_data: StatsData) -> bool:
97        # Decide if this plugin should record the stats_data.
98        return stats_data.name in self._metric_to_recorder.keys()
99
100    def _record_stats_data(self, stats_data: StatsData) -> None:
101        recorder = self._metric_to_recorder[stats_data.name]
102        enabled_plugin_options = []
103        if GRPC_CLIENT_METRIC_PREFIX in recorder.name:
104            enabled_plugin_options = self._enabled_client_plugin_options
105        else:
106            enabled_plugin_options = self._enabled_server_plugin_options
107        # Only deserialize labels if we need add exchanged labels.
108        if stats_data.include_exchange_labels:
109            deserialized_labels = self._deserialize_labels(
110                stats_data.labels, enabled_plugin_options
111            )
112        else:
113            deserialized_labels = stats_data.labels
114        labels = self._maybe_add_labels(
115            stats_data.include_exchange_labels,
116            deserialized_labels,
117            enabled_plugin_options,
118        )
119        decoded_labels = self.decode_labels(labels)
120
121        target = decoded_labels.get(GRPC_TARGET_LABEL, "")
122        if not self._plugin.target_attribute_filter(target):
123            # Filter target name.
124            decoded_labels[GRPC_TARGET_LABEL] = GRPC_OTHER_LABEL_VALUE
125
126        method = decoded_labels.get(GRPC_METHOD_LABEL, "")
127        if not (
128            stats_data.registered_method
129            or self._plugin.generic_method_attribute_filter(method)
130        ):
131            # Filter method name if it's not registered method and
132            # generic_method_attribute_filter returns false.
133            decoded_labels[GRPC_METHOD_LABEL] = GRPC_OTHER_LABEL_VALUE
134
135        value = 0
136        if stats_data.measure_double:
137            value = stats_data.value_float
138        else:
139            value = stats_data.value_int
140        if isinstance(recorder, Counter):
141            recorder.add(value, attributes=decoded_labels)
142        elif isinstance(recorder, Histogram):
143            recorder.record(value, attributes=decoded_labels)
144
145    def maybe_record_stats_data(self, stats_data: List[StatsData]) -> None:
146        # Records stats data to MeterProvider.
147        if self._should_record(stats_data):
148            self._record_stats_data(stats_data)
149
150    def get_client_exchange_labels(self) -> Dict[str, AnyStr]:
151        """Get labels used for client side Metadata Exchange."""
152
153        labels_for_exchange = {}
154        for plugin_option in self._enabled_client_plugin_options:
155            if hasattr(plugin_option, "get_label_injector") and hasattr(
156                plugin_option.get_label_injector(), "get_labels_for_exchange"
157            ):
158                labels_for_exchange.update(
159                    plugin_option.get_label_injector().get_labels_for_exchange()
160                )
161        return labels_for_exchange
162
163    def get_server_exchange_labels(self) -> Dict[str, str]:
164        """Get labels used for server side Metadata Exchange."""
165        labels_for_exchange = {}
166        for plugin_option in self._enabled_server_plugin_options:
167            if hasattr(plugin_option, "get_label_injector") and hasattr(
168                plugin_option.get_label_injector(), "get_labels_for_exchange"
169            ):
170                labels_for_exchange.update(
171                    plugin_option.get_label_injector().get_labels_for_exchange()
172                )
173        return labels_for_exchange
174
175    def activate_client_plugin_options(self, target: bytes) -> None:
176        """Activate client plugin options based on option settings."""
177        target_str = target.decode("utf-8", "replace")
178        if not self._enabled_client_plugin_options:
179            self._enabled_client_plugin_options = []
180            for plugin_option in self._plugin.plugin_options:
181                if hasattr(
182                    plugin_option, "is_active_on_client_channel"
183                ) and plugin_option.is_active_on_client_channel(target_str):
184                    self._enabled_client_plugin_options.append(plugin_option)
185
186    def activate_server_plugin_options(self, xds: bool) -> None:
187        """Activate server plugin options based on option settings."""
188        if not self._enabled_server_plugin_options:
189            self._enabled_server_plugin_options = []
190            for plugin_option in self._plugin.plugin_options:
191                if hasattr(
192                    plugin_option, "is_active_on_server"
193                ) and plugin_option.is_active_on_server(xds):
194                    self._enabled_server_plugin_options.append(plugin_option)
195
196    @staticmethod
197    def _deserialize_labels(
198        labels: Dict[str, AnyStr],
199        enabled_plugin_options: List[OpenTelemetryPluginOption],
200    ) -> Dict[str, AnyStr]:
201        for plugin_option in enabled_plugin_options:
202            if all(
203                [
204                    hasattr(plugin_option, "get_label_injector"),
205                    hasattr(
206                        plugin_option.get_label_injector(), "deserialize_labels"
207                    ),
208                ]
209            ):
210                labels = plugin_option.get_label_injector().deserialize_labels(
211                    labels
212                )
213        return labels
214
215    @staticmethod
216    def _maybe_add_labels(
217        include_exchange_labels: bool,
218        labels: Dict[str, str],
219        enabled_plugin_options: List[OpenTelemetryPluginOption],
220    ) -> Dict[str, AnyStr]:
221        for plugin_option in enabled_plugin_options:
222            if all(
223                [
224                    hasattr(plugin_option, "get_label_injector"),
225                    hasattr(
226                        plugin_option.get_label_injector(),
227                        "get_additional_labels",
228                    ),
229                ]
230            ):
231                labels.update(
232                    plugin_option.get_label_injector().get_additional_labels(
233                        include_exchange_labels
234                    )
235                )
236        return labels
237
238    def get_enabled_optional_labels(self) -> List[OptionalLabelType]:
239        return self._plugin._get_enabled_optional_labels()
240
241    @staticmethod
242    def _register_metrics(
243        meter: Meter, metrics: List[_open_telemetry_measures.Metric]
244    ) -> Dict[MetricsName, Union[Counter, Histogram]]:
245        metric_to_recorder_map = {}
246        recorder = None
247        for metric in metrics:
248            if metric == _open_telemetry_measures.CLIENT_ATTEMPT_STARTED:
249                recorder = meter.create_counter(
250                    name=metric.name,
251                    unit=metric.unit,
252                    description=metric.description,
253                )
254            elif metric == _open_telemetry_measures.CLIENT_ATTEMPT_DURATION:
255                recorder = meter.create_histogram(
256                    name=metric.name,
257                    unit=metric.unit,
258                    description=metric.description,
259                )
260            elif metric == _open_telemetry_measures.CLIENT_RPC_DURATION:
261                recorder = meter.create_histogram(
262                    name=metric.name,
263                    unit=metric.unit,
264                    description=metric.description,
265                )
266            elif metric == _open_telemetry_measures.CLIENT_ATTEMPT_SEND_BYTES:
267                recorder = meter.create_histogram(
268                    name=metric.name,
269                    unit=metric.unit,
270                    description=metric.description,
271                )
272            elif (
273                metric == _open_telemetry_measures.CLIENT_ATTEMPT_RECEIVED_BYTES
274            ):
275                recorder = meter.create_histogram(
276                    name=metric.name,
277                    unit=metric.unit,
278                    description=metric.description,
279                )
280            elif metric == _open_telemetry_measures.SERVER_STARTED_RPCS:
281                recorder = meter.create_counter(
282                    name=metric.name,
283                    unit=metric.unit,
284                    description=metric.description,
285                )
286            elif metric == _open_telemetry_measures.SERVER_RPC_DURATION:
287                recorder = meter.create_histogram(
288                    name=metric.name,
289                    unit=metric.unit,
290                    description=metric.description,
291                )
292            elif metric == _open_telemetry_measures.SERVER_RPC_SEND_BYTES:
293                recorder = meter.create_histogram(
294                    name=metric.name,
295                    unit=metric.unit,
296                    description=metric.description,
297                )
298            elif metric == _open_telemetry_measures.SERVER_RPC_RECEIVED_BYTES:
299                recorder = meter.create_histogram(
300                    name=metric.name,
301                    unit=metric.unit,
302                    description=metric.description,
303                )
304            metric_to_recorder_map[metric.cyname] = recorder
305        return metric_to_recorder_map
306
307    @staticmethod
308    def decode_labels(labels: Dict[str, AnyStr]) -> Dict[str, str]:
309        decoded_labels = {}
310        for key, value in labels.items():
311            if isinstance(value, bytes):
312                value = value.decode()
313            decoded_labels[key] = value
314        return decoded_labels
315
316
317def start_open_telemetry_observability(
318    *,
319    plugins: Iterable[_OpenTelemetryPlugin],
320) -> None:
321    _start_open_telemetry_observability(
322        OpenTelemetryObservability(plugins=plugins)
323    )
324
325
326def end_open_telemetry_observability() -> None:
327    _end_open_telemetry_observability()
328
329
330class _OpenTelemetryExporterDelegator(_observability.Exporter):
331    _plugins: Iterable[_OpenTelemetryPlugin]
332
333    def __init__(self, plugins: Iterable[_OpenTelemetryPlugin]):
334        self._plugins = plugins
335
336    def export_stats_data(
337        self, stats_data: List[_observability.StatsData]
338    ) -> None:
339        # Records stats data to MeterProvider.
340        for data in stats_data:
341            for plugin in self._plugins:
342                plugin.maybe_record_stats_data(data)
343
344    def export_tracing_data(
345        self, tracing_data: List[_observability.TracingData]
346    ) -> None:
347        pass
348
349
350# pylint: disable=no-self-use
351class OpenTelemetryObservability(grpc._observability.ObservabilityPlugin):
352    """OpenTelemetry based plugin implementation.
353
354    This is class is part of an EXPERIMENTAL API.
355
356    Args:
357      plugins: _OpenTelemetryPlugins to enable.
358    """
359
360    _exporter: "grpc_observability.Exporter"
361    _plugins: List[_OpenTelemetryPlugin]
362    _registered_method: Set[bytes]
363    _client_option_activated: bool
364    _server_option_activated: bool
365
366    def __init__(
367        self,
368        *,
369        plugins: Optional[Iterable[_OpenTelemetryPlugin]],
370    ):
371        self._exporter = _OpenTelemetryExporterDelegator(plugins)
372        self._registered_methods = set()
373        self._plugins = plugins
374        self._client_option_activated = False
375        self._server_option_activated = False
376
377    def observability_init(self):
378        try:
379            _cyobservability.activate_stats()
380            self.set_stats(True)
381        except Exception as e:  # pylint: disable=broad-except
382            raise ValueError(f"Activate observability metrics failed with: {e}")
383
384        try:
385            _cyobservability.cyobservability_init(self._exporter)
386        # TODO(xuanwn): Use specific exceptions
387        except Exception as e:  # pylint: disable=broad-except
388            _LOGGER.exception("Initiate observability failed with: %s", e)
389
390        grpc._observability.observability_init(self)
391
392    def observability_deinit(self) -> None:
393        # Sleep so we don't loss any data. If we shutdown export thread
394        # immediately after exit, it's possible that core didn't call RecordEnd
395        # in callTracer, and all data recorded by calling RecordEnd will be
396        # lost.
397        # CENSUS_EXPORT_BATCH_INTERVAL_SECS: The time equals to the time in
398        # AwaitNextBatchLocked.
399        # TODO(xuanwn): explicit synchronization
400        # https://github.com/grpc/grpc/issues/33262
401        time.sleep(_cyobservability.CENSUS_EXPORT_BATCH_INTERVAL_SECS)
402        self.set_tracing(False)
403        self.set_stats(False)
404        _cyobservability.observability_deinit()
405        grpc._observability.observability_deinit()
406
407    def create_client_call_tracer(
408        self, method_name: bytes, target: bytes
409    ) -> ClientCallTracerCapsule:
410        trace_id = b"TRACE_ID"
411        self._maybe_activate_client_plugin_options(target)
412        exchange_labels = self._get_client_exchange_labels()
413        enabled_optional_labels = set()
414        for plugin in self._plugins:
415            enabled_optional_labels.update(plugin.get_enabled_optional_labels())
416
417        capsule = _cyobservability.create_client_call_tracer(
418            method_name,
419            target,
420            trace_id,
421            self._get_identifier(),
422            exchange_labels,
423            enabled_optional_labels,
424            method_name in self._registered_methods,
425        )
426        return capsule
427
428    def create_server_call_tracer_factory(
429        self,
430        *,
431        xds: bool = False,
432    ) -> Optional[ServerCallTracerFactoryCapsule]:
433        capsule = None
434        self._maybe_activate_server_plugin_options(xds)
435        exchange_labels = self._get_server_exchange_labels()
436        capsule = _cyobservability.create_server_call_tracer_factory_capsule(
437            exchange_labels, self._get_identifier()
438        )
439        return capsule
440
441    def save_trace_context(
442        self, trace_id: str, span_id: str, is_sampled: bool
443    ) -> None:
444        pass
445
446    def record_rpc_latency(
447        self,
448        method: str,
449        target: str,
450        rpc_latency: float,
451        status_code: grpc.StatusCode,
452    ) -> None:
453        status_code = GRPC_STATUS_CODE_TO_STRING.get(status_code, "UNKNOWN")
454        encoded_method = method.encode("utf8")
455        _cyobservability._record_rpc_latency(
456            self._exporter,
457            method,
458            target,
459            rpc_latency,
460            status_code,
461            self._get_identifier(),
462            encoded_method in self._registered_methods,
463        )
464
465    def save_registered_method(self, method_name: bytes) -> None:
466        self._registered_methods.add(method_name)
467
468    def _get_client_exchange_labels(self) -> Dict[str, AnyStr]:
469        client_exchange_labels = {}
470        for _plugin in self._plugins:
471            client_exchange_labels.update(_plugin.get_client_exchange_labels())
472        return client_exchange_labels
473
474    def _get_server_exchange_labels(self) -> Dict[str, AnyStr]:
475        server_exchange_labels = {}
476        for _plugin in self._plugins:
477            server_exchange_labels.update(_plugin.get_server_exchange_labels())
478        return server_exchange_labels
479
480    def _maybe_activate_client_plugin_options(self, target: bytes) -> None:
481        if not self._client_option_activated:
482            for _plugin in self._plugins:
483                _plugin.activate_client_plugin_options(target)
484            self._client_option_activated = True
485
486    def _maybe_activate_server_plugin_options(self, xds: bool) -> None:
487        if not self._server_option_activated:
488            for _plugin in self._plugins:
489                _plugin.activate_server_plugin_options(xds)
490            self._server_option_activated = True
491
492    def _get_identifier(self) -> str:
493        plugin_identifiers = []
494        for _plugin in self._plugins:
495            plugin_identifiers.append(_plugin.identifier)
496        return PLUGIN_IDENTIFIER_SEP.join(plugin_identifiers)
497
498    def get_enabled_optional_labels(self) -> List[OptionalLabelType]:
499        return []
500
501
502def _start_open_telemetry_observability(
503    otel_o11y: OpenTelemetryObservability,
504) -> None:
505    global _OPEN_TELEMETRY_OBSERVABILITY  # pylint: disable=global-statement
506    with _observability_lock:
507        if _OPEN_TELEMETRY_OBSERVABILITY is None:
508            _OPEN_TELEMETRY_OBSERVABILITY = otel_o11y
509            _OPEN_TELEMETRY_OBSERVABILITY.observability_init()
510        else:
511            raise RuntimeError(
512                "gPRC Python observability was already initialized!"
513            )
514
515
516def _end_open_telemetry_observability() -> None:
517    global _OPEN_TELEMETRY_OBSERVABILITY  # pylint: disable=global-statement
518    with _observability_lock:
519        if not _OPEN_TELEMETRY_OBSERVABILITY:
520            raise RuntimeError(
521                "Trying to end gPRC Python observability without initialize first!"
522            )
523        else:
524            _OPEN_TELEMETRY_OBSERVABILITY.observability_deinit()
525            _OPEN_TELEMETRY_OBSERVABILITY = None
526